Skip to content

Commit

Permalink
feat(gui): enhance image generation interface
Browse files Browse the repository at this point in the history
- Add output folder selection and recent models feature
- Implement image gallery with lightbox functionality
- Improve error handling and user notifications
- Update requirements.txt with httpx dependency
  • Loading branch information
rtuszik committed Aug 23, 2024
1 parent 79711e9 commit 0378aeb
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 14 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ token_count==0.2.1
pillow==10.4.0
loguru==0.7.2
nicegui==1.4.36
httpx
143 changes: 129 additions & 14 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import json
import os
import sys
import urllib.parse
from datetime import datetime
from pathlib import Path

import httpx
from loguru import logger
from nicegui import ui
from nicegui import events, ui

# Configure Loguru
logger.remove() # Remove the default handler
Expand All @@ -14,18 +18,49 @@
logger.add("gui.log", rotation="10 MB", format="{time} {level} {message}", level="INFO")


class Lightbox:
def __init__(self):
with ui.dialog().props("maximized").classes("bg-black") as self.dialog:
ui.keyboard(self._handle_key)
self.large_image = ui.image().props("no-spinner fit=scale-down")
self.image_list = []

def add_image(self, thumb_url: str, orig_url: str) -> ui.image:
self.image_list.append(orig_url)
with ui.button(on_click=lambda: self._open(orig_url)).props(
"flat dense square"
):
return ui.image(thumb_url)

def _handle_key(self, event_args: events.KeyEventArguments) -> None:
if not event_args.action.keydown:
return
if event_args.key.escape:
self.dialog.close()
image_index = self.image_list.index(self.large_image.source)
if event_args.key.arrow_left and image_index > 0:
self._open(self.image_list[image_index - 1])
if event_args.key.arrow_right and image_index < len(self.image_list) - 1:
self._open(self.image_list[image_index + 1])

def _open(self, url: str) -> None:
self.large_image.set_source(url)
self.dialog.open()


class ImageGeneratorGUI:
def __init__(self, image_generator):
self.image_generator = image_generator
self.settings_file = "settings.json"
self.load_settings()
self.recent_replicate_models = self.load_recent_replicate_models()
self.setup_ui()
logger.info("ImageGeneratorGUI initialized")

def setup_ui(self):
ui.dark_mode().enable()

with ui.column().classes("w-full max-w-3xl mx-auto p-4 space-y-4"):
with ui.column().classes("w-full max-w-7xl mx-auto p-4 space-y-4"):
with ui.card().classes("w-full"):
ui.label("Image Generator").classes("text-2xl font-bold mb-4")
with ui.row().classes("w-full justify-between"):
Expand All @@ -42,6 +77,21 @@ def setup_left_panel(self):
).classes("w-full")
self.replicate_model_input.on("change", self.update_replicate_model)

self.recent_models_select = ui.select(
options=self.recent_replicate_models,
label="Recent Models",
value=None,
on_change=self.select_recent_model,
).classes("w-full")

self.folder_path = self.settings.get(
"output_folder", str(Path.home() / "Downloads")
)
self.folder_input = ui.input(
label="Output Folder", value=self.folder_path
).classes("w-full")
self.folder_input.on("change", self.update_folder_path)

self.flux_model_select = (
ui.select(
["dev", "schnell"],
Expand Down Expand Up @@ -171,14 +221,27 @@ def setup_left_panel(self):
.bind_value(self, "disable_safety_checker")
)

def update_folder_path(self, e):
new_path = e.value
if os.path.isdir(new_path):
self.folder_path = new_path
self.save_settings()
logger.info(f"Output folder set to: {self.folder_path}")
ui.notify(f"Output folder updated to: {self.folder_path}", type="success")
else:
ui.notify(
"Invalid folder path. Please enter a valid directory.", type="error"
)
self.folder_input.value = self.folder_path

def setup_right_panel(self):
self.output_area = ui.textarea(label="Generated Image URLs").classes(
"w-full h-64"
)
self.output_area.props("readonly")
self.spinner = ui.spinner(size="lg")
self.spinner.visible = False

# Add gallery view
self.gallery_container = ui.column().classes("w-full mt-4")
self.lightbox = Lightbox()

def setup_bottom_panel(self):
self.prompt_input = (
ui.textarea("Prompt", value=self.settings.get("prompt", ""))
Expand All @@ -193,20 +256,46 @@ def setup_bottom_panel(self):
"w-full bg-blue-500 hover:bg-blue-600 text-white font-bold py-2 px-4 rounded"
)

if not self.replicate_model_input.value:
self.generate_button.disable()
def select_folder(self):
def on_folder_selected(e):
if e.value:
self.folder_path = e.value
self.folder_input.value = self.folder_path
self.save_settings()
logger.info(f"Output folder set to: {self.folder_path}")

ui.open_directory_dialog(on_folder_selected)

def update_replicate_model(self, e):
new_model = e.value
if new_model:
self.image_generator.set_model(new_model)
self.save_settings()
self.add_recent_replicate_model(new_model)
logger.info(f"Replicate model updated to: {new_model}")
self.generate_button.enable()
else:
logger.warning("Empty Replicate model provided")
self.generate_button.disable()

def select_recent_model(self, e):
if e.value:
self.replicate_model_input.value = e.value
self.update_replicate_model(e)
self.recent_models_select.value = None

def add_recent_replicate_model(self, model):
if model not in self.recent_replicate_models:
self.recent_replicate_models.insert(0, model)
self.recent_replicate_models = self.recent_replicate_models[
:5
] # Keep only the last 5
self.save_settings()
self.recent_models_select.options = self.recent_replicate_models

def load_recent_replicate_models(self):
return self.settings.get("recent_replicate_models", [])

def toggle_custom_dimensions(self, e):
if e.value == "custom":
self.width_input.enable()
Expand Down Expand Up @@ -260,28 +349,52 @@ async def start_generation(self):

self.generate_button.disable()
self.spinner.visible = True
self.output_area.value = "Generating images..."
ui.notify("Generating images...", type="info")
logger.info(f"Generating images with params: {json.dumps(params, indent=2)}")

try:
output = await asyncio.to_thread(
self.image_generator.generate_images, params
)
self.display_image_urls(output)
await self.download_and_display_images(output)
logger.success(f"Images generated successfully: {output}")
except Exception as e:
error_message = f"An error occurred: {str(e)}"
ui.notify(error_message, type="error")
logger.exception(error_message)
self.output_area.value = error_message
finally:
self.generate_button.enable()
self.spinner.visible = False

def display_image_urls(self, image_urls):
self.output_area.value = "\n".join(image_urls)
ui.notify("Images generated successfully!", type="success")
async def download_and_display_images(self, image_urls):
downloaded_images = []
async with httpx.AsyncClient() as client:
for i, url in enumerate(image_urls):
response = await client.get(url)
if response.status_code == 200:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
url_part = urllib.parse.urlparse(url).path.split("/")[-2][
:8
] # Get first 8 chars of the unique part
file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png"
file_path = Path(self.folder_path) / file_name
with open(file_path, "wb") as f:
f.write(response.content)
downloaded_images.append(str(file_path))
logger.info(f"Image downloaded: {file_path}")
else:
logger.error(f"Failed to download image from {url}")

self.update_gallery(downloaded_images)
ui.notify("Images generated and downloaded successfully!", type="success")

def update_gallery(self, image_paths):
self.gallery_container.clear()
with self.gallery_container:
for image_path in image_paths:
self.lightbox.add_image(image_path, image_path).classes(
"w-32 h-32 object-cover m-1"
)

def load_settings(self):
if os.path.exists(self.settings_file):
Expand All @@ -295,6 +408,7 @@ def load_settings(self):
def save_settings(self):
settings_to_save = {
"replicate_model": self.replicate_model_input.value,
"output_folder": self.folder_path,
"flux_model": self.flux_model,
"aspect_ratio": self.aspect_ratio,
"width": self.width,
Expand All @@ -308,6 +422,7 @@ def save_settings(self):
"output_quality": self.output_quality,
"disable_safety_checker": self.disable_safety_checker,
"prompt": self.prompt,
"recent_replicate_models": self.recent_replicate_models,
}
with open(self.settings_file, "w") as f:
json.dump(settings_to_save, f)
Expand Down

0 comments on commit 0378aeb

Please sign in to comment.