Skip to content

Commit 9ea63b1

Browse files
committed
Ensure correct AST nodes are created when reading from state as well
1 parent 44496e0 commit 9ea63b1

4 files changed

Lines changed: 262 additions & 1 deletion

File tree

sqlmesh/core/model/meta.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing_extensions import Self
66

77
from pydantic import Field
8-
from sqlglot import Dialect, exp
8+
from sqlglot import Dialect, exp, parse_one
99
from sqlglot.helper import ensure_collection, ensure_list
1010
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1111

@@ -182,6 +182,14 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]:
182182
def _partition_and_cluster_validator(
183183
cls, v: t.Any, info: ValidationInfo
184184
) -> t.List[exp.Expression]:
185+
if isinstance(v, list) and info.field_name == "partitioned_by_":
186+
# this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str]
187+
string_to_parse = (
188+
f"({','.join(v)})" # recreate the (a, b, c) part of "partitioned_by (a, b, c)"
189+
)
190+
parsed = parse_one(string_to_parse, into=exp.PartitionedByProperty)
191+
v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v
192+
185193
expressions = list_of_fields_validator(v, info.data)
186194

187195
for expression in expressions:
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Remove superfluous exp.Paren references from partitioned_by"""
2+
3+
import json
4+
5+
import pandas as pd
6+
from sqlglot import exp
7+
8+
from sqlmesh.utils.migration import index_text_type
9+
from sqlmesh.utils.migration import blob_text_type
10+
11+
12+
def migrate(state_sync, **kwargs): # type: ignore
13+
engine_adapter = state_sync.engine_adapter
14+
schema = state_sync.schema
15+
snapshots_table = "_snapshots"
16+
index_type = index_text_type(engine_adapter.dialect)
17+
if schema:
18+
snapshots_table = f"{schema}.{snapshots_table}"
19+
20+
new_snapshots = []
21+
updated = False
22+
23+
for (
24+
name,
25+
identifier,
26+
version,
27+
snapshot,
28+
kind_name,
29+
updated_ts,
30+
unpaused_ts,
31+
ttl_ms,
32+
unrestorable,
33+
) in engine_adapter.fetchall(
34+
exp.select(
35+
"name",
36+
"identifier",
37+
"version",
38+
"snapshot",
39+
"kind_name",
40+
"updated_ts",
41+
"unpaused_ts",
42+
"ttl_ms",
43+
"unrestorable",
44+
).from_(snapshots_table),
45+
quote_identifiers=True,
46+
):
47+
parsed_snapshot = json.loads(snapshot)
48+
49+
if partitioned_by := parsed_snapshot["node"].get("partitioned_by"):
50+
new_partitioned_by = []
51+
for item in partitioned_by:
52+
# rewrite '(foo)' to 'foo'
53+
if item.startswith("(") and item.endswith(")"):
54+
item = item[1:-1]
55+
updated = True
56+
new_partitioned_by.append(item)
57+
parsed_snapshot["node"]["partitioned_by"] = new_partitioned_by
58+
59+
new_snapshots.append(
60+
{
61+
"name": name,
62+
"identifier": identifier,
63+
"version": version,
64+
"snapshot": json.dumps(parsed_snapshot),
65+
"kind_name": kind_name,
66+
"updated_ts": updated_ts,
67+
"unpaused_ts": unpaused_ts,
68+
"ttl_ms": ttl_ms,
69+
"unrestorable": unrestorable,
70+
}
71+
)
72+
73+
if new_snapshots and updated:
74+
engine_adapter.delete_from(snapshots_table, "TRUE")
75+
blob_type = blob_text_type(engine_adapter.dialect)
76+
77+
engine_adapter.insert_append(
78+
snapshots_table,
79+
pd.DataFrame(new_snapshots),
80+
columns_to_types={
81+
"name": exp.DataType.build(index_type),
82+
"identifier": exp.DataType.build(index_type),
83+
"version": exp.DataType.build(index_type),
84+
"snapshot": exp.DataType.build(blob_type),
85+
"kind_name": exp.DataType.build(index_type),
86+
"updated_ts": exp.DataType.build("bigint"),
87+
"unpaused_ts": exp.DataType.build("bigint"),
88+
"ttl_ms": exp.DataType.build("bigint"),
89+
"unrestorable": exp.DataType.build("boolean"),
90+
},
91+
)

tests/core/test_model.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,134 @@ def test_render_definition_with_defaults():
15141514
) == d.format_model_expressions(expected_expressions)
15151515

15161516

1517+
def test_render_definition_partitioned_by():
1518+
# no parenthesis in definition, no parenthesis when rendered
1519+
model = load_sql_based_model(
1520+
d.parse(
1521+
f"""
1522+
MODEL (
1523+
name db.table,
1524+
kind FULL,
1525+
partitioned_by a
1526+
);
1527+
1528+
select 1 as a;
1529+
"""
1530+
)
1531+
)
1532+
1533+
assert model.partitioned_by == [exp.column("a", quoted=True)]
1534+
assert (
1535+
model.render_definition()[0].sql(pretty=True)
1536+
== """MODEL (
1537+
name db.table,
1538+
kind FULL,
1539+
partitioned_by "a"
1540+
)"""
1541+
)
1542+
1543+
# single column wrapped in parenthesis in defintion, no parenthesis in rendered
1544+
model = load_sql_based_model(
1545+
d.parse(
1546+
f"""
1547+
MODEL (
1548+
name db.table,
1549+
kind FULL,
1550+
partitioned_by (a)
1551+
);
1552+
1553+
select 1 as a;
1554+
"""
1555+
)
1556+
)
1557+
1558+
assert model.partitioned_by == [exp.column("a", quoted=True)]
1559+
assert (
1560+
model.render_definition()[0].sql(pretty=True)
1561+
== """MODEL (
1562+
name db.table,
1563+
kind FULL,
1564+
partitioned_by "a"
1565+
)"""
1566+
)
1567+
1568+
# multiple columns wrapped in parenthesis in definition, parenthesis in rendered
1569+
model = load_sql_based_model(
1570+
d.parse(
1571+
f"""
1572+
MODEL (
1573+
name db.table,
1574+
kind FULL,
1575+
partitioned_by (a, b)
1576+
);
1577+
1578+
select 1 as a, 2 as b;
1579+
"""
1580+
)
1581+
)
1582+
1583+
assert model.partitioned_by == [exp.column("a", quoted=True), exp.column("b", quoted=True)]
1584+
assert (
1585+
model.render_definition()[0].sql(pretty=True)
1586+
== """MODEL (
1587+
name db.table,
1588+
kind FULL,
1589+
partitioned_by ("a", "b")
1590+
)"""
1591+
)
1592+
1593+
# multiple columns not wrapped in parenthesis in the definition is an error
1594+
with pytest.raises(ParseError, match=r"keyword: 'value' missing"):
1595+
load_sql_based_model(
1596+
d.parse(
1597+
f"""
1598+
MODEL (
1599+
name db.table,
1600+
kind FULL,
1601+
partitioned_by a, b
1602+
);
1603+
1604+
select 1 as a, 2 as b;
1605+
"""
1606+
)
1607+
)
1608+
1609+
# Iceberg transforms / functions
1610+
model = load_sql_based_model(
1611+
d.parse(
1612+
f"""
1613+
MODEL (
1614+
name db.table,
1615+
kind FULL,
1616+
partitioned_by (day(a), truncate(b, 4), bucket(c, 3))
1617+
);
1618+
1619+
select 1 as a, 2 as b, 3 as c;
1620+
"""
1621+
),
1622+
dialect="trino",
1623+
)
1624+
1625+
assert model.partitioned_by == [
1626+
exp.Day(this=exp.column("a", quoted=True)),
1627+
exp.PartitionByTruncate(
1628+
this=exp.column("b", quoted=True), expression=exp.Literal.number(4)
1629+
),
1630+
exp.PartitionedByBucket(
1631+
this=exp.column("c", quoted=True), expression=exp.Literal.number(3)
1632+
),
1633+
]
1634+
assert (
1635+
model.render_definition()[0].sql(pretty=True)
1636+
== """MODEL (
1637+
name db.table,
1638+
dialect trino,
1639+
kind FULL,
1640+
partitioned_by (DAY("a"), TRUNCATE("b", 4), BUCKET("c", 3))
1641+
)"""
1642+
)
1643+
1644+
15171645
def test_cron():
15181646
daily = _Node(name="x", cron="@daily")
15191647
assert to_datetime(daily.cron_prev("2020-01-01")) == to_datetime("2019-12-31")

tests/core/test_snapshot.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,3 +2938,37 @@ def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int
29382938
)
29392939
snapshot_a = make_snapshot(sql_model)
29402940
assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)]
2941+
2942+
2943+
def test_partitioned_by_roundtrip(make_snapshot: t.Callable):
2944+
sql_model = load_sql_based_model(
2945+
parse("""
2946+
MODEL (
2947+
name test_schema.test_model,
2948+
kind full,
2949+
partitioned_by (a, bucket(4, b), truncate(3, c), month(d))
2950+
);
2951+
SELECT a, b, c, d FROM tbl;
2952+
""")
2953+
)
2954+
snapshot = make_snapshot(sql_model)
2955+
assert isinstance(snapshot, Snapshot)
2956+
assert isinstance(snapshot.node, SqlModel)
2957+
2958+
assert snapshot.node.partitioned_by == [
2959+
exp.column("a", quoted=True),
2960+
exp.PartitionedByBucket(
2961+
this=exp.column("b", quoted=True), expression=exp.Literal.number(4)
2962+
),
2963+
exp.PartitionByTruncate(
2964+
this=exp.column("c", quoted=True), expression=exp.Literal.number(3)
2965+
),
2966+
exp.Month(this=exp.column("d", quoted=True)),
2967+
]
2968+
2969+
# roundtrip through json and ensure we get correct AST nodes on the other end
2970+
serialized = snapshot.json()
2971+
deserialized = snapshot.parse_raw(serialized)
2972+
2973+
assert isinstance(deserialized.node, SqlModel)
2974+
assert deserialized.node.partitioned_by == snapshot.node.partitioned_by

0 commit comments

Comments
 (0)