From 151759e9922031a677008b1ace170f6e67f9ea00 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 16 Oct 2024 13:21:34 +0200 Subject: [PATCH] update name Signed-off-by: Ashwin Vaidya --- .../models/image/vlm_ad/lightning_model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/anomalib/models/image/vlm_ad/lightning_model.py b/src/anomalib/models/image/vlm_ad/lightning_model.py index 39468f6a16..1279f7a31e 100644 --- a/src/anomalib/models/image/vlm_ad/lightning_model.py +++ b/src/anomalib/models/image/vlm_ad/lightning_model.py @@ -32,15 +32,15 @@ def __init__( self.vlm_backend: Backend = self._setup_vlm_backend(model, api_key) @staticmethod - def _setup_vlm_backend(model: ModelName, api_key: str | None) -> Backend: - if model == ModelName.LLAMA_OLLAMA: - return Ollama(model_name=model.value) - if model == ModelName.GPT_4O_MINI: - return ChatGPT(api_key=api_key, model_name=model.value) - if model in {ModelName.VICUNA_7B_HF, ModelName.VICUNA_13B_HF, ModelName.MISTRAL_7B_HF}: - return Huggingface(model_name=model.value) - - msg = f"Unsupported VLM model: {model}" + def _setup_vlm_backend(model_name: ModelName, api_key: str | None) -> Backend: + if model_name == ModelName.LLAMA_OLLAMA: + return Ollama(model_name=model_name.value) + if model_name == ModelName.GPT_4O_MINI: + return ChatGPT(api_key=api_key, model_name=model_name.value) + if model_name in {ModelName.VICUNA_7B_HF, ModelName.VICUNA_13B_HF, ModelName.MISTRAL_7B_HF}: + return Huggingface(model_name=model_name.value) + + msg = f"Unsupported VLM model: {model_name}" raise ValueError(msg) def _setup(self) -> None: