Skip to content

Commit b7a5e07

Browse files
committed
Handle snowflake lack of CTAS tracking
1 parent 59cec3b commit b7a5e07

4 files changed

Lines changed: 55 additions & 6 deletions

File tree

sqlmesh/core/engine_adapter/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,11 @@ def _log_sql(
24342434

24352435
logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log)
24362436

2437+
def _record_execution_stats(
2438+
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
2439+
) -> None:
2440+
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)
2441+
24372442
def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.Any) -> None:
24382443
self.cursor.execute(sql, **kwargs)
24392444

@@ -2450,7 +2455,7 @@ def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.An
24502455
except (TypeError, ValueError):
24512456
pass
24522457

2453-
QueryExecutionTracker.record_execution(sql, rowcount, None)
2458+
self._record_execution_stats(sql, rowcount)
24542459

24552460
@contextlib.contextmanager
24562461
def temp_table(

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import logging
5+
import re
56
import typing as t
67

78
from sqlglot import exp
@@ -24,6 +25,7 @@
2425
set_catalog,
2526
)
2627
from sqlmesh.core.schema_diff import SchemaDiffer
28+
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
2729
from sqlmesh.utils import optional_import, get_source_columns_to_types
2830
from sqlmesh.utils.errors import SQLMeshError
2931
from sqlmesh.utils.pandas import columns_to_types_from_dtypes
@@ -73,6 +75,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
7375
)
7476
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
7577
SNOWPARK = "snowpark"
78+
SUPPORTS_QUERY_EXECUTION_TRACKING = True
7679

7780
@contextlib.contextmanager
7881
def session(self, properties: SessionProperties) -> t.Iterator[None]:
@@ -665,3 +668,33 @@ def close(self) -> t.Any:
665668
self._connection_pool.set_attribute(self.SNOWPARK, None)
666669

667670
return super().close()
671+
672+
def _record_execution_stats(
673+
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
674+
) -> None:
675+
"""Snowflake does not report row counts for CTAS like other DML operations.
676+
677+
They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row
678+
containing the string "Table <table_name> successfully created." and a row count of 1.
679+
680+
We do not want to record the row count of 1 for CTAS operations, so we check for that data pattern and return
681+
early if it is detected.
682+
683+
Regex explanation - Snowflake identifiers may be:
684+
- An unquoted contiguous set of [a-zA-Z0-9_$] characters
685+
- A double-quoted string that may contain spaces and nested double-quotes represented by `""`
686+
- Example: " my ""table"" name "
687+
- Pattern: "(?:[^"]|"")+"
688+
- ?: is a non-capturing group
689+
- [^"] matches any single character except a double-quote
690+
- "" matches two sequential double-quotes
691+
"""
692+
if rowcount == 1:
693+
results = self.cursor.fetchall()
694+
if results and len(results) == 1:
695+
is_ctas = re.match(
696+
r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results[0][0]
697+
)
698+
if is_ctas:
699+
return
700+
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)

sqlmesh/core/snapshot/execution_tracker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
@dataclass
1111
class QueryExecutionStats:
1212
snapshot_batch_id: str
13-
total_rows_processed: int = 0
14-
total_bytes_processed: int = 0
13+
total_rows_processed: t.Optional[int] = None
14+
total_bytes_processed: t.Optional[int] = None
1515
query_count: int = 0
1616
queries_executed: t.List[t.Tuple[str, t.Optional[int], t.Optional[int], float]] = field(
1717
default_factory=list
@@ -42,12 +42,18 @@ def add_execution(
4242
self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
4343
) -> None:
4444
if row_count is not None and row_count >= 0:
45-
self.stats.total_rows_processed += row_count
45+
if self.stats.total_rows_processed is None:
46+
self.stats.total_rows_processed = row_count
47+
else:
48+
self.stats.total_rows_processed += row_count
4649

4750
# conditional on row_count because we should only count bytes corresponding to
4851
# DML actions whose rows were captured
4952
if bytes_processed is not None and bytes_processed >= 0:
50-
self.stats.total_bytes_processed += bytes_processed
53+
if self.stats.total_bytes_processed is None:
54+
self.stats.total_bytes_processed = bytes_processed
55+
else:
56+
self.stats.total_bytes_processed += bytes_processed
5157

5258
self.stats.query_count += 1
5359
# TODO: remove this

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2457,7 +2457,12 @@ def capture_execution_stats(
24572457

24582458
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
24592459
assert actual_execution_stats["incremental_model"].total_rows_processed == 7
2460-
assert actual_execution_stats["full_model"].total_rows_processed == 3
2460+
# snowflake doesn't track rows for CTAS
2461+
assert actual_execution_stats["full_model"].total_rows_processed == (
2462+
None if ctx.mark.startswith("snowflake") else 3
2463+
)
2464+
# seed rows aren't tracked
2465+
assert actual_execution_stats["seed_model"].total_rows_processed is None
24612466

24622467
if ctx.mark.startswith("bigquery"):
24632468
assert actual_execution_stats["incremental_model"].total_bytes_processed

0 commit comments

Comments
 (0)