From c65e24197eabc7249c9f7ded9190973088ad129b Mon Sep 17 00:00:00 2001 From: Henning Date: Thu, 7 Mar 2024 22:50:24 +0100 Subject: [PATCH] Mistral backend option --- dbtai/cli.py | 34 ++++++++++++++++++++++++---------- dbtai/manifest.py | 35 ++++++++++++++++++++++++++--------- pyproject.toml | 3 ++- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/dbtai/cli.py b/dbtai/cli.py index 39b07ac..1b5d0f4 100644 --- a/dbtai/cli.py +++ b/dbtai/cli.py @@ -52,36 +52,42 @@ def setup(): ), inquirer.List('backend', message ="LLM Backend", - choices = ["OpenAI", "Azure OpenAI"], + choices = ["OpenAI", "Azure OpenAI", "Mistral"], default = "OpenAI" ), inquirer.List("auth_type", message = "Authentication Type", choices = ["API Key", "Native Authentication (DefaultAzureCredential)"], default = "API Key", - ignore = lambda answers: answers['backend'] == "OpenAI" + ignore = lambda answers: answers['backend'] in ["OpenAI", "Mistral"] ), inquirer.Text('api_key', - message='OpenAI API Key', + message='API Key', ignore = lambda answers: answers['auth_type'] == "Native Authentication (DefaultAzureCredential)" ), inquirer.List("openai_model_name", message = "Model Name", choices = ["gpt-3.5-turbo", "gpt-4-turbo-preview"], default = "gpt-4-turbo-preview", - ignore = lambda answers: answers['backend'] == "Azure OpenAI" + ignore = lambda answers: answers['backend'] != "OpenAI" + ), + inquirer.List("mistral_model_name", + message = "Model Name", + choices = ["mistral-large-latest"], + default = "mistral-large-latest", + ignore = lambda answers: answers['backend'] != "Mistral" ), inquirer.Text("azure_endpoint", message = "Azure OpenAI Endpoint", - ignore = lambda answers: answers['backend'] == "OpenAI" + ignore = lambda answers: answers['backend'] != "Azure OpenAI" ), inquirer.Text("azure_openai_model", message = "Azure OpenAI Model", - ignore = lambda answers: answers['backend'] == "OpenAI" + ignore = lambda answers: answers['backend'] != "Azure OpenAI" ), inquirer.Text("azure_openai_deployment", message = "Azure OpenAI Deployment", - ignore = lambda answers: answers['backend'] == "OpenAI" + ignore = lambda answers: answers['backend'] != "Azure OpenAI" ), ] answer = inquirer.prompt(question) @@ -147,8 +153,6 @@ def gen(model_name, description, input): @click.option("--diff", "-d", is_flag=True, help="Show the diff between existing and suggested code", default=False) def fix(model_name, description, diff): manifest = Manifest() - click.echo(model_name) - click.echo(description) model = manifest.fix(model_name, description) @@ -211,4 +215,14 @@ def hello(): \____ | |___ /__| \____|__ /___| \/ \/ \/ """ - click.echo(greeting) \ No newline at end of file + click.echo(greeting) + + +@dbtai.command(help="Generate a dbt test") +@click.argument("model", required=True) +@click.argument("description", required=True) +def test(model, description): + raise NotImplementedError("Not yet implemented") + manifest = Manifest() + test = manifest.generate_test(model, description) + click.echo(test) \ No newline at end of file diff --git a/dbtai/manifest.py b/dbtai/manifest.py index 078534d..fa6318f 100644 --- a/dbtai/manifest.py +++ b/dbtai/manifest.py @@ -11,6 +11,7 @@ import appdirs import yaml from openai import OpenAI +from mistralai.client import MistralClient from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap import io @@ -40,21 +41,26 @@ def __init__( self.manifest = json.load(file) self.config = self._load_config() - self.client = self._make_openai_client() - + if self.config['backend'] == "Mistral": + self.client = self._make_mistral_client() + elif self.config['backend'] == "Azure OpenAI": + raise NotImplementedError("Azure OpenAI not yet implemented") + else: + self.client = self._make_openai_client() def _make_openai_client(self): """Make the OpenAI client with auth.""" + if self.config['backend'] == "OpenAI": api_key = self.config['api_key'] or os.getenv("OPENAI_API_KEY") return OpenAI(api_key=api_key) else: raise NotImplementedError("Azure OpenAI not yet implemented") - # return OpenAI( - # endpoint=self.config['azure_endpoint'], - # model=self.config['azure_openai_model'], - # deployment=self.config['azure_openai_deployment'] - # ) + def _make_mistral_client(self): + """Make the Mistral client with auth.""" + client = MistralClient(api_key=self.config['api_key']) + return client + def chat_completion(self, messages, response_format_type="json_object"): """Convenience method to call the chat completion endpoint. @@ -67,13 +73,24 @@ def chat_completion(self, messages, response_format_type="json_object"): openai.ChatCompletion: The response from the chat API """ if self.config["backend"] == "OpenAI": + if not self.config.get("openai_model_name"): + raise ValueError("OpenAI model name not set in config") + return self.client.chat.completions.create( - model=self.config['openai_model_name'], + model=self.config.get('openai_model_name', 'gpt-4-turbo-preview'), messages=messages, response_format={"type": response_format_type} ) + elif self.config["backend"] == "Mistral": + + return self.client.chat( + model=self.config.get("mistral_model_name", "mistral-large-latest"), + messages=messages, + response_format={"type": response_format_type}, + ) + else: - raise NotImplementedError("Azure OpenAI not yet implemented") + raise NotImplementedError("Your backend is set to Azure OpenAI not yet implemented") def _load_config(self): """Convenience function to load the user config from the config file.""" diff --git a/pyproject.toml b/pyproject.toml index b97a4be..e650c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dbtai" -version = "0.1.0" +version = "0.2.0" description = "`dbtai` is a utility CLI command to generate dbt model documentation for a given model using OpenAI." authors = ["Henning Holgersen"] keywords = [ @@ -13,6 +13,7 @@ license = "Apache 2.0" [tool.poetry.dependencies] python = "<3.14,>=3.8.0" openai = ">1.1.0" +mistralai = ">=0.1.3,<2" click = "^8.1.3" "ruamel.yaml" = "^0.18.6" inquirer = "^3.2.4"