Skip to content

Commit

Permalink
DH-5777/adding the safe int conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed May 17, 2024
1 parent dc4e743 commit 6c0dcb1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
6 changes: 6 additions & 0 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def replace_unprocessable_characters(text: str) -> str:
text = text.strip()
return text.replace(r"\_", "_")

def safe_int_conversion(value: str, default: int = 0) -> int:
try:
return int(value)
except ValueError:
return default


class SQLGenerator(Component, ABC):
metadata: Any
Expand Down
3 changes: 2 additions & 1 deletion dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_generator import safe_int_conversion
from dataherald.types import SQLGeneration
from dataherald.utils.timeout_utils import run_with_timeout

Expand Down Expand Up @@ -34,7 +35,7 @@ def create_sql_query_status(
run_with_timeout(
db.run_sql,
args=(query,),
timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")),
timeout_duration=safe_int_conversion(os.getenv("SQL_EXECUTION_TIMEOUT"), 60),
)
sql_generation.status = "VALID"
sql_generation.error = None
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from dataherald.sql_database.models.types import (
DatabaseConnection,
)
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator, safe_int_conversion
from dataherald.types import FineTuningStatus, Prompt, SQLGeneration
from dataherald.utils.agent_prompts import (
ERROR_PARSING_MESSAGE,
Expand Down Expand Up @@ -290,7 +290,7 @@ def _run(
self.db.run_sql,
args=(query,),
kwargs={"top_k": TOP_K},
timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")),
timeout_duration=safe_int_conversion(os.getenv("SQL_EXECUTION_TIMEOUT"), 60),
)[0]
except TimeoutError:
return "SQL query execution time exceeded, proceed without query execution"
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from dataherald.sql_database.models.types import (
DatabaseConnection,
)
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator, safe_int_conversion
from dataherald.types import Prompt, SQLGeneration
from dataherald.utils.agent_prompts import (
AGENT_PREFIX,
Expand Down Expand Up @@ -171,7 +171,7 @@ def _run(
self.db.run_sql,
args=(query,),
kwargs={"top_k": top_k},
timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")),
timeout_duration=safe_int_conversion(os.getenv("SQL_EXECUTION_TIMEOUT"), 60),
)[0]
except TimeoutError:
return "SQL query execution time exceeded, proceed without query execution"
Expand Down

0 comments on commit 6c0dcb1

Please sign in to comment.