diff --git a/requirements.txt b/requirements.txt index 18bb3ee..73be64b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ token_count==0.2.1 pillow==10.4.0 loguru==0.7.2 nicegui==1.4.36 +httpx \ No newline at end of file diff --git a/src/gui.py b/src/gui.py index 72bb0e9..7100542 100644 --- a/src/gui.py +++ b/src/gui.py @@ -2,9 +2,13 @@ import json import os import sys +import urllib.parse +from datetime import datetime +from pathlib import Path +import httpx from loguru import logger -from nicegui import ui +from nicegui import events, ui # Configure Loguru logger.remove() # Remove the default handler @@ -14,18 +18,49 @@ logger.add("gui.log", rotation="10 MB", format="{time} {level} {message}", level="INFO") +class Lightbox: + def __init__(self): + with ui.dialog().props("maximized").classes("bg-black") as self.dialog: + ui.keyboard(self._handle_key) + self.large_image = ui.image().props("no-spinner fit=scale-down") + self.image_list = [] + + def add_image(self, thumb_url: str, orig_url: str) -> ui.image: + self.image_list.append(orig_url) + with ui.button(on_click=lambda: self._open(orig_url)).props( + "flat dense square" + ): + return ui.image(thumb_url) + + def _handle_key(self, event_args: events.KeyEventArguments) -> None: + if not event_args.action.keydown: + return + if event_args.key.escape: + self.dialog.close() + image_index = self.image_list.index(self.large_image.source) + if event_args.key.arrow_left and image_index > 0: + self._open(self.image_list[image_index - 1]) + if event_args.key.arrow_right and image_index < len(self.image_list) - 1: + self._open(self.image_list[image_index + 1]) + + def _open(self, url: str) -> None: + self.large_image.set_source(url) + self.dialog.open() + + class ImageGeneratorGUI: def __init__(self, image_generator): self.image_generator = image_generator self.settings_file = "settings.json" self.load_settings() + self.recent_replicate_models = self.load_recent_replicate_models() self.setup_ui() logger.info("ImageGeneratorGUI initialized") def setup_ui(self): ui.dark_mode().enable() - with ui.column().classes("w-full max-w-3xl mx-auto p-4 space-y-4"): + with ui.column().classes("w-full max-w-7xl mx-auto p-4 space-y-4"): with ui.card().classes("w-full"): ui.label("Image Generator").classes("text-2xl font-bold mb-4") with ui.row().classes("w-full justify-between"): @@ -42,6 +77,21 @@ def setup_left_panel(self): ).classes("w-full") self.replicate_model_input.on("change", self.update_replicate_model) + self.recent_models_select = ui.select( + options=self.recent_replicate_models, + label="Recent Models", + value=None, + on_change=self.select_recent_model, + ).classes("w-full") + + self.folder_path = self.settings.get( + "output_folder", str(Path.home() / "Downloads") + ) + self.folder_input = ui.input( + label="Output Folder", value=self.folder_path + ).classes("w-full") + self.folder_input.on("change", self.update_folder_path) + self.flux_model_select = ( ui.select( ["dev", "schnell"], @@ -171,14 +221,27 @@ def setup_left_panel(self): .bind_value(self, "disable_safety_checker") ) + def update_folder_path(self, e): + new_path = e.value + if os.path.isdir(new_path): + self.folder_path = new_path + self.save_settings() + logger.info(f"Output folder set to: {self.folder_path}") + ui.notify(f"Output folder updated to: {self.folder_path}", type="success") + else: + ui.notify( + "Invalid folder path. Please enter a valid directory.", type="error" + ) + self.folder_input.value = self.folder_path + def setup_right_panel(self): - self.output_area = ui.textarea(label="Generated Image URLs").classes( - "w-full h-64" - ) - self.output_area.props("readonly") self.spinner = ui.spinner(size="lg") self.spinner.visible = False + # Add gallery view + self.gallery_container = ui.column().classes("w-full mt-4") + self.lightbox = Lightbox() + def setup_bottom_panel(self): self.prompt_input = ( ui.textarea("Prompt", value=self.settings.get("prompt", "")) @@ -193,20 +256,46 @@ def setup_bottom_panel(self): "w-full bg-blue-500 hover:bg-blue-600 text-white font-bold py-2 px-4 rounded" ) - if not self.replicate_model_input.value: - self.generate_button.disable() + def select_folder(self): + def on_folder_selected(e): + if e.value: + self.folder_path = e.value + self.folder_input.value = self.folder_path + self.save_settings() + logger.info(f"Output folder set to: {self.folder_path}") + + ui.open_directory_dialog(on_folder_selected) def update_replicate_model(self, e): new_model = e.value if new_model: self.image_generator.set_model(new_model) self.save_settings() + self.add_recent_replicate_model(new_model) logger.info(f"Replicate model updated to: {new_model}") self.generate_button.enable() else: logger.warning("Empty Replicate model provided") self.generate_button.disable() + def select_recent_model(self, e): + if e.value: + self.replicate_model_input.value = e.value + self.update_replicate_model(e) + self.recent_models_select.value = None + + def add_recent_replicate_model(self, model): + if model not in self.recent_replicate_models: + self.recent_replicate_models.insert(0, model) + self.recent_replicate_models = self.recent_replicate_models[ + :5 + ] # Keep only the last 5 + self.save_settings() + self.recent_models_select.options = self.recent_replicate_models + + def load_recent_replicate_models(self): + return self.settings.get("recent_replicate_models", []) + def toggle_custom_dimensions(self, e): if e.value == "custom": self.width_input.enable() @@ -260,7 +349,6 @@ async def start_generation(self): self.generate_button.disable() self.spinner.visible = True - self.output_area.value = "Generating images..." ui.notify("Generating images...", type="info") logger.info(f"Generating images with params: {json.dumps(params, indent=2)}") @@ -268,20 +356,45 @@ async def start_generation(self): output = await asyncio.to_thread( self.image_generator.generate_images, params ) - self.display_image_urls(output) + await self.download_and_display_images(output) logger.success(f"Images generated successfully: {output}") except Exception as e: error_message = f"An error occurred: {str(e)}" ui.notify(error_message, type="error") logger.exception(error_message) - self.output_area.value = error_message finally: self.generate_button.enable() self.spinner.visible = False - def display_image_urls(self, image_urls): - self.output_area.value = "\n".join(image_urls) - ui.notify("Images generated successfully!", type="success") + async def download_and_display_images(self, image_urls): + downloaded_images = [] + async with httpx.AsyncClient() as client: + for i, url in enumerate(image_urls): + response = await client.get(url) + if response.status_code == 200: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + url_part = urllib.parse.urlparse(url).path.split("/")[-2][ + :8 + ] # Get first 8 chars of the unique part + file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png" + file_path = Path(self.folder_path) / file_name + with open(file_path, "wb") as f: + f.write(response.content) + downloaded_images.append(str(file_path)) + logger.info(f"Image downloaded: {file_path}") + else: + logger.error(f"Failed to download image from {url}") + + self.update_gallery(downloaded_images) + ui.notify("Images generated and downloaded successfully!", type="success") + + def update_gallery(self, image_paths): + self.gallery_container.clear() + with self.gallery_container: + for image_path in image_paths: + self.lightbox.add_image(image_path, image_path).classes( + "w-32 h-32 object-cover m-1" + ) def load_settings(self): if os.path.exists(self.settings_file): @@ -295,6 +408,7 @@ def load_settings(self): def save_settings(self): settings_to_save = { "replicate_model": self.replicate_model_input.value, + "output_folder": self.folder_path, "flux_model": self.flux_model, "aspect_ratio": self.aspect_ratio, "width": self.width, @@ -308,6 +422,7 @@ def save_settings(self): "output_quality": self.output_quality, "disable_safety_checker": self.disable_safety_checker, "prompt": self.prompt, + "recent_replicate_models": self.recent_replicate_models, } with open(self.settings_file, "w") as f: json.dump(settings_to_save, f)