Skip to content

Commit 87923db

Browse files
refactor on method in modelmeta
1 parent 8887e2b commit 87923db

2 files changed

Lines changed: 17 additions & 19 deletions

File tree

sqlmesh/core/context.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,25 +1647,9 @@ def table_diff(
16471647
target = target_snapshot.qualified_view_name.for_environment(
16481648
target_env.naming_info, adapter.dialect
16491649
)
1650-
1651-
model_on = []
1652-
if not on:
1653-
for expr in [
1654-
ref.expression for ref in model.all_references if ref.unique
1655-
]:
1656-
if isinstance(expr, exp.Tuple):
1657-
model_on.extend(
1658-
[
1659-
key.this.sql(dialect=adapter.dialect)
1660-
for key in expr.expressions
1661-
]
1662-
)
1663-
else:
1664-
# Handle a single Column or Paren expression
1665-
model_on.append(expr.this.sql(dialect=adapter.dialect))
1666-
1667-
models_to_diff.append((model, adapter, source, target, on or model_on))
1668-
if not (on or model_on):
1650+
model_on = on or model.on
1651+
models_to_diff.append((model, adapter, source, target, model_on))
1652+
if not model_on:
16691653
models_without_grain.append(model)
16701654
else:
16711655
models_no_diff.append(model_fqn)

sqlmesh/core/model/meta.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,20 @@ def all_references(self) -> t.List[Reference]:
450450
Reference(model_name=self.name, expression=e, unique=True) for e in self.references
451451
]
452452

453+
@property
454+
def on(self) -> t.List[str]:
455+
"""The grains to be used as join condition in table_diff."""
456+
457+
on: t.List[str] = []
458+
for expr in [ref.expression for ref in self.all_references if ref.unique]:
459+
if isinstance(expr, exp.Tuple):
460+
on.extend([key.this.sql(dialect=self.dialect) for key in expr.expressions])
461+
else:
462+
# Handle a single Column or Paren expression
463+
on.append(expr.this.sql(dialect=self.dialect))
464+
465+
return on
466+
453467
@property
454468
def managed_columns(self) -> t.Dict[str, exp.DataType]:
455469
return getattr(self.kind, "managed_columns", {})

0 commit comments

Comments
 (0)