Skip to content

Commit

Permalink
feat(model-handling): improve user model addition
Browse files Browse the repository at this point in the history
- Add error handling for model addition process
- Retrieve and use latest model version when adding user models
- Update client initialization in ImageGenerator class
  • Loading branch information
rtuszik committed Sep 12, 2024
1 parent 71dca34 commit caec2d0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
33 changes: 20 additions & 13 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,19 +524,26 @@ async def add_model():
async def add_user_model(self, new_model):
logger.debug(f"Adding user model: {new_model}")
if new_model and new_model not in self.user_added_models:
self.user_added_models[new_model] = new_model
self.model_options = list(self.user_added_models.keys())
self.replicate_model_select.options = self.model_options
self.replicate_model_select.value = new_model
await self.update_replicate_model(new_model)
models_json = json.dumps(
{"user_added": list(self.user_added_models.keys())}
)
set_setting("default", "models", models_json)
save_settings()
ui.notify(f"Model '{new_model}' added successfully", type="positive")
self.model_list.refresh()
logger.info(f"User model added: {new_model}")
try:
latest_v = await asyncio.to_thread(
self.image_generator.get_model_version, new_model
)
self.user_added_models[new_model] = latest_v
self.model_options = list(self.user_added_models.values())
self.replicate_model_select.options = self.model_options
self.replicate_model_select.value = latest_v
await self.update_replicate_model(latest_v)
models_json = json.dumps(
{"user_added": list(self.user_added_models.values())}
)
set_setting("default", "models", models_json)
save_settings()
ui.notify(f"Model '{latest_v}' added successfully", type="positive")
self.model_list.refresh()
logger.info(f"User model added: {latest_v}")
except Exception as e:
logger.error(f"Error adding model: {str(e)}")
ui.notify(f"Error adding model: {str(e)}", type="negative")
else:
logger.warning(f"Invalid model name or model already exists: {new_model}")
ui.notify("Invalid model name or model already exists", type="negative")
Expand Down
39 changes: 32 additions & 7 deletions src/replicate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,46 @@ class ImageGenerator:
def __init__(self):
self.replicate_model = None
self.api_key = None
self.client = None
logger.info("ImageGenerator initialized")

def set_api_key(self, api_key):
self.api_key = api_key
os.environ["REPLICATE_API_KEY"] = api_key
logger.info("API key set")
self.client = replicate.Client(api_token=self.api_key)
logger.info("API key set and client initialized")

def set_model(self, replicate_model):
self.replicate_model = replicate_model
logger.info(f"Model set to: {replicate_model}")

def get_model_version(self, user_input):
if not self.client:
error_message = (
"No API key set. Please set an API key before getting model version."
)
logger.error(error_message)
raise ImageGenerationError(error_message)

logger.info(f"Parsing model string: {user_input}")
if ":" in user_input:
logger.debug("Model string contains version")
return user_input
else:
logger.debug("Model string does not contain version")
owner, name = user_input.split("/")
logger.debug(f"Retrieving latest version for {owner}/{name}")
if not self.client:
error_message = "No API key set. Please set an API key before getting model version."
logger.error(error_message)
raise ImageGenerationError(error_message)

model = self.client.models.get(f"{owner}/{name}")
version = model.latest_version.id
latest_version = f"{owner}/{name}:{version}"
logger.info(f"Latest version retrieved: {latest_version}")
return latest_version

def generate_images(self, params):
if not self.replicate_model:
error_message = (
Expand All @@ -31,7 +60,7 @@ def generate_images(self, params):
logger.error(error_message)
raise ImageGenerationError(error_message)

if not self.api_key:
if not self.client:
error_message = (
"No API key set. Please set an API key before generating images."
)
Expand All @@ -40,17 +69,13 @@ def generate_images(self, params):

try:
flux_model = params.pop("flux_model", "dev")

params["model"] = flux_model

logger.info(
f"Generating images with params: {json.dumps(params, indent=2)}"
)
logger.info(f"Using Replicate model: {self.replicate_model}")

client = replicate.Client(api_token=self.api_key)
output = client.run(self.replicate_model, input=params)

output = self.client.run(self.replicate_model, input=params)
logger.success(f"Images generated successfully. Output: {output}")
return output
except Exception as e:
Expand Down

0 comments on commit caec2d0

Please sign in to comment.