Skip to content

Commit 2cf32bc

Browse files
committed
Add tests for BigFrame and Snowpark dataframes
1 parent 6c283e8 commit 2cf32bc

10 files changed

Lines changed: 144 additions & 36 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ athena = ["PyAthena[Pandas]"]
4040
azuresql = ["pymssql"]
4141
bigquery = [
4242
"google-cloud-bigquery[pandas]",
43-
"google-cloud-bigquery-storage"
43+
"google-cloud-bigquery-storage",
44+
"bigframes>=1.32.0"
4445
]
45-
bigframes = ["bigframes>=1.32.0"]
4646
clickhouse = ["clickhouse-connect"]
4747
databricks = ["databricks-sql-connector[pyarrow]"]
4848
dev = [
@@ -107,8 +107,7 @@ slack = ["slack_sdk"]
107107
snowflake = [
108108
"cryptography",
109109
"snowflake-connector-python[pandas,secure-local-storage]",
110-
# as at 2024-08-05, snowflake-snowpark-python is only available up to Python 3.11
111-
"snowflake-snowpark-python; python_version<'3.12'",
110+
"snowflake-snowpark-python",
112111
]
113112
trino = ["trino"]
114113
web = [

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _df_to_source_queries(
249249

250250
# we need to ensure that the order of the columns in columns_to_types columns matches the order of the values
251251
# they can differ if a user specifies columns() on a python model in a different order than what's in the DataFrame's emitted by that model
252-
df = df[list(columns_to_types.keys())]
252+
df = df[list(columns_to_types)]
253253
values = list(df.itertuples(index=False, name=None))
254254

255255
return [

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def query_factory() -> Query:
219219
if not self.table_exists(temp_table):
220220
columns_to_types_create = columns_to_types.copy()
221221
ordered_df = df[
222-
list(columns_to_types_create.keys())
222+
list(columns_to_types_create)
223223
] # reorder DataFrame so it matches columns_to_types
224224
self._convert_df_datetime(ordered_df, columns_to_types_create)
225225
self.create_table(temp_table, columns_to_types_create)

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,25 @@ def _df_to_source_queries(
288288
is_snowpark_dataframe = snowpark and isinstance(df, snowpark.dataframe.DataFrame)
289289

290290
def query_factory() -> Query:
291+
# The catalog needs to be normalized before being passed to Snowflake's library functions because they
292+
# just wrap whatever they are given in quotes without checking if its already quoted
293+
database = (
294+
normalize_identifiers(temp_table.catalog, dialect=self.dialect)
295+
if temp_table.catalog
296+
else None
297+
)
298+
291299
if is_snowpark_dataframe:
292-
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect, identify=True)) # type: ignore
300+
temp_table.set("catalog", database)
301+
df_renamed = df.rename(
302+
{
303+
col: exp.to_identifier(col, quoted=True).sql(dialect=self.dialect)
304+
for col in columns_to_types
305+
}
306+
) # type: ignore
307+
df_renamed.createOrReplaceTempView(
308+
temp_table.sql(dialect=self.dialect, identify=True)
309+
) # type: ignore
293310
elif isinstance(df, pd.DataFrame):
294311
from snowflake.connector.pandas_tools import write_pandas
295312

@@ -325,11 +342,7 @@ def query_factory() -> Query:
325342
df,
326343
temp_table.name,
327344
schema=temp_table.db or None,
328-
database=normalize_identifiers(temp_table.catalog, dialect=self.dialect).sql(
329-
dialect=self.dialect
330-
)
331-
if temp_table.catalog
332-
else None,
345+
database=database.sql(dialect=self.dialect) if database else None,
333346
chunk_size=self.DEFAULT_BATCH_SIZE,
334347
overwrite=True,
335348
table_type="temp",

sqlmesh/core/engine_adapter/spark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,14 @@ def _ensure_pyspark_df(
281281
if pyspark_df:
282282
if columns_to_types:
283283
# ensure Spark dataframe column order matches columns_to_types
284-
pyspark_df = pyspark_df.select(*list(columns_to_types.keys()))
284+
pyspark_df = pyspark_df.select(*list(columns_to_types))
285285
return pyspark_df
286286
df = self.try_get_pandas_df(generic_df)
287287
if df is None:
288288
raise SQLMeshError("Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame")
289289
if columns_to_types:
290290
# ensure Pandas dataframe column order matches columns_to_types
291-
df = df[list(columns_to_types.keys())]
291+
df = df[list(columns_to_types)]
292292
kwargs = (
293293
dict(schema=self.sqlglot_to_spark_types(columns_to_types)) if columns_to_types else {}
294294
)

tests/core/engine_adapter/integration/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def get_table_comment(
359359
FROM pg_class c
360360
INNER JOIN pg_description d ON c.oid = d.objoid AND d.objsubid = 0
361361
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
362-
WHERE
362+
WHERE
363363
c.relname = '{table_name}'
364364
AND n.nspname= '{schema_name}'
365365
AND c.relkind = '{"v" if table_kind == "VIEW" else "r"}'
@@ -465,12 +465,12 @@ def get_column_comments(
465465
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
466466
INNER JOIN pg_attribute a ON c.oid = a.attrelid
467467
INNER JOIN pg_description d
468-
ON
468+
ON
469469
a.attnum = d.objsubid
470470
AND d.objoid = c.oid
471471
WHERE
472472
n.nspname = '{schema_name}'
473-
AND c.relname = '{table_name}'
473+
AND c.relname = '{table_name}'
474474
AND c.relkind = '{"v" if table_kind == "VIEW" else "r"}'
475475
;
476476
"""
@@ -494,6 +494,7 @@ def create_context(
494494
self,
495495
config_mutator: t.Optional[t.Callable[[str, Config], None]] = None,
496496
path: t.Optional[pathlib.Path] = None,
497+
ephemeral_state_connection: bool = True,
497498
) -> Context:
498499
private_sqlmesh_dir = pathlib.Path(pathlib.Path().home(), ".sqlmesh")
499500
config = load_config_from_paths(
@@ -509,14 +510,12 @@ def create_context(
509510
config.gateways = {self.gateway: config.gateways[self.gateway]}
510511

511512
gateway_config = config.gateways[self.gateway]
512-
if (
513-
(sc := gateway_config.state_connection)
514-
and (conn := gateway_config.connection)
515-
and sc.type_ == "duckdb"
516-
):
517-
# if duckdb is being used as the state connection, set concurrent_tasks=1 on the main connection
518-
# to prevent duckdb from being accessed from multiple threads and getting deadlocked
519-
conn.concurrent_tasks = 1
513+
if ephemeral_state_connection:
514+
# Override whatever state connection has been configured on the integration test config to use in-memory DuckDB instead
515+
# This is so tests that initialize a SQLMesh context can run concurrently without clobbering each others state
516+
from sqlmesh.core.config.connection import DuckDBConnectionConfig
517+
518+
gateway_config.state_connection = DuckDBConnectionConfig()
520519

521520
if "athena" in self.gateway:
522521
conn = gateway_config.connection

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2721,7 +2721,9 @@ def _use_warehouse_as_state_connection(gateway_name: str, config: Config):
27212721

27222722
config.gateways[gateway_name].state_schema = test_schema
27232723

2724-
sqlmesh_context = ctx.create_context(config_mutator=_use_warehouse_as_state_connection)
2724+
sqlmesh_context = ctx.create_context(
2725+
config_mutator=_use_warehouse_as_state_connection, ephemeral_state_connection=False
2726+
)
27252727
assert sqlmesh_context.config.get_state_schema(ctx.gateway) == test_schema
27262728

27272729
state_sync = (
@@ -2742,8 +2744,7 @@ def test_python_model_column_order(ctx: TestContext, tmp_path_factory: pytest.Te
27422744
pytest.skip("python model column order test only needs to be run once per db")
27432745

27442746
tmp_path = tmp_path_factory.mktemp(f"column_order_{ctx.test_id}")
2745-
2746-
test_schema = ctx.add_test_suffix("column_order")
2747+
schema = ctx.add_test_suffix(TEST_SCHEMA)
27472748

27482749
(tmp_path / "models").mkdir()
27492750

@@ -2772,7 +2773,7 @@ def execute(
27722773
return context.spark.createDataFrame([
27732774
Row(name="foo", id=1)
27742775
])
2775-
""".replace("TEST_SCHEMA", test_schema)
2776+
""".replace("TEST_SCHEMA", schema)
27762777
)
27772778
else:
27782779
# python model that emits a Pandas DataFrame
@@ -2796,7 +2797,7 @@ def execute(
27962797
return pd.DataFrame([
27972798
{"name": "foo", "id": 1}
27982799
])
2799-
""".replace("TEST_SCHEMA", test_schema)
2800+
""".replace("TEST_SCHEMA", schema)
28002801
)
28012802

28022803
sqlmesh_ctx = ctx.create_context(path=tmp_path)
@@ -2808,6 +2809,9 @@ def execute(
28082809

28092810
engine_adapter = sqlmesh_ctx.engine_adapter
28102811

2811-
df = engine_adapter.fetchdf(f"select * from {test_schema}.model")
2812+
query = exp.select("*").from_(
2813+
exp.to_table(f"{schema}.model", dialect=ctx.dialect), dialect=ctx.dialect
2814+
)
2815+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
28122816
assert len(df) == 1
28132817
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/integration/test_integration_bigquery.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,50 @@ def test_table_diff_table_name_matches_column_name(ctx: TestContext):
433433

434434
assert row_diff.stats["join_count"] == 1
435435
assert row_diff.full_match_count == 1
436+
437+
438+
def test_bigframe_python_model_column_order(ctx: TestContext, tmp_path: Path):
439+
model_name = ctx.table("TEST")
440+
441+
(tmp_path / "models").mkdir()
442+
443+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
444+
# is returned by the DataFrame within the model
445+
model_path = tmp_path / "models" / "python_model.py"
446+
447+
# python model that emits a BigFrame dataframe
448+
model_path.write_text(
449+
"""
450+
from bigframes.pandas import DataFrame
451+
import typing as t
452+
from sqlmesh import ExecutionContext, model
453+
454+
@model(
455+
'MODEL_NAME',
456+
columns={
457+
"id": "int",
458+
"name": "varchar"
459+
},
460+
dialect="bigquery"
461+
)
462+
def execute(
463+
context: ExecutionContext,
464+
**kwargs: t.Any,
465+
) -> DataFrame:
466+
return DataFrame({'name': ['foo'], 'id': [1]}, session=context.bigframe)
467+
""".replace("MODEL_NAME", model_name.sql(dialect="bigquery"))
468+
)
469+
470+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
471+
472+
assert len(sqlmesh_ctx.models) == 1
473+
474+
plan = sqlmesh_ctx.plan(auto_apply=True)
475+
assert len(plan.new_snapshots) == 1
476+
477+
engine_adapter = sqlmesh_ctx.engine_adapter
478+
479+
query = exp.select("*").from_(model_name)
480+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
481+
assert len(df) == 1
482+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
22
import pytest
33
from sqlglot import exp
4+
from pathlib import Path
45
from sqlglot.optimizer.qualify_columns import quote_identifiers
56
from sqlglot.helper import seq_get
67
from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter
@@ -210,3 +211,49 @@ def test_create_iceberg_table(ctx: TestContext, engine_adapter: SnowflakeEngineA
210211
result = sqlmesh.plan(auto_apply=True)
211212

212213
assert len(result.new_snapshots) == 2
214+
215+
216+
def test_snowpark_python_model_column_order(ctx: TestContext, tmp_path: Path):
217+
model_name = ctx.table("TEST")
218+
219+
(tmp_path / "models").mkdir()
220+
221+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
222+
# is returned by the DataFrame within the model
223+
model_path = tmp_path / "models" / "python_model.py"
224+
225+
# python model that emits a Snowpark DataFrame
226+
model_path.write_text(
227+
"""
228+
from snowflake.snowpark.dataframe import DataFrame
229+
import typing as t
230+
from sqlmesh import ExecutionContext, model
231+
232+
@model(
233+
'MODEL_NAME',
234+
columns={
235+
"id": "int",
236+
"name": "varchar"
237+
}
238+
)
239+
def execute(
240+
context: ExecutionContext,
241+
**kwargs: t.Any,
242+
) -> DataFrame:
243+
return context.snowpark.create_dataframe([["foo", 1]], schema=["name", "id"])
244+
""".replace("MODEL_NAME", model_name.sql(dialect="snowflake"))
245+
)
246+
247+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
248+
249+
assert len(sqlmesh_ctx.models) == 1
250+
251+
plan = sqlmesh_ctx.plan(auto_apply=True)
252+
assert len(plan.new_snapshots) == 1
253+
254+
engine_adapter = sqlmesh_ctx.engine_adapter
255+
256+
query = exp.select("*").from_(plan.environment.snapshots[0].fully_qualified_table)
257+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
258+
assert len(df) == 1
259+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/test_snowflake.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ def test_replace_query_snowpark_dataframe(
424424
from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame
425425

426426
session = Session.builder.config("local_testing", True).create()
427+
# df.createOrReplaceTempView() throws "[Local Testing] Mocking SnowflakePlan Rename is not supported" when used against the Snowflake local_testing session
428+
# since we cant trace any queries from the Snowpark library anyway, we just suppress this and verify the cleanup queries issued by our EngineAdapter
429+
session._conn._suppress_not_implemented_error = True
430+
427431
df: SnowparkDataFrame = session.create_dataframe([(1, "name")], schema=["ID", "NAME"])
428432
assert isinstance(df, SnowparkDataFrame)
429433

@@ -439,11 +443,6 @@ def test_replace_query_snowpark_dataframe(
439443
columns_to_types={"ID": exp.DataType.build("INT"), "NAME": exp.DataType.build("VARCHAR")},
440444
)
441445

442-
# the Snowflake library generates "CREATE TEMPORARY VIEW" from a direct DataFrame call
443-
# which doesnt pass through our EngineAdapter so we cant capture it
444-
spy.assert_called()
445-
assert "__temp_foo_e6wjkjj6" in spy.call_args[0][0]
446-
447446
# verify that DROP VIEW is called instead of DROP TABLE
448447
assert to_sql_calls(adapter) == [
449448
'CREATE OR REPLACE TABLE "foo" AS SELECT CAST("ID" AS INT) AS "ID", CAST("NAME" AS VARCHAR) AS "NAME" FROM (SELECT CAST("ID" AS INT) AS "ID", CAST("NAME" AS VARCHAR) AS "NAME" FROM "__temp_foo_e6wjkjj6") AS "_subquery"',

0 commit comments

Comments
 (0)