From 6b55e7c757bb4c2afb750d198f5629d9cf7c349c Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Thu, 17 Oct 2024 11:57:23 -0400 Subject: [PATCH] add query_id to SQLQueryStatus (demonstration only) --- dbt/adapters/snowflake/connections.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 10bee30f0..85ecf0531 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -15,7 +15,7 @@ from contextlib import contextmanager from dataclasses import dataclass from io import StringIO -from time import sleep +from time import sleep, perf_counter from typing import Optional, Tuple, Union, Any, List, Iterable, TYPE_CHECKING @@ -43,8 +43,11 @@ DbtRuntimeError, DbtConfigError, ) +from dbt_common.events.contextvars import get_node_info from dbt_common.exceptions import DbtDatabaseError +from dbt_common.events.functions import fire_event from dbt_common.record import get_record_mode_from_env, RecorderMode +from dbt.adapters.events.types import SQLQueryStatus from dbt.adapters.exceptions.connection import FailedToConnectError from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials from dbt.adapters.sql import SQLConnectionManager @@ -86,7 +89,7 @@ def snowflake_private_key(private_key: RSAPrivateKey) -> bytes: @dataclass class SnowflakeAdapterResponse(AdapterResponse): - query_id: str = "" + pass @dataclass @@ -536,6 +539,8 @@ def add_query( bindings: Optional[Any] = None, abridge_sql_log: bool = False, ) -> Tuple[Connection, Any]: + pre = perf_counter() + if bindings: # The snowflake connector is stricter than, e.g., psycopg2 - # which allows any iterable thing to be passed as a binding. @@ -561,6 +566,15 @@ def add_query( if cursor is None: self._raise_cursor_not_found_error(sql) + fire_event( + SQLQueryStatus( + status=str(self.get_response(cursor)), + elapsed=perf_counter() - pre, + node_info=get_node_info(), + query_id=cursor.sfqid, + ) + ) + return connection, cursor def _stripped_queries(self, sql: str) -> List[str]: