From 8bc9b71c22467f683e17bda651588dd0655ce6b4 Mon Sep 17 00:00:00 2001 From: Robin Tuszik <47579899+rtuszik@users.noreply.github.com> Date: Fri, 23 Aug 2024 20:13:29 +0200 Subject: [PATCH] Squashed commit of the following: commit 0378aebbe22beb55cbbc4500d8d82636bc4d89f6 Author: Robin Tuszik <47579899+rtuszik@users.noreply.github.com> Date: Fri Aug 23 20:12:18 2024 +0200 feat(gui): enhance image generation interface - Add output folder selection and recent models feature - Implement image gallery with lightbox functionality - Improve error handling and user notifications - Update requirements.txt with httpx dependency commit 79711e979c64c5b3ca1935d5fc28813e8fb022a7 Author: Robin Tuszik <47579899+rtuszik@users.noreply.github.com> Date: Fri Aug 23 12:41:01 2024 +0200 refactor(image-gen): update ImageGenerator class - Add logging with Loguru for better debugging - Implement model setting functionality - Update dependencies in requirements.txt - Ignore settings.json in .gitignore commit 62b3fed6f8038bf622944eee178e63813a62ed94 Author: Robin Tuszik <47579899+rtuszik@users.noreply.github.com> Date: Fri Aug 23 12:38:50 2024 +0200 refactor(image-gen): update ImageGenerator class - Add logging with Loguru for better debugging - Implement model setting functionality - Update dependencies in requirements.txt - Ignore settings.json in .gitignore commit 0b8104588c8642088d51396b9687c15d96dd1921 Author: Robin Tuszik <47579899+rtuszik@users.noreply.github.com> Date: Fri Aug 23 12:38:33 2024 +0200 refactor(main): migrate to NiceGUI and add logging - Replace PyQt6 with NiceGUI for the application's GUI - Implement Loguru for structured logging - Reorganize main.py structure for better modularity - Update imports and main execution flow --- requirements.txt | 6 +- src/gui.py | 969 +++++++++++++++++------------------------ src/image_generator.py | 49 ++- src/main.py | 38 +- 4 files changed, 472 insertions(+), 590 deletions(-) diff --git a/requirements.txt b/requirements.txt index e77f435..73be64b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -PyQt6==6.7.1 -replicate==0.31.0 +replicate==0.32.0 python-dotenv==1.0.1 token_count==0.2.1 pillow==10.4.0 +loguru==0.7.2 +nicegui==1.4.36 +httpx \ No newline at end of file diff --git a/src/gui.py b/src/gui.py index 347ac89..7100542 100644 --- a/src/gui.py +++ b/src/gui.py @@ -1,614 +1,433 @@ +import asyncio +import json import os -import time -from urllib.request import urlretrieve - -from PyQt6.QtCore import QSettings, Qt, QThreadPool, QTimer -from PyQt6.QtGui import QGuiApplication, QPixmap, QResizeEvent -from PyQt6.QtWidgets import ( - QCheckBox, - QComboBox, - QDialog, - QDoubleSpinBox, - QFileDialog, - QFormLayout, - QGridLayout, - QHBoxLayout, - QLabel, - QLineEdit, - QMainWindow, - QMessageBox, - QProgressBar, - QPushButton, - QScrollArea, - QSizePolicy, - QSpinBox, - QStatusBar, - QTextEdit, - QVBoxLayout, - QWidget, +import sys +import urllib.parse +from datetime import datetime +from pathlib import Path + +import httpx +from loguru import logger +from nicegui import events, ui + +# Configure Loguru +logger.remove() # Remove the default handler +logger.add( + sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO" ) -from utils import ImageGeneratorThread, ImageLoader, TokenCounter +logger.add("gui.log", rotation="10 MB", format="{time} {level} {message}", level="INFO") -class ImageViewer(QDialog): - def __init__(self, pixmap, parent=None): - super().__init__(parent) - self.original_pixmap = pixmap - self.initUI() +class Lightbox: + def __init__(self): + with ui.dialog().props("maximized").classes("bg-black") as self.dialog: + ui.keyboard(self._handle_key) + self.large_image = ui.image().props("no-spinner fit=scale-down") + self.image_list = [] - def initUI(self): - self.setWindowTitle("Image Viewer") - self.setGeometry(100, 100, 1920, 1080) + def add_image(self, thumb_url: str, orig_url: str) -> ui.image: + self.image_list.append(orig_url) + with ui.button(on_click=lambda: self._open(orig_url)).props( + "flat dense square" + ): + return ui.image(thumb_url) - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - - self.image_label = QLabel(self) - self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - layout.addWidget(self.image_label) - - self.save_button = QPushButton("Save Image", self) - self.save_button.clicked.connect(self.saveImage) - layout.addWidget(self.save_button) + def _handle_key(self, event_args: events.KeyEventArguments) -> None: + if not event_args.action.keydown: + return + if event_args.key.escape: + self.dialog.close() + image_index = self.image_list.index(self.large_image.source) + if event_args.key.arrow_left and image_index > 0: + self._open(self.image_list[image_index - 1]) + if event_args.key.arrow_right and image_index < len(self.image_list) - 1: + self._open(self.image_list[image_index + 1]) - self.updateImage() + def _open(self, url: str) -> None: + self.large_image.set_source(url) + self.dialog.open() - def updateImage(self): - if self.image_label: - button_height = self.save_button.height() if self.save_button else 0 - scaled_pixmap = self.original_pixmap.scaled( - self.width(), - self.height() - button_height, - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation, - ) - self.image_label.setPixmap(scaled_pixmap) - def resizeEvent(self, a0: QResizeEvent | None) -> None: - if a0 is not None: - self.updateImage() - super().resizeEvent(a0) - - def saveImage(self): - file_name, _ = QFileDialog.getSaveFileName( - self, "Save Image", "", "Images (*.png *.jpg *.bmp)" - ) - if file_name: - self.original_pixmap.save(file_name) - - -class ImagePreviewWidget(QLabel): - def __init__(self, pixmap, file_path, parent=None): - super().__init__(parent) - self.original_pixmap = pixmap - self.file_path = file_path - self.setPixmap( - pixmap.scaled( - 300, - 300, - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation, - ) - ) - self.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.setStyleSheet(""" - QLabel { - border: 2px solid #555555; - border-radius: 10px; - padding: 5px; - margin: 5px; - } - QLabel:hover { - border-color: #0078d7; - } - """) - self.setMinimumSize(310, 310) - - def mousePressEvent(self, ev): - if ev.button() == Qt.MouseButton.LeftButton: - viewer = ImageViewer(self.original_pixmap, self.parent()) - viewer.exec() - - -class ImageGeneratorGUI(QMainWindow): +class ImageGeneratorGUI: def __init__(self, image_generator): - super().__init__() self.image_generator = image_generator - self.settings = QSettings("rtuszik", "Flux-Dev-Lora-GUI") - self.threadpool = QThreadPool() - self.current_thread = None - self.is_grid_view = True - self.save_metadata_checkbox = None - self.initUI() - self.loadSettings() - if self.save_metadata_checkbox is None: - print( - "Warning: save_metadata_checkbox is still None after initUI and loadSettings" - ) - QTimer.singleShot(100, self.loadImagesAsync) - - def initUI(self): - self.setStyleSheet(self.getStyleSheet()) - self.setupMainWidget() - self.setupLeftPanel() - self.setupRightPanel() - self.setupBottomPanel() - self.setupStatusBar() - self.setWindowTitle("Image Generator") - self.resize(1900, 800) - self.setMinimumWidth(1600) - - def getStyleSheet(self): - return """ - QMainWindow, QWidget { - background-color: #2b2b2b; - color: #f0f0f0; - font-family: 'Arial', 'Sans-Serif'; - font-size: 13px; - } - QLineEdit, QTextEdit, QComboBox, QSpinBox, QDoubleSpinBox { - background-color: #3c3c3c; - border: 1px solid #555555; - border-radius: 5px; - padding: 5px; - color: #f0f0f0; - width: 100%; - } - QPushButton { - background-color: #5c5c5c; - color: white; - border: none; - border-radius: 5px; - padding: 8px 16px; - font-weight: 500; - min-height: 30px; - width: 100%; - } - QPushButton:hover { - background-color: #6c6c6c; - } - QPushButton:pressed { - background-color: #4c4c4c; - border: 1px solid #333333; - } - QLabel { - color: #f0f0f0; - } - QScrollArea { - border: none; - background-color: #3c3c3c; - } - QCheckBox { - spacing: 5px; - color: #f0f0f0; - } - QCheckBox::indicator { - width: 18px; - height: 18px; - } - QCheckBox::indicator:unchecked { - border: 2px solid #888888; - background-color: #3c3c3c; - } - QCheckBox::indicator:checked { - border: 2px solid #0078d7; - background-color: #0078d7; - } - """ - - def setupMainWidget(self): - main_widget = QWidget() - main_layout = QVBoxLayout(main_widget) - main_layout.setContentsMargins(20, 20, 20, 20) - main_layout.setSpacing(20) - self.setCentralWidget(main_widget) - - top_layout = QHBoxLayout() - main_layout.addLayout(top_layout, 1) - - self.left_layout = QVBoxLayout() - top_layout.addLayout(self.left_layout, 1) - - self.right_layout = QVBoxLayout() - top_layout.addLayout(self.right_layout, 2) - - self.bottom_layout = QVBoxLayout() - main_layout.addLayout(self.bottom_layout) - - def setupLeftPanel(self): - self.setupFormInputs() - self.setupSaveSettings() - - def setupFormInputs(self): - form_layout = QFormLayout() - form_layout.setSpacing(10) - form_layout.setLabelAlignment(Qt.AlignmentFlag.AlignRight) - form_layout.setFormAlignment( - Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop - ) - - self.aspect_ratio_input = self.createComboBox( - [ - "1:1", - "16:9", - "21:9", - "3:2", - "2:3", - "4:5", - "5:4", - "3:4", - "4:3", - "9:16", - "9:21", - ] + self.settings_file = "settings.json" + self.load_settings() + self.recent_replicate_models = self.load_recent_replicate_models() + self.setup_ui() + logger.info("ImageGeneratorGUI initialized") + + def setup_ui(self): + ui.dark_mode().enable() + + with ui.column().classes("w-full max-w-7xl mx-auto p-4 space-y-4"): + with ui.card().classes("w-full"): + ui.label("Image Generator").classes("text-2xl font-bold mb-4") + with ui.row().classes("w-full justify-between"): + with ui.column().classes("w-1/2 pr-2"): + self.setup_left_panel() + with ui.column().classes("w-1/2 pl-2"): + self.setup_right_panel() + self.setup_bottom_panel() + logger.info("UI setup completed") + + def setup_left_panel(self): + self.replicate_model_input = ui.input( + "Replicate Model", value=self.settings.get("replicate_model", "") + ).classes("w-full") + self.replicate_model_input.on("change", self.update_replicate_model) + + self.recent_models_select = ui.select( + options=self.recent_replicate_models, + label="Recent Models", + value=None, + on_change=self.select_recent_model, + ).classes("w-full") + + self.folder_path = self.settings.get( + "output_folder", str(Path.home() / "Downloads") ) - self.num_outputs_input = self.createSpinBox(1, 4) - self.num_inference_steps_input = self.createSpinBox(1, 50) - self.guidance_scale_input = self.createDoubleSpinBox(0, 10, 0.1) - self.seed_input = self.createSpinBox(-2147483648, 2147483647, "Random") - self.output_format_input = self.createComboBox(["png", "jpg", "webp"]) - self.output_quality_input = self.createSpinBox(0, 100) - self.hf_lora_input = QLineEdit() - self.lora_scale_input = self.createDoubleSpinBox(0, 1, 0.1) - self.disable_safety_checker_input = QCheckBox("Disable Safety") - self.disable_safety_checker_input.setChecked(True) - - self.addFormRow(form_layout, "Aspect Ratio:", self.aspect_ratio_input) - self.addFormRow(form_layout, "Outputs:", self.num_outputs_input) - self.addFormRow(form_layout, "Inference Steps:", self.num_inference_steps_input) - self.addFormRow(form_layout, "Guidance Scale:", self.guidance_scale_input) - self.addFormRow(form_layout, "Seed:", self.seed_input) - self.addFormRow(form_layout, "Output Format:", self.output_format_input) - self.addFormRow(form_layout, "Quality:", self.output_quality_input) - self.addFormRow(form_layout, "HF LoRA:", self.hf_lora_input) - self.addFormRow(form_layout, "LoRA Scale:", self.lora_scale_input) - self.addFormRow(form_layout, "", self.disable_safety_checker_input) - - self.left_layout.addLayout(form_layout) - - def setupSaveSettings(self): - self.auto_save_checkbox = QCheckBox("Auto-save") - self.save_metadata_checkbox = QCheckBox("Save prompt as metadata") - self.save_dir_input = QLineEdit() - self.save_dir_input.setReadOnly(True) - self.choose_dir_button = QPushButton("Choose Directory") - self.choose_dir_button.clicked.connect(self.choose_save_directory) - - save_settings_layout = QVBoxLayout() - checkbox_layout = QHBoxLayout() - checkbox_layout.addWidget(self.auto_save_checkbox) - checkbox_layout.addWidget(self.save_metadata_checkbox) - save_settings_layout.addLayout(checkbox_layout) - - dir_layout = QHBoxLayout() - dir_layout.addWidget(self.save_dir_input) - dir_layout.addWidget(self.choose_dir_button) - save_settings_layout.addLayout(dir_layout) - - self.left_layout.addLayout(save_settings_layout) - - def setupRightPanel(self): - self.gallery_scroll = QScrollArea() - self.gallery_scroll.setWidgetResizable(True) - self.gallery_widget = QWidget() - self.gallery_layout = QGridLayout(self.gallery_widget) - self.gallery_layout.setSpacing(10) - self.gallery_scroll.setWidget(self.gallery_widget) - - self.right_layout.addWidget(self.gallery_scroll) - - self.view_toggle = QPushButton("Toggle View") - self.view_toggle.clicked.connect(self.toggle_view) - self.right_layout.addWidget(self.view_toggle) - - def setupBottomPanel(self): - self.prompt_input = QTextEdit() - self.prompt_input.setFixedHeight(100) - self.prompt_input.setSizePolicy( - QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed + self.folder_input = ui.input( + label="Output Folder", value=self.folder_path + ).classes("w-full") + self.folder_input.on("change", self.update_folder_path) + + self.flux_model_select = ( + ui.select( + ["dev", "schnell"], + label="Flux Model", + value=self.settings.get("flux_model", "dev"), + ) + .classes("w-full") + .bind_value(self, "flux_model") ) - self.bottom_layout.addWidget(QLabel("Prompt:")) - self.bottom_layout.addWidget(self.prompt_input) - - self.token_counter = TokenCounter(self.prompt_input) - self.bottom_layout.addWidget(self.token_counter) - self.generate_button = QPushButton("Generate Images") - self.generate_button.setSizePolicy( - QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed - ) - self.generate_button.setFixedHeight(50) - self.generate_button.clicked.connect(self.generate_images) - self.bottom_layout.addWidget(self.generate_button) - - self.progress_bar = QProgressBar() - self.progress_bar.setRange(0, 0) - self.progress_bar.setTextVisible(False) - self.progress_bar.hide() - self.bottom_layout.addWidget(self.progress_bar) - - self.interrupt_button = QPushButton("Interrupt Generation") - self.interrupt_button.clicked.connect(self.interrupt_generation) - self.interrupt_button.setEnabled(False) - self.bottom_layout.addWidget(self.interrupt_button) - - def setupStatusBar(self): - status_bar = QStatusBar() - self.setStatusBar(status_bar) - status_bar.showMessage("Ready") - - def createComboBox(self, items): - combo_box = QComboBox() - combo_box.addItems(items) - return combo_box - - def createSpinBox(self, min_value, max_value, special_value_text=None): - spin_box = QSpinBox() - spin_box.setRange(min_value, max_value) - if special_value_text: - spin_box.setSpecialValueText(special_value_text) - return spin_box - - def createDoubleSpinBox(self, min_value, max_value, step): - double_spin_box = QDoubleSpinBox() - double_spin_box.setRange(min_value, max_value) - double_spin_box.setSingleStep(step) - return double_spin_box - - def addFormRow(self, form_layout, label, widget): - form_layout.addRow(label, widget) - - def loadImagesAsync(self): - folder_path = self.save_dir_input.text() - loader = ImageLoader(folder_path) - loader.signals.finished.connect(self.updateGallery) - self.threadpool.start(loader) - - def toggle_view(self): - self.is_grid_view = not self.is_grid_view - self.clearGallery() - self.updateGallery() - - def updateGallery(self, image_paths=None): - if image_paths is None: - image_paths = [ - item.widget().file_path - for i in range(self.gallery_layout.count()) - if (item := self.gallery_layout.itemAt(i)) - and isinstance(item.widget(), ImagePreviewWidget) - ] - - sorted_images = sorted( - image_paths, key=lambda x: os.path.getctime(x), reverse=True + self.aspect_ratio_select = ( + ui.select( + [ + "1:1", + "16:9", + "21:9", + "3:2", + "2:3", + "4:5", + "5:4", + "3:4", + "4:3", + "9:16", + "9:21", + "custom", + ], + label="Aspect Ratio", + value=self.settings.get("aspect_ratio", "1:1"), + ) + .classes("w-full") + .bind_value(self, "aspect_ratio") ) + self.aspect_ratio_select.on("change", self.toggle_custom_dimensions) + + with ui.column().classes("w-full").bind_visibility_from( + self.aspect_ratio_select, "value", value="custom" + ): + self.width_input = ( + ui.number( + "Width", value=self.settings.get("width", 1024), min=256, max=1440 + ) + .classes("w-full") + .bind_value(self, "width") + ) + self.height_input = ( + ui.number( + "Height", value=self.settings.get("height", 1024), min=256, max=1440 + ) + .classes("w-full") + .bind_value(self, "height") + ) - existing_images = { - item.widget().file_path - for i in range(self.gallery_layout.count()) - if (item := self.gallery_layout.itemAt(i)) - and isinstance(item.widget(), ImagePreviewWidget) - } - - for path in sorted_images: - if path not in existing_images: - pixmap = QPixmap(path) - preview = ImagePreviewWidget(pixmap, path) - if self.is_grid_view: - row = self.gallery_layout.count() // 3 - col = self.gallery_layout.count() % 3 - else: - row = self.gallery_layout.count() - col = 0 - self.gallery_layout.addWidget(preview, row, col) - - for i in range(self.gallery_layout.count()): - item = self.gallery_layout.itemAt(i) - if item and isinstance(item.widget(), ImagePreviewWidget): - item.widget().show() - - scrollbar = self.gallery_scroll.verticalScrollBar() - if scrollbar: - scrollbar.setValue(scrollbar.minimum()) - - def clearGallery(self): - for i in reversed(range(self.gallery_layout.count())): - item = self.gallery_layout.itemAt(i) - if item and isinstance(item.widget(), ImagePreviewWidget): - item.widget().hide() - self.gallery_layout.removeWidget(item.widget()) - - def center(self): - primary_screen = QGuiApplication.primaryScreen() - if primary_screen: - screen_geometry = primary_screen.geometry() - center_point = screen_geometry.center() - frame_geometry = self.frameGeometry() - frame_geometry.moveCenter(center_point) - self.move(frame_geometry.topLeft()) - - def display_images(self, image_urls): - self.progress_bar.hide() - self.generate_button.setEnabled(True) - self.interrupt_button.setEnabled(False) - - timestamp = time.strftime("%Y%m%d_%H%M%S") - new_image_paths = [] - for i, image_url in enumerate(image_urls): - if self.auto_save_checkbox.isChecked(): - base_name = f"generated_image_{timestamp}_{i+1}.{self.output_format_input.currentText()}" - image_path = os.path.join(self.save_dir_input.text(), base_name) - counter = 1 - while os.path.exists(image_path): - new_name = f"generated_image_{timestamp}_{i+1}_{counter}.{self.output_format_input.currentText()}" - image_path = os.path.join(self.save_dir_input.text(), new_name) - counter += 1 - urlretrieve(image_url, image_path) - - if ( - self.save_metadata_checkbox is not None - and self.save_metadata_checkbox.isChecked() - ): - self.add_metadata_to_image( - image_path, self.prompt_input.toPlainText() - ) - else: - print("Warning: save_metadata_checkbox is None or not checked") - - new_image_paths.append(image_path) - else: - image_path = f"temp_image_{i}.{self.output_format_input.currentText()}" - urlretrieve(image_url, image_path) - new_image_paths.append(image_path) - - self.updateGallery(new_image_paths) - - if not self.auto_save_checkbox.isChecked(): - for path in new_image_paths: - os.remove(path) - - def add_metadata_to_image(self, image_path, prompt): - try: - from PIL import Image - from PIL.PngImagePlugin import PngInfo - - with Image.open(image_path) as img: - if img.format == "PNG": - metadata = PngInfo() - metadata.add_text("prompt", prompt) - img.save(image_path, pnginfo=metadata) - elif img.format in ["JPEG", "WEBP"]: - exif = img.getexif() - exif[0x9286] = prompt # 0x9286 is the UserComment EXIF tag - img.save(image_path, exif=exif) - except Exception as e: - print(f"Error adding metadata to {image_path}: {str(e)}") - - def loadSettings(self): - self.prompt_input.setPlainText(self.settings.value("prompt", "")) - self.aspect_ratio_input.setCurrentText( - self.settings.value("aspect_ratio", "1:1") + self.num_outputs_input = ( + ui.number( + "Num Outputs", value=self.settings.get("num_outputs", 1), min=1, max=4 + ) + .classes("w-full") + .bind_value(self, "num_outputs") ) - self.num_outputs_input.setValue(int(self.settings.value("num_outputs", 1))) - self.num_inference_steps_input.setValue( - int(self.settings.value("num_inference_steps", 28)) + self.lora_scale_input = ( + ui.number( + "LoRA Scale", + value=self.settings.get("lora_scale", 1), + min=-1, + max=2, + step=0.1, + ) + .classes("w-full") + .bind_value(self, "lora_scale") ) - self.guidance_scale_input.setValue( - float(self.settings.value("guidance_scale", 3.5)) + self.num_inference_steps_input = ( + ui.number( + "Num Inference Steps", + value=self.settings.get("num_inference_steps", 28), + min=1, + max=50, + ) + .classes("w-full") + .bind_value(self, "num_inference_steps") ) - self.seed_input.setValue(int(self.settings.value("seed", -2147483648))) - self.output_format_input.setCurrentText( - self.settings.value("output_format", "webp") + self.guidance_scale_input = ( + ui.number( + "Guidance Scale", + value=self.settings.get("guidance_scale", 3.5), + min=0, + max=10, + step=0.1, + ) + .classes("w-full") + .bind_value(self, "guidance_scale") ) - self.output_quality_input.setValue( - int(self.settings.value("output_quality", 80)) + self.seed_input = ( + ui.number( + "Seed", + value=self.settings.get("seed", -1), + min=-2147483648, + max=2147483647, + ) + .classes("w-full") + .bind_value(self, "seed") ) - self.hf_lora_input.setText(self.settings.value("hf_lora", "")) - self.lora_scale_input.setValue(float(self.settings.value("lora_scale", 0.8))) - self.disable_safety_checker_input.setChecked( - self.settings.value("disable_safety_checker", True, type=bool) + self.output_format_select = ( + ui.select( + ["webp", "jpg", "png"], + label="Output Format", + value=self.settings.get("output_format", "webp"), + ) + .classes("w-full") + .bind_value(self, "output_format") ) - self.auto_save_checkbox.setChecked( - self.settings.value("auto_save", True, type=bool) + self.output_quality_input = ( + ui.number( + "Output Quality", + value=self.settings.get("output_quality", 80), + min=0, + max=100, + ) + .classes("w-full") + .bind_value(self, "output_quality") ) - self.save_dir_input.setText( - self.settings.value( - "save_directory", os.path.expanduser("~/Downloads/replicate") + self.disable_safety_checker_switch = ( + ui.switch( + "Disable Safety Checker", + value=self.settings.get("disable_safety_checker", False), ) + .classes("w-full") + .bind_value(self, "disable_safety_checker") ) - if self.save_metadata_checkbox: - self.save_metadata_checkbox.setChecked( - self.settings.value("save_metadata", False, type=bool) + + def update_folder_path(self, e): + new_path = e.value + if os.path.isdir(new_path): + self.folder_path = new_path + self.save_settings() + logger.info(f"Output folder set to: {self.folder_path}") + ui.notify(f"Output folder updated to: {self.folder_path}", type="success") + else: + ui.notify( + "Invalid folder path. Please enter a valid directory.", type="error" ) + self.folder_input.value = self.folder_path - self.loadImagesAsync() + def setup_right_panel(self): + self.spinner = ui.spinner(size="lg") + self.spinner.visible = False - def saveSettings(self): - self.settings.setValue("prompt", self.prompt_input.toPlainText()) - self.settings.setValue("aspect_ratio", self.aspect_ratio_input.currentText()) - self.settings.setValue("num_outputs", self.num_outputs_input.value()) - self.settings.setValue( - "num_inference_steps", self.num_inference_steps_input.value() + # Add gallery view + self.gallery_container = ui.column().classes("w-full mt-4") + self.lightbox = Lightbox() + + def setup_bottom_panel(self): + self.prompt_input = ( + ui.textarea("Prompt", value=self.settings.get("prompt", "")) + .classes("w-full") + .bind_value(self, "prompt") ) - self.settings.setValue("guidance_scale", self.guidance_scale_input.value()) - self.settings.setValue("seed", self.seed_input.value()) - self.settings.setValue("output_format", self.output_format_input.currentText()) - self.settings.setValue("output_quality", self.output_quality_input.value()) - self.settings.setValue("hf_lora", self.hf_lora_input.text()) - self.settings.setValue("lora_scale", self.lora_scale_input.value()) - self.settings.setValue( - "disable_safety_checker", self.disable_safety_checker_input.isChecked() + self.token_counter = ui.label("Tokens: 0").classes("text-sm text-gray-500") + self.prompt_input.on("input", self.update_token_count) + self.generate_button = ui.button( + "Generate Images", on_click=self.start_generation + ).classes( + "w-full bg-blue-500 hover:bg-blue-600 text-white font-bold py-2 px-4 rounded" ) - self.settings.setValue("auto_save", self.auto_save_checkbox.isChecked()) - self.settings.setValue("save_directory", self.save_dir_input.text()) - if self.save_metadata_checkbox is not None: - self.settings.setValue( - "save_metadata", self.save_metadata_checkbox.isChecked() - ) + def select_folder(self): + def on_folder_selected(e): + if e.value: + self.folder_path = e.value + self.folder_input.value = self.folder_path + self.save_settings() + logger.info(f"Output folder set to: {self.folder_path}") + + ui.open_directory_dialog(on_folder_selected) + + def update_replicate_model(self, e): + new_model = e.value + if new_model: + self.image_generator.set_model(new_model) + self.save_settings() + self.add_recent_replicate_model(new_model) + logger.info(f"Replicate model updated to: {new_model}") + self.generate_button.enable() + else: + logger.warning("Empty Replicate model provided") + self.generate_button.disable() + + def select_recent_model(self, e): + if e.value: + self.replicate_model_input.value = e.value + self.update_replicate_model(e) + self.recent_models_select.value = None + + def add_recent_replicate_model(self, model): + if model not in self.recent_replicate_models: + self.recent_replicate_models.insert(0, model) + self.recent_replicate_models = self.recent_replicate_models[ + :5 + ] # Keep only the last 5 + self.save_settings() + self.recent_models_select.options = self.recent_replicate_models + + def load_recent_replicate_models(self): + return self.settings.get("recent_replicate_models", []) + + def toggle_custom_dimensions(self, e): + if e.value == "custom": + self.width_input.enable() + self.height_input.enable() else: - print("Warning: save_metadata_checkbox is None") + self.width_input.disable() + self.height_input.disable() + self.save_settings() + logger.info(f"Custom dimensions toggled: {e.value}") + + def update_token_count(self, e): + token_count = len(e.value.split()) + self.token_counter.text = f"Tokens: {token_count}" + if token_count > 77: + ui.notify("Warning: Tokens beyond 77 will be ignored", type="warning") + self.save_settings() + + async def start_generation(self): + if not self.replicate_model_input.value: + ui.notify( + "Please set a Replicate model before generating images.", type="error" + ) + logger.warning( + "Attempted to generate images without setting a Replicate model" + ) + return - def choose_save_directory(self): - dir_path = QFileDialog.getExistingDirectory(self, "Choose Save Directory") - if dir_path: - self.save_dir_input.setText(dir_path) - self.saveSettings() - self.loadImagesAsync() + # Ensure the model is set in the ImageGenerator + self.image_generator.set_model(self.replicate_model_input.value) - def generate_images(self): + self.save_settings() params = { - "prompt": self.prompt_input.toPlainText(), - "aspect_ratio": self.aspect_ratio_input.currentText(), - "num_outputs": self.num_outputs_input.value(), - "num_inference_steps": self.num_inference_steps_input.value(), - "guidance_scale": self.guidance_scale_input.value(), - "output_format": self.output_format_input.currentText(), - "output_quality": self.output_quality_input.value(), - "hf_lora": self.hf_lora_input.text(), - "lora_scale": self.lora_scale_input.value(), - "disable_safety_checker": self.disable_safety_checker_input.isChecked(), + "prompt": self.prompt, + "flux_model": self.flux_model, + "aspect_ratio": self.aspect_ratio, + "num_outputs": self.num_outputs, + "lora_scale": self.lora_scale, + "num_inference_steps": self.num_inference_steps, + "guidance_scale": self.guidance_scale, + "output_format": self.output_format, + "output_quality": self.output_quality, + "disable_safety_checker": self.disable_safety_checker, } - if self.seed_input.value() != self.seed_input.minimum(): - params["seed"] = self.seed_input.value() + if self.aspect_ratio == "custom": + params["width"] = self.width + params["height"] = self.height + + if self.seed != -1: + params["seed"] = self.seed + + self.generate_button.disable() + self.spinner.visible = True + ui.notify("Generating images...", type="info") + logger.info(f"Generating images with params: {json.dumps(params, indent=2)}") + + try: + output = await asyncio.to_thread( + self.image_generator.generate_images, params + ) + await self.download_and_display_images(output) + logger.success(f"Images generated successfully: {output}") + except Exception as e: + error_message = f"An error occurred: {str(e)}" + ui.notify(error_message, type="error") + logger.exception(error_message) + finally: + self.generate_button.enable() + self.spinner.visible = False + + async def download_and_display_images(self, image_urls): + downloaded_images = [] + async with httpx.AsyncClient() as client: + for i, url in enumerate(image_urls): + response = await client.get(url) + if response.status_code == 200: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + url_part = urllib.parse.urlparse(url).path.split("/")[-2][ + :8 + ] # Get first 8 chars of the unique part + file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png" + file_path = Path(self.folder_path) / file_name + with open(file_path, "wb") as f: + f.write(response.content) + downloaded_images.append(str(file_path)) + logger.info(f"Image downloaded: {file_path}") + else: + logger.error(f"Failed to download image from {url}") + + self.update_gallery(downloaded_images) + ui.notify("Images generated and downloaded successfully!", type="success") + + def update_gallery(self, image_paths): + self.gallery_container.clear() + with self.gallery_container: + for image_path in image_paths: + self.lightbox.add_image(image_path, image_path).classes( + "w-32 h-32 object-cover m-1" + ) + + def load_settings(self): + if os.path.exists(self.settings_file): + with open(self.settings_file, "r") as f: + self.settings = json.load(f) + logger.info("Settings loaded successfully") + else: + self.settings = {} + logger.info("No existing settings found, using defaults") + + def save_settings(self): + settings_to_save = { + "replicate_model": self.replicate_model_input.value, + "output_folder": self.folder_path, + "flux_model": self.flux_model, + "aspect_ratio": self.aspect_ratio, + "width": self.width, + "height": self.height, + "num_outputs": self.num_outputs, + "lora_scale": self.lora_scale, + "num_inference_steps": self.num_inference_steps, + "guidance_scale": self.guidance_scale, + "seed": self.seed, + "output_format": self.output_format, + "output_quality": self.output_quality, + "disable_safety_checker": self.disable_safety_checker, + "prompt": self.prompt, + "recent_replicate_models": self.recent_replicate_models, + } + with open(self.settings_file, "w") as f: + json.dump(settings_to_save, f) + logger.info("Settings saved successfully") - if not params["prompt"]: - QMessageBox.warning(self, "Error", "Please enter a prompt.") - return - self.saveSettings() - self.clear_images() - self.progress_bar.show() - self.generate_button.setEnabled(False) - - self.current_thread = ImageGeneratorThread(self.image_generator, params) - self.current_thread.finished.connect(self.display_images) - self.current_thread.error.connect(self.show_error) - self.current_thread.start() - - self.interrupt_button.setEnabled(True) - - def interrupt_generation(self): - if self.current_thread and self.current_thread.isRunning(): - self.current_thread.terminate() - self.current_thread.wait() - self.show_error("Image generation interrupted by user.") - self.interrupt_button.setEnabled(False) - - def show_error(self, error_message): - self.progress_bar.hide() - self.generate_button.setEnabled(True) - QMessageBox.critical(self, "Error", f"An error occurred: {error_message}") - self.interrupt_button.setEnabled(False) - - def clear_images(self): - for i in reversed(range(self.gallery_layout.count())): - item = self.gallery_layout.itemAt(i) - if item and isinstance(item.widget(), ImagePreviewWidget): - widget = item.widget() - widget.setParent(None) - - def closeEvent(self, a0): - self.saveSettings() - super().closeEvent(a0) - self.interrupt_button.setEnabled(False) +async def create_gui(image_generator): + return ImageGeneratorGUI(image_generator) diff --git a/src/image_generator.py b/src/image_generator.py index 6000d6c..d2f7df7 100644 --- a/src/image_generator.py +++ b/src/image_generator.py @@ -1,19 +1,62 @@ +import json +import sys + import replicate from dotenv import load_dotenv +from loguru import logger load_dotenv() +# Configure Loguru +logger.remove() # Remove the default handler +logger.add( + sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO" +) +logger.add( + "image_generator.log", + rotation="10 MB", + format="{time} {level} {message}", + level="INFO", +) + class ImageGenerator: def __init__(self): - self.model = "rtuszik/fluxlyptus:b23b9b488de7af95eba09786ef3156d345d979024712f54b3e5a32d61f14e568" + self.replicate_model = None + logger.info("ImageGenerator initialized") + + def set_model(self, replicate_model): + self.replicate_model = replicate_model + logger.info(f"Model set to: {replicate_model}") def generate_images(self, params): + if not self.replicate_model: + error_message = ( + "No Replicate model set. Please set a model before generating images." + ) + logger.error(error_message) + raise ImageGenerationError(error_message) + try: - output = replicate.run(self.model, input=params) + # Remove the Flux model from params and store it separately + flux_model = params.pop("flux_model", "dev") + + # Add the Flux model choice to the input parameters + params["model"] = flux_model + + logger.info( + f"Generating images with params: {json.dumps(params, indent=2)}" + ) + logger.info(f"Using Replicate model: {self.replicate_model}") + + output = replicate.run(self.replicate_model, input=params) + + logger.success(f"Images generated successfully. Output: {output}") return output except Exception as e: - raise ImageGenerationError(f"Error generating images: {str(e)}") + error_message = f"Error generating images: {str(e)}" + logger.exception(error_message) + raise ImageGenerationError(error_message) class ImageGenerationError(Exception): diff --git a/src/main.py b/src/main.py index 0799010..645ac58 100644 --- a/src/main.py +++ b/src/main.py @@ -1,15 +1,33 @@ -from gui import ImageGeneratorGUI +import sys + +from gui import create_gui from image_generator import ImageGenerator -from PyQt6.QtWidgets import QApplication +from loguru import logger +from nicegui import ui + +# Configure Loguru +logger.remove() # Remove the default handler +logger.add( + sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO" +) +logger.add( + "main.log", rotation="10 MB", format="{time} {level} {message}", level="INFO" +) + +# Create the ImageGenerator instance +logger.info("Initializing ImageGenerator") +generator = ImageGenerator() + +# Create and setup the GUI +logger.info("Creating and setting up GUI") -def main(): - app = QApplication([]) - generator = ImageGenerator() - window = ImageGeneratorGUI(generator) - window.show() - app.exec() +@ui.page("/") +async def main_page(): + await create_gui(generator) + logger.info("NiceGUI server is running") -if __name__ == "__main__": - main() +# Run the NiceGUI server +logger.info("Starting NiceGUI server") +ui.run(port=8080)