From 3fb2fdc05d28935dc1386411377c919e6b5fe1ae Mon Sep 17 00:00:00 2001 From: Robin Tuszik Date: Thu, 12 Sep 2024 17:53:50 +0200 Subject: [PATCH] feat(model-handling): improve user model addition (#17) - Add error handling for model addition process - Retrieve and use latest model version when adding user models - Update client initialization in ImageGenerator class --- src/gui.py | 33 ++++++++++++++++++++------------- src/replicate_api.py | 39 ++++++++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/gui.py b/src/gui.py index a4bf86b..5e84df3 100644 --- a/src/gui.py +++ b/src/gui.py @@ -524,19 +524,26 @@ async def add_model(): async def add_user_model(self, new_model): logger.debug(f"Adding user model: {new_model}") if new_model and new_model not in self.user_added_models: - self.user_added_models[new_model] = new_model - self.model_options = list(self.user_added_models.keys()) - self.replicate_model_select.options = self.model_options - self.replicate_model_select.value = new_model - await self.update_replicate_model(new_model) - models_json = json.dumps( - {"user_added": list(self.user_added_models.keys())} - ) - set_setting("default", "models", models_json) - save_settings() - ui.notify(f"Model '{new_model}' added successfully", type="positive") - self.model_list.refresh() - logger.info(f"User model added: {new_model}") + try: + latest_v = await asyncio.to_thread( + self.image_generator.get_model_version, new_model + ) + self.user_added_models[new_model] = latest_v + self.model_options = list(self.user_added_models.values()) + self.replicate_model_select.options = self.model_options + self.replicate_model_select.value = latest_v + await self.update_replicate_model(latest_v) + models_json = json.dumps( + {"user_added": list(self.user_added_models.values())} + ) + set_setting("default", "models", models_json) + save_settings() + ui.notify(f"Model '{latest_v}' added successfully", type="positive") + self.model_list.refresh() + logger.info(f"User model added: {latest_v}") + except Exception as e: + logger.error(f"Error adding model: {str(e)}") + ui.notify(f"Error adding model: {str(e)}", type="negative") else: logger.warning(f"Invalid model name or model already exists: {new_model}") ui.notify("Invalid model name or model already exists", type="negative") diff --git a/src/replicate_api.py b/src/replicate_api.py index 0aaf4cf..cafd3e0 100644 --- a/src/replicate_api.py +++ b/src/replicate_api.py @@ -12,17 +12,46 @@ class ImageGenerator: def __init__(self): self.replicate_model = None self.api_key = None + self.client = None logger.info("ImageGenerator initialized") def set_api_key(self, api_key): self.api_key = api_key os.environ["REPLICATE_API_KEY"] = api_key - logger.info("API key set") + self.client = replicate.Client(api_token=self.api_key) + logger.info("API key set and client initialized") def set_model(self, replicate_model): self.replicate_model = replicate_model logger.info(f"Model set to: {replicate_model}") + def get_model_version(self, user_input): + if not self.client: + error_message = ( + "No API key set. Please set an API key before getting model version." + ) + logger.error(error_message) + raise ImageGenerationError(error_message) + + logger.info(f"Parsing model string: {user_input}") + if ":" in user_input: + logger.debug("Model string contains version") + return user_input + else: + logger.debug("Model string does not contain version") + owner, name = user_input.split("/") + logger.debug(f"Retrieving latest version for {owner}/{name}") + if not self.client: + error_message = "No API key set. Please set an API key before getting model version." + logger.error(error_message) + raise ImageGenerationError(error_message) + + model = self.client.models.get(f"{owner}/{name}") + version = model.latest_version.id + latest_version = f"{owner}/{name}:{version}" + logger.info(f"Latest version retrieved: {latest_version}") + return latest_version + def generate_images(self, params): if not self.replicate_model: error_message = ( @@ -31,7 +60,7 @@ def generate_images(self, params): logger.error(error_message) raise ImageGenerationError(error_message) - if not self.api_key: + if not self.client: error_message = ( "No API key set. Please set an API key before generating images." ) @@ -40,17 +69,13 @@ def generate_images(self, params): try: flux_model = params.pop("flux_model", "dev") - params["model"] = flux_model - logger.info( f"Generating images with params: {json.dumps(params, indent=2)}" ) logger.info(f"Using Replicate model: {self.replicate_model}") - client = replicate.Client(api_token=self.api_key) - output = client.run(self.replicate_model, input=params) - + output = self.client.run(self.replicate_model, input=params) logger.success(f"Images generated successfully. Output: {output}") return output except Exception as e: