Skip to content

Commit 9e3f2aa

Browse files
committed
Remove time travel test for cloud engines, handle pyspark DFs in dbx
1 parent 55c5ffc commit 9e3f2aa

2 files changed

Lines changed: 72 additions & 52 deletions

File tree

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SourceQuery,
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
17+
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
1718
from sqlmesh.core.node import IntervalUnit
1819
from sqlmesh.core.schema_diff import NestedSupport
1920
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
@@ -379,38 +380,59 @@ def _record_execution_stats(
379380
except:
380381
return
381382

382-
history = self.cursor.fetchall_arrow()
383-
if history.num_rows:
384-
history_df = history.to_pandas()
385-
write_df = history_df[history_df["operation"] == "WRITE"]
386-
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
387-
if not write_df.empty:
388-
metrics = write_df["operationMetrics"][0]
389-
if metrics:
390-
rowcount = None
391-
rowcount_str = [
392-
metric[1] for metric in metrics if metric[0] == "numOutputRows"
393-
]
394-
if rowcount_str:
395-
try:
396-
rowcount = int(rowcount_str[0])
397-
except (TypeError, ValueError):
398-
pass
399-
400-
bytes_processed = None
401-
bytes_str = [
402-
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
403-
]
404-
if bytes_str:
405-
try:
406-
bytes_processed = int(bytes_str[0])
407-
except (TypeError, ValueError):
408-
pass
409-
410-
if rowcount is not None or bytes_processed is not None:
411-
# if no rows were written, df contains 0 for bytes but no value for rows
412-
rowcount = (
413-
0 if rowcount is None and bytes_processed is not None else rowcount
414-
)
415-
416-
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)
383+
history = (
384+
self.cursor.fetchdf()
385+
if isinstance(self.cursor, SparkSessionCursor)
386+
else self.cursor.fetchall_arrow()
387+
)
388+
if history is not None:
389+
from pandas import DataFrame as PandasDataFrame
390+
from pyspark.sql import DataFrame as PySparkDataFrame
391+
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
392+
393+
history_df = None
394+
if isinstance(history, PandasDataFrame):
395+
history_df = history
396+
elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)):
397+
history_df = history.toPandas()
398+
else:
399+
# arrow table
400+
history_df = history.to_pandas()
401+
402+
if history_df is not None and not history_df.empty:
403+
write_df = history_df[history_df["operation"] == "WRITE"]
404+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
405+
if not write_df.empty:
406+
metrics = write_df["operationMetrics"][0]
407+
if metrics:
408+
rowcount = None
409+
rowcount_str = [
410+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
411+
]
412+
if rowcount_str:
413+
try:
414+
rowcount = int(rowcount_str[0])
415+
except (TypeError, ValueError):
416+
pass
417+
418+
bytes_processed = None
419+
bytes_str = [
420+
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
421+
]
422+
if bytes_str:
423+
try:
424+
bytes_processed = int(bytes_str[0])
425+
except (TypeError, ValueError):
426+
pass
427+
428+
if rowcount is not None or bytes_processed is not None:
429+
# if no rows were written, df contains 0 for bytes but no value for rows
430+
rowcount = (
431+
0
432+
if rowcount is None and bytes_processed is not None
433+
else rowcount
434+
)
435+
436+
QueryExecutionTracker.record_execution(
437+
sql, rowcount, bytes_processed
438+
)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,23 +2431,21 @@ def capture_execution_stats(
24312431
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24322432

24332433
# run that loads 0 rows in incremental model
2434-
actual_execution_stats = {}
2435-
with patch.object(
2436-
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2437-
):
2438-
with time_machine.travel(date.today() + timedelta(days=1)):
2439-
context.run()
2440-
2441-
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2442-
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2443-
# snowflake doesn't track rows for CTAS
2444-
assert actual_execution_stats["full_model"].total_rows_processed == (
2445-
None if ctx.mark.startswith("snowflake") else 3
2446-
)
2447-
2448-
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2449-
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2450-
assert actual_execution_stats["full_model"].total_bytes_processed is not None
2434+
# - some cloud DBs error because time travel messes up token expiration
2435+
if not ctx.is_remote:
2436+
actual_execution_stats = {}
2437+
with patch.object(
2438+
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2439+
):
2440+
with time_machine.travel(date.today() + timedelta(days=1)):
2441+
context.run()
2442+
2443+
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2444+
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2445+
# snowflake doesn't track rows for CTAS
2446+
assert actual_execution_stats["full_model"].total_rows_processed == (
2447+
None if ctx.mark.startswith("snowflake") else 3
2448+
)
24512449

24522450
# make and validate unmodified dev environment
24532451
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)