diff --git a/dbtai/chatbot.py b/dbtai/chatbot.py index 0eed6a6..56b5230 100644 --- a/dbtai/chatbot.py +++ b/dbtai/chatbot.py @@ -40,11 +40,19 @@ def chat_completion(self, messages): Returns: openai.ChatCompletion: The response from the chat API """ - - return self.client.chat.completions.create( - model="gpt-3.5-turbo", - messages=messages - ) + if self.config["backend"] == "OpenAI": + return self.client.chat.completions.create( + model=self.config["openai_model_name"], + messages=messages + ) + else: + return self.client.chat.completions.create( + model=self.config["azure_openai_model"], + messages=messages, + deployment=self.config["azure_openai_deployment"], + endpoint=self.config["azure_endpoint"] + ) + def run(self): print(f""" @@ -58,7 +66,7 @@ def run(self): print("Goodbye!") break - if user_input == "\save": + if user_input == r"\save": with open('chat_history.txt', 'a') as f: f.write(f"Chat history for the dbt model: {self.model_name}, on {datetime.datetime.now().isoformat()}\n\n") for item in self.chat_history: diff --git a/dbtai/cli.py b/dbtai/cli.py index 49bf48e..d65b621 100644 --- a/dbtai/cli.py +++ b/dbtai/cli.py @@ -132,7 +132,7 @@ def constraints(): @dbtai.command(help="Generate model code") @click.argument("model_name", required=True) @click.argument("description", required=True) -@click.option("--input", "-i", required=False, help="Input model", multiple=True) +@click.option("--input", "-i", required=False, help="Name of Input model. Can be passed multiple times to reference several models", multiple=True) def gen(model_name, description, input): manifest = Manifest() model = manifest.generate_model(model_name, description, input)