Skip to content

Commit

Permalink
Update .env.example (#508)
Browse files Browse the repository at this point in the history
* Update .env.example

* fix formatting issues

* update formatter to use official black formatter
  • Loading branch information
aazo11 authored Jun 24, 2024
1 parent f3cb505 commit 19080de
Show file tree
Hide file tree
Showing 18 changed files with 116 additions and 71 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
- uses: rickstaa/action-black@v1
with:
black_args: ". --check"
- uses: psf/black@stable

2 changes: 1 addition & 1 deletion services/engine/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ AGENT_MAX_ITERATIONS = 15
#timeout in seconds for the engine to return a response. Defaults to 150 seconds
DH_ENGINE_TIMEOUT = 150
#tmeout for SQL execution, our agents exceute the SQL query to recover from errors, this is the timeout for that execution. Defaults to 30 seconds
SQL_EXECUTION_TIMEOUT =
SQL_EXECUTION_TIMEOUT = 30
#The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50
UPPER_LIMIT_QUERY_RETURN_ROWS = 50
#Encryption key for storing DB connection data in Mongo
Expand Down
8 changes: 5 additions & 3 deletions services/engine/dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,11 @@ def create_finetuning_job(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
schemas=fine_tuning_request.schemas,
alias=fine_tuning_request.alias
if fine_tuning_request.alias
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}",
alias=(
fine_tuning_request.alias
if fine_tuning_request.alias
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}"
),
base_llm=base_llm,
golden_sqls=[str(golden_sql.id) for golden_sql in golden_sqls],
metadata=fine_tuning_request.metadata,
Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class that all scanner classes inherit from."""

from abc import ABC, abstractmethod

from dataherald.config import Component
Expand Down
16 changes: 9 additions & 7 deletions services/engine/dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ def create_fine_tuning_job(self):
finetuning_request = self.client.fine_tuning.jobs.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters=model.base_llm.model_parameters
if model.base_llm.model_parameters
else {
"batch_size": 1,
"learning_rate_multiplier": "auto",
"n_epochs": 3,
},
hyperparameters=(
model.base_llm.model_parameters
if model.base_llm.model_parameters
else {
"batch_size": 1,
"learning_rate_multiplier": "auto",
"n_epochs": 3,
}
),
)
model.finetuning_job_id = finetuning_request.id
if finetuning_request.status == "failed":
Expand Down
40 changes: 23 additions & 17 deletions services/engine/dataherald/scripts/migrate_v006_to_v100.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def update_object_id_fields(field_name: str, collection_name: str):
"_id": question["_id"],
"db_connection_id": str(question["db_connection_id"]),
"text": question["question"],
"created_at": None
if len(responses) == 0
else responses[0]["created_at"],
"created_at": (
None if len(responses) == 0 else responses[0]["created_at"]
),
"metadata": None,
},
)
Expand All @@ -112,17 +112,21 @@ def update_object_id_fields(field_name: str, collection_name: str):
{
"_id": response["_id"],
"prompt_id": str(response["question_id"]),
"evaluate": False
if response["confidence_score"] is None
else True,
"evaluate": (
False if response["confidence_score"] is None else True
),
"sql": response["sql_query"],
"status": "VALID"
if response["sql_generation_status"] == "VALID"
else "INVALID",
"completed_at": response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else None,
"status": (
"VALID"
if response["sql_generation_status"] == "VALID"
else "INVALID"
),
"completed_at": (
response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else None
),
"tokens_used": response["total_tokens"],
"confidence_score": response["confidence_score"],
"error": response["error_message"],
Expand All @@ -140,10 +144,12 @@ def update_object_id_fields(field_name: str, collection_name: str):
{
"sql_generation_id": str(response["_id"]),
"text": response["response"],
"created_at": response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else response["created_at"],
"created_at": (
response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else response["created_at"]
),
"metadata": None,
},
)
Expand Down
6 changes: 3 additions & 3 deletions services/engine/dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ def export_csv_file(self, sql_generation_id: str) -> StreamingResponse:
stream = self._api.export_csv_file(sql_generation_id)

response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers[
"Content-Disposition"
] = f"attachment; filename=sql_generation_{sql_generation_id}.csv"
response.headers["Content-Disposition"] = (
f"attachment; filename=sql_generation_{sql_generation_id}.csv"
)
return response

def delete_golden_sql(self, golden_sql_id: str) -> dict:
Expand Down
16 changes: 10 additions & 6 deletions services/engine/dataherald/services/nl_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def create(
initial_nl_generation = NLGeneration(
sql_generation_id=sql_generation_id,
created_at=datetime.now(),
llm_config=nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig(),
llm_config=(
nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig()
),
metadata=nl_generation_request.metadata,
)
self.nl_generation_repository.insert(initial_nl_generation)
Expand All @@ -46,9 +48,11 @@ def create(
nl_generator = GeneratesNlAnswer(
self.system,
self.storage,
nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig(),
(
nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig()
),
)
try:
nl_generation = nl_generator.execute(
Expand Down
48 changes: 30 additions & 18 deletions services/engine/dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def create( # noqa: PLR0912
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
llm_config=sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
llm_config=(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
metadata=sql_generation_request.metadata,
)
langsmith_metadata = (
Expand Down Expand Up @@ -115,16 +117,20 @@ def create( # noqa: PLR0912
)
sql_generator = DataheraldSQLAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
else:
sql_generator = DataheraldFinetuningAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
sql_generator.finetuning_id = sql_generation_request.finetuning_id
sql_generator.use_fintuned_model_only = (
Expand Down Expand Up @@ -184,9 +190,11 @@ def start_streaming(
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
llm_config=sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
llm_config=(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
metadata=sql_generation_request.metadata,
)
langsmith_metadata = (
Expand Down Expand Up @@ -215,16 +223,20 @@ def start_streaming(
)
sql_generator = DataheraldSQLAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
else:
sql_generator = DataheraldFinetuningAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
sql_generator.finetuning_id = sql_generation_request.finetuning_id
sql_generator.use_fintuned_model_only = (
Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/smart_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class that all cache classes inherit from."""

from abc import ABC, abstractmethod
from typing import Any, Union

Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SQL wrapper around SQLDatabase in langchain."""

import logging
import re
from typing import List
Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class that all sql generation classes inherit from."""

import datetime
import logging
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,9 @@ def create_sql_agent(
suffix: str = FINETUNING_AGENT_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: List[str] | None = None,
max_iterations: int
| None = int(os.getenv("AGENT_MAX_ITERATIONS", "12")), # noqa: B008
max_iterations: int | None = int(
os.getenv("AGENT_MAX_ITERATIONS", "12")
), # noqa: B008
max_execution_time: float | None = None,
early_stopping_method: str = "generate",
verbose: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,9 @@ def create_sql_agent(
input_variables: List[str] | None = None,
max_examples: int = 20,
number_of_instructions: int = 1,
max_iterations: int
| None = int(os.getenv("AGENT_MAX_ITERATIONS", "15")), # noqa: B008
max_iterations: int | None = int(
os.getenv("AGENT_MAX_ITERATIONS", "15")
), # noqa: B008
max_execution_time: float | None = None,
early_stopping_method: str = "generate",
verbose: bool = False,
Expand Down
15 changes: 12 additions & 3 deletions services/engine/dataherald/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ class S3:
def __init__(self):
self.settings = Settings()

def _get_client(self, access_key: str | None = None, secret_access_key: str | None = None, region: str | None = None) -> boto3.client:
def _get_client(
self,
access_key: str | None = None,
secret_access_key: str | None = None,
region: str | None = None,
) -> boto3.client:
_access_key = access_key or self.settings.s3_aws_access_key_id
_secret_access_key = secret_access_key or self.settings.s3_aws_secret_access_key
_region = region or self.settings.s3_region
Expand Down Expand Up @@ -44,7 +49,9 @@ def upload(self, file_location, file_storage: FileStorage | None = None) -> str:
bucket_name = file_storage.bucket
s3_client = self._get_client(
access_key=fernet_encrypt.decrypt(file_storage.access_key_id),
secret_access_key=fernet_encrypt.decrypt(file_storage.secret_access_key),
secret_access_key=fernet_encrypt.decrypt(
file_storage.secret_access_key
),
region=file_storage.region,
)
else:
Expand All @@ -63,7 +70,9 @@ def download(self, path: str, file_storage: FileStorage | None = None) -> str:
fernet_encrypt = FernetEncrypt()
s3_client = self._get_client(
access_key=fernet_encrypt.decrypt(file_storage.access_key_id),
secret_access_key=fernet_encrypt.decrypt(file_storage.secret_access_key),
secret_access_key=fernet_encrypt.decrypt(
file_storage.secret_access_key
),
region=file_storage.region,
)
else:
Expand Down
8 changes: 5 additions & 3 deletions services/engine/dataherald/vector_store/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
{
"_id": str(golden_sqls[key].id),
"$vector": embeds[key],
"tables_used": ", ".join(Parser(golden_sqls[key].sql))
if isinstance(Parser(golden_sqls[key].sql), list)
else "",
"tables_used": (
", ".join(Parser(golden_sqls[key].sql))
if isinstance(Parser(golden_sqls[key].sql), list)
else ""
),
"db_connection_id": str(golden_sqls[key].db_connection_id),
}
)
Expand Down
8 changes: 5 additions & 3 deletions services/engine/dataherald/vector_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
collection,
[
{
"tables_used": ", ".join(Parser(golden_sql.sql))
if isinstance(Parser(golden_sql.sql), list)
else "",
"tables_used": (
", ".join(Parser(golden_sql.sql))
if isinstance(Parser(golden_sql.sql), list)
else ""
),
"db_connection_id": str(golden_sql.db_connection_id),
}
],
Expand Down
1 change: 1 addition & 0 deletions services/engine/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Set up the package."""

import os
from pathlib import Path

Expand Down

0 comments on commit 19080de

Please sign in to comment.