diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1d5a66a5..ce0b89f3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -9,6 +9,5 @@ jobs: steps: - uses: actions/checkout@v3 - uses: chartboost/ruff-action@v1 - - uses: rickstaa/action-black@v1 - with: - black_args: ". --check" + - uses: psf/black@stable + diff --git a/services/engine/.env.example b/services/engine/.env.example index b89080a1..4d1d2027 100644 --- a/services/engine/.env.example +++ b/services/engine/.env.example @@ -8,7 +8,7 @@ AGENT_MAX_ITERATIONS = 15 #timeout in seconds for the engine to return a response. Defaults to 150 seconds DH_ENGINE_TIMEOUT = 150 #tmeout for SQL execution, our agents exceute the SQL query to recover from errors, this is the timeout for that execution. Defaults to 30 seconds -SQL_EXECUTION_TIMEOUT = +SQL_EXECUTION_TIMEOUT = 30 #The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50 UPPER_LIMIT_QUERY_RETURN_ROWS = 50 #Encryption key for storing DB connection data in Mongo diff --git a/services/engine/dataherald/api/fastapi.py b/services/engine/dataherald/api/fastapi.py index e9edbd57..6824e3ce 100644 --- a/services/engine/dataherald/api/fastapi.py +++ b/services/engine/dataherald/api/fastapi.py @@ -613,9 +613,11 @@ def create_finetuning_job( Finetuning( db_connection_id=fine_tuning_request.db_connection_id, schemas=fine_tuning_request.schemas, - alias=fine_tuning_request.alias - if fine_tuning_request.alias - else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}", + alias=( + fine_tuning_request.alias + if fine_tuning_request.alias + else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}" + ), base_llm=base_llm, golden_sqls=[str(golden_sql.id) for golden_sql in golden_sqls], metadata=fine_tuning_request.metadata, diff --git a/services/engine/dataherald/db_scanner/__init__.py b/services/engine/dataherald/db_scanner/__init__.py index cd62dddd..13f92f67 100644 --- a/services/engine/dataherald/db_scanner/__init__.py +++ b/services/engine/dataherald/db_scanner/__init__.py @@ -1,4 +1,5 @@ """Base class that all scanner classes inherit from.""" + from abc import ABC, abstractmethod from dataherald.config import Component diff --git a/services/engine/dataherald/finetuning/openai_finetuning.py b/services/engine/dataherald/finetuning/openai_finetuning.py index 2876d2c8..148c64bd 100644 --- a/services/engine/dataherald/finetuning/openai_finetuning.py +++ b/services/engine/dataherald/finetuning/openai_finetuning.py @@ -298,13 +298,15 @@ def create_fine_tuning_job(self): finetuning_request = self.client.fine_tuning.jobs.create( training_file=model.finetuning_file_id, model=model.base_llm.model_name, - hyperparameters=model.base_llm.model_parameters - if model.base_llm.model_parameters - else { - "batch_size": 1, - "learning_rate_multiplier": "auto", - "n_epochs": 3, - }, + hyperparameters=( + model.base_llm.model_parameters + if model.base_llm.model_parameters + else { + "batch_size": 1, + "learning_rate_multiplier": "auto", + "n_epochs": 3, + } + ), ) model.finetuning_job_id = finetuning_request.id if finetuning_request.status == "failed": diff --git a/services/engine/dataherald/scripts/migrate_v006_to_v100.py b/services/engine/dataherald/scripts/migrate_v006_to_v100.py index 4cf330d9..927f2514 100644 --- a/services/engine/dataherald/scripts/migrate_v006_to_v100.py +++ b/services/engine/dataherald/scripts/migrate_v006_to_v100.py @@ -94,9 +94,9 @@ def update_object_id_fields(field_name: str, collection_name: str): "_id": question["_id"], "db_connection_id": str(question["db_connection_id"]), "text": question["question"], - "created_at": None - if len(responses) == 0 - else responses[0]["created_at"], + "created_at": ( + None if len(responses) == 0 else responses[0]["created_at"] + ), "metadata": None, }, ) @@ -112,17 +112,21 @@ def update_object_id_fields(field_name: str, collection_name: str): { "_id": response["_id"], "prompt_id": str(response["question_id"]), - "evaluate": False - if response["confidence_score"] is None - else True, + "evaluate": ( + False if response["confidence_score"] is None else True + ), "sql": response["sql_query"], - "status": "VALID" - if response["sql_generation_status"] == "VALID" - else "INVALID", - "completed_at": response["created_at"] - + timedelta(seconds=response["exec_time"]) - if response["exec_time"] - else None, + "status": ( + "VALID" + if response["sql_generation_status"] == "VALID" + else "INVALID" + ), + "completed_at": ( + response["created_at"] + + timedelta(seconds=response["exec_time"]) + if response["exec_time"] + else None + ), "tokens_used": response["total_tokens"], "confidence_score": response["confidence_score"], "error": response["error_message"], @@ -140,10 +144,12 @@ def update_object_id_fields(field_name: str, collection_name: str): { "sql_generation_id": str(response["_id"]), "text": response["response"], - "created_at": response["created_at"] - + timedelta(seconds=response["exec_time"]) - if response["exec_time"] - else response["created_at"], + "created_at": ( + response["created_at"] + + timedelta(seconds=response["exec_time"]) + if response["exec_time"] + else response["created_at"] + ), "metadata": None, }, ) diff --git a/services/engine/dataherald/server/fastapi/__init__.py b/services/engine/dataherald/server/fastapi/__init__.py index 2380972e..e35475ef 100644 --- a/services/engine/dataherald/server/fastapi/__init__.py +++ b/services/engine/dataherald/server/fastapi/__init__.py @@ -516,9 +516,9 @@ def export_csv_file(self, sql_generation_id: str) -> StreamingResponse: stream = self._api.export_csv_file(sql_generation_id) response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv") - response.headers[ - "Content-Disposition" - ] = f"attachment; filename=sql_generation_{sql_generation_id}.csv" + response.headers["Content-Disposition"] = ( + f"attachment; filename=sql_generation_{sql_generation_id}.csv" + ) return response def delete_golden_sql(self, golden_sql_id: str) -> dict: diff --git a/services/engine/dataherald/services/nl_generations.py b/services/engine/dataherald/services/nl_generations.py index 11da18e8..fa89e2f1 100644 --- a/services/engine/dataherald/services/nl_generations.py +++ b/services/engine/dataherald/services/nl_generations.py @@ -30,9 +30,11 @@ def create( initial_nl_generation = NLGeneration( sql_generation_id=sql_generation_id, created_at=datetime.now(), - llm_config=nl_generation_request.llm_config - if nl_generation_request.llm_config - else LLMConfig(), + llm_config=( + nl_generation_request.llm_config + if nl_generation_request.llm_config + else LLMConfig() + ), metadata=nl_generation_request.metadata, ) self.nl_generation_repository.insert(initial_nl_generation) @@ -46,9 +48,11 @@ def create( nl_generator = GeneratesNlAnswer( self.system, self.storage, - nl_generation_request.llm_config - if nl_generation_request.llm_config - else LLMConfig(), + ( + nl_generation_request.llm_config + if nl_generation_request.llm_config + else LLMConfig() + ), ) try: nl_generation = nl_generator.execute( diff --git a/services/engine/dataherald/services/sql_generations.py b/services/engine/dataherald/services/sql_generations.py index 413101ca..6dc6792b 100644 --- a/services/engine/dataherald/services/sql_generations.py +++ b/services/engine/dataherald/services/sql_generations.py @@ -69,9 +69,11 @@ def create( # noqa: PLR0912 initial_sql_generation = SQLGeneration( prompt_id=prompt_id, created_at=datetime.now(), - llm_config=sql_generation_request.llm_config - if sql_generation_request.llm_config - else LLMConfig(), + llm_config=( + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig() + ), metadata=sql_generation_request.metadata, ) langsmith_metadata = ( @@ -115,16 +117,20 @@ def create( # noqa: PLR0912 ) sql_generator = DataheraldSQLAgent( self.system, - sql_generation_request.llm_config - if sql_generation_request.llm_config - else LLMConfig(), + ( + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig() + ), ) else: sql_generator = DataheraldFinetuningAgent( self.system, - sql_generation_request.llm_config - if sql_generation_request.llm_config - else LLMConfig(), + ( + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig() + ), ) sql_generator.finetuning_id = sql_generation_request.finetuning_id sql_generator.use_fintuned_model_only = ( @@ -184,9 +190,11 @@ def start_streaming( initial_sql_generation = SQLGeneration( prompt_id=prompt_id, created_at=datetime.now(), - llm_config=sql_generation_request.llm_config - if sql_generation_request.llm_config - else LLMConfig(), + llm_config=( + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig() + ), metadata=sql_generation_request.metadata, ) langsmith_metadata = ( @@ -215,16 +223,20 @@ def start_streaming( ) sql_generator = DataheraldSQLAgent( self.system, - sql_generation_request.llm_config - if sql_generation_request.llm_config - else LLMConfig(), + ( + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig() + ), ) else: sql_generator = DataheraldFinetuningAgent( self.system, - sql_generation_request.llm_config - if sql_generation_request.llm_config - else LLMConfig(), + ( + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig() + ), ) sql_generator.finetuning_id = sql_generation_request.finetuning_id sql_generator.use_fintuned_model_only = ( diff --git a/services/engine/dataherald/smart_cache/__init__.py b/services/engine/dataherald/smart_cache/__init__.py index 060152d5..498b7805 100644 --- a/services/engine/dataherald/smart_cache/__init__.py +++ b/services/engine/dataherald/smart_cache/__init__.py @@ -1,4 +1,5 @@ """Base class that all cache classes inherit from.""" + from abc import ABC, abstractmethod from typing import Any, Union diff --git a/services/engine/dataherald/sql_database/base.py b/services/engine/dataherald/sql_database/base.py index ae528024..c35445f7 100644 --- a/services/engine/dataherald/sql_database/base.py +++ b/services/engine/dataherald/sql_database/base.py @@ -1,4 +1,5 @@ """SQL wrapper around SQLDatabase in langchain.""" + import logging import re from typing import List diff --git a/services/engine/dataherald/sql_generator/__init__.py b/services/engine/dataherald/sql_generator/__init__.py index 6612332b..ceb59023 100644 --- a/services/engine/dataherald/sql_generator/__init__.py +++ b/services/engine/dataherald/sql_generator/__init__.py @@ -1,4 +1,5 @@ """Base class that all sql generation classes inherit from.""" + import datetime import logging import os diff --git a/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py b/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py index fe54dcf4..29d95794 100644 --- a/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -493,8 +493,9 @@ def create_sql_agent( suffix: str = FINETUNING_AGENT_SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: List[str] | None = None, - max_iterations: int - | None = int(os.getenv("AGENT_MAX_ITERATIONS", "12")), # noqa: B008 + max_iterations: int | None = int( + os.getenv("AGENT_MAX_ITERATIONS", "12") + ), # noqa: B008 max_execution_time: float | None = None, early_stopping_method: str = "generate", verbose: bool = False, diff --git a/services/engine/dataherald/sql_generator/dataherald_sqlagent.py b/services/engine/dataherald/sql_generator/dataherald_sqlagent.py index a93a3091..d49e2589 100644 --- a/services/engine/dataherald/sql_generator/dataherald_sqlagent.py +++ b/services/engine/dataherald/sql_generator/dataherald_sqlagent.py @@ -655,8 +655,9 @@ def create_sql_agent( input_variables: List[str] | None = None, max_examples: int = 20, number_of_instructions: int = 1, - max_iterations: int - | None = int(os.getenv("AGENT_MAX_ITERATIONS", "15")), # noqa: B008 + max_iterations: int | None = int( + os.getenv("AGENT_MAX_ITERATIONS", "15") + ), # noqa: B008 max_execution_time: float | None = None, early_stopping_method: str = "generate", verbose: bool = False, diff --git a/services/engine/dataherald/utils/s3.py b/services/engine/dataherald/utils/s3.py index 96497d96..6e375d3f 100644 --- a/services/engine/dataherald/utils/s3.py +++ b/services/engine/dataherald/utils/s3.py @@ -12,7 +12,12 @@ class S3: def __init__(self): self.settings = Settings() - def _get_client(self, access_key: str | None = None, secret_access_key: str | None = None, region: str | None = None) -> boto3.client: + def _get_client( + self, + access_key: str | None = None, + secret_access_key: str | None = None, + region: str | None = None, + ) -> boto3.client: _access_key = access_key or self.settings.s3_aws_access_key_id _secret_access_key = secret_access_key or self.settings.s3_aws_secret_access_key _region = region or self.settings.s3_region @@ -44,7 +49,9 @@ def upload(self, file_location, file_storage: FileStorage | None = None) -> str: bucket_name = file_storage.bucket s3_client = self._get_client( access_key=fernet_encrypt.decrypt(file_storage.access_key_id), - secret_access_key=fernet_encrypt.decrypt(file_storage.secret_access_key), + secret_access_key=fernet_encrypt.decrypt( + file_storage.secret_access_key + ), region=file_storage.region, ) else: @@ -63,7 +70,9 @@ def download(self, path: str, file_storage: FileStorage | None = None) -> str: fernet_encrypt = FernetEncrypt() s3_client = self._get_client( access_key=fernet_encrypt.decrypt(file_storage.access_key_id), - secret_access_key=fernet_encrypt.decrypt(file_storage.secret_access_key), + secret_access_key=fernet_encrypt.decrypt( + file_storage.secret_access_key + ), region=file_storage.region, ) else: diff --git a/services/engine/dataherald/vector_store/astra.py b/services/engine/dataherald/vector_store/astra.py index c3307b59..71bcf7ac 100644 --- a/services/engine/dataherald/vector_store/astra.py +++ b/services/engine/dataherald/vector_store/astra.py @@ -94,9 +94,11 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str): { "_id": str(golden_sqls[key].id), "$vector": embeds[key], - "tables_used": ", ".join(Parser(golden_sqls[key].sql)) - if isinstance(Parser(golden_sqls[key].sql), list) - else "", + "tables_used": ( + ", ".join(Parser(golden_sqls[key].sql)) + if isinstance(Parser(golden_sqls[key].sql), list) + else "" + ), "db_connection_id": str(golden_sqls[key].db_connection_id), } ) diff --git a/services/engine/dataherald/vector_store/chroma.py b/services/engine/dataherald/vector_store/chroma.py index 6aa1cd02..90d0af81 100644 --- a/services/engine/dataherald/vector_store/chroma.py +++ b/services/engine/dataherald/vector_store/chroma.py @@ -47,9 +47,11 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str): collection, [ { - "tables_used": ", ".join(Parser(golden_sql.sql)) - if isinstance(Parser(golden_sql.sql), list) - else "", + "tables_used": ( + ", ".join(Parser(golden_sql.sql)) + if isinstance(Parser(golden_sql.sql), list) + else "" + ), "db_connection_id": str(golden_sql.db_connection_id), } ], diff --git a/services/engine/setup.py b/services/engine/setup.py index 6a964a71..d24c2cd2 100644 --- a/services/engine/setup.py +++ b/services/engine/setup.py @@ -1,4 +1,5 @@ """Set up the package.""" + import os from pathlib import Path