Skip to content

Commit 6c283e8

Browse files
committed
fix tests
1 parent cdd062a commit 6c283e8

3 files changed

Lines changed: 57 additions & 17 deletions

File tree

sqlmesh/core/engine_adapter/spark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,16 @@ def _ensure_pyspark_df(
279279
) -> PySparkDataFrame:
280280
pyspark_df = self.try_get_pyspark_df(generic_df)
281281
if pyspark_df:
282+
if columns_to_types:
283+
# ensure Spark dataframe column order matches columns_to_types
284+
pyspark_df = pyspark_df.select(*list(columns_to_types.keys()))
282285
return pyspark_df
283286
df = self.try_get_pandas_df(generic_df)
284287
if df is None:
285288
raise SQLMeshError("Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame")
289+
if columns_to_types:
290+
# ensure Pandas dataframe column order matches columns_to_types
291+
df = df[list(columns_to_types.keys())]
286292
kwargs = (
287293
dict(schema=self.sqlglot_to_spark_types(columns_to_types)) if columns_to_types else {}
288294
)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2735,7 +2735,10 @@ def _use_warehouse_as_state_connection(gateway_name: str, config: Config):
27352735

27362736

27372737
def test_python_model_column_order(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
2738-
if ctx.test_type != "df":
2738+
if ctx.test_type == "pyspark" and ctx.dialect in ("spark", "databricks"):
2739+
# dont skip
2740+
pass
2741+
elif ctx.test_type != "df":
27392742
pytest.skip("python model column order test only needs to be run once per db")
27402743

27412744
tmp_path = tmp_path_factory.mktemp(f"column_order_{ctx.test_id}")
@@ -2746,8 +2749,35 @@ def test_python_model_column_order(ctx: TestContext, tmp_path_factory: pytest.Te
27462749

27472750
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
27482751
# is returned by the DataFrame within the model
2749-
(tmp_path / "models" / "python_model.py").write_text(
2750-
"""
2752+
model_path = tmp_path / "models" / "python_model.py"
2753+
if ctx.test_type == "pyspark":
2754+
# python model that emits a PySpark dataframe
2755+
model_path.write_text(
2756+
"""
2757+
from pyspark.sql import DataFrame, Row
2758+
import typing as t
2759+
from sqlmesh import ExecutionContext, model
2760+
2761+
@model(
2762+
"TEST_SCHEMA.model",
2763+
columns={
2764+
"id": "int",
2765+
"name": "varchar"
2766+
}
2767+
)
2768+
def execute(
2769+
context: ExecutionContext,
2770+
**kwargs: t.Any,
2771+
) -> DataFrame:
2772+
return context.spark.createDataFrame([
2773+
Row(name="foo", id=1)
2774+
])
2775+
""".replace("TEST_SCHEMA", test_schema)
2776+
)
2777+
else:
2778+
# python model that emits a Pandas DataFrame
2779+
model_path.write_text(
2780+
"""
27512781
import pandas as pd
27522782
import typing as t
27532783
from sqlmesh import ExecutionContext, model
@@ -2766,8 +2796,8 @@ def execute(
27662796
return pd.DataFrame([
27672797
{"name": "foo", "id": 1}
27682798
])
2769-
""".replace("TEST_SCHEMA", test_schema)
2770-
)
2799+
""".replace("TEST_SCHEMA", test_schema)
2800+
)
27712801

27722802
sqlmesh_ctx = ctx.create_context(path=tmp_path)
27732803

tests/core/engine_adapter/test_base.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
966966
def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable):
967967
adapter = make_mocked_engine_adapter(EngineAdapter)
968968

969-
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
969+
df = pd.DataFrame({"id": [1, 2, 3], "ts": [4, 5, 6], "val": [1, 2, 3]})
970970
adapter.merge(
971971
target_table="target",
972972
source_table=df,
@@ -978,7 +978,7 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable):
978978
unique_key=[exp.to_identifier("id")],
979979
)
980980
adapter.cursor.execute.assert_called_once_with(
981-
'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4), (2, 5), (3, 6)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" '
981+
'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4, 1), (2, 5, 2), (3, 6, 3)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" '
982982
'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" '
983983
'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")'
984984
)
@@ -995,7 +995,7 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable):
995995
unique_key=[exp.to_identifier("id"), exp.to_identifier("ts")],
996996
)
997997
adapter.cursor.execute.assert_called_once_with(
998-
'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4), (2, 5), (3, 6)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" AND "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts" '
998+
'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4, 1), (2, 5, 2), (3, 6, 3)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" AND "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts" '
999999
'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" '
10001000
'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")'
10011001
)
@@ -1175,23 +1175,23 @@ def test_merge_filter(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
11751175
"""
11761176
MERGE INTO "target" AS "__MERGE_TARGET__"
11771177
USING (
1178-
SELECT "ID", "ts", "val"
1178+
SELECT "ID", "ts", "val"
11791179
FROM "source"
11801180
) AS "__MERGE_SOURCE__"
11811181
ON (
1182-
"__MERGE_SOURCE__"."ID" > 0
1182+
"__MERGE_SOURCE__"."ID" > 0
11831183
AND "__MERGE_TARGET__"."ts" < TIMESTAMP("2020-02-05")
11841184
)
11851185
AND "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
1186-
WHEN MATCHED THEN
1187-
UPDATE SET
1186+
WHEN MATCHED THEN
1187+
UPDATE SET
11881188
"__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val",
11891189
"__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts")
1190-
WHEN NOT MATCHED THEN
1191-
INSERT ("ID", "ts", "val")
1190+
WHEN NOT MATCHED THEN
1191+
INSERT ("ID", "ts", "val")
11921192
VALUES (
1193-
"__MERGE_SOURCE__"."ID",
1194-
"__MERGE_SOURCE__"."ts",
1193+
"__MERGE_SOURCE__"."ID",
1194+
"__MERGE_SOURCE__"."ts",
11951195
"__MERGE_SOURCE__"."val"
11961196
);
11971197
""",
@@ -1585,7 +1585,11 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable):
15851585
"id2": [4, 5, 6],
15861586
"name": ["muffins", "chips", "soda"],
15871587
"price": [4.0, 5.0, 6.0],
1588-
"updated_at": ["2020-01-01 10:00:00", "2020-01-02 15:00:00", "2020-01-03 12:00:00"],
1588+
"test_updated_at": [
1589+
"2020-01-01 10:00:00",
1590+
"2020-01-02 15:00:00",
1591+
"2020-01-03 12:00:00",
1592+
],
15891593
}
15901594
)
15911595
adapter.scd_type_2_by_time(

0 commit comments

Comments
 (0)