@@ -1569,8 +1569,8 @@ def table_diff(
15691569 self ,
15701570 source : str ,
15711571 target : str ,
1572- on : t .List [str ] | exp .Condition | None = None ,
1573- skip_columns : t .List [str ] | None = None ,
1572+ on : t .Optional [ t . List [str ] | exp .Condition ] = None ,
1573+ skip_columns : t .Optional [ t . List [str ]] = None ,
15741574 select_models : t .Optional [t .Collection [str ]] = None ,
15751575 where : t .Optional [str | exp .Condition ] = None ,
15761576 limit : int = 20 ,
@@ -1579,7 +1579,7 @@ def table_diff(
15791579 decimals : int = 3 ,
15801580 skip_grain_check : bool = False ,
15811581 temp_schema : t .Optional [str ] = None ,
1582- ) -> t .Union [ TableDiff , t . List [TableDiff ] ]:
1582+ ) -> t .List [TableDiff ]:
15831583 """Show a diff between two tables.
15841584
15851585 Args:
@@ -1613,26 +1613,30 @@ def table_diff(
16131613 raise SQLMeshError (f"Could not find environment '{ target } '" )
16141614
16151615 selected_models = self ._new_selector ().expand_model_selections (select_models )
1616- models_to_diff : t .List [t .Tuple [Model , EngineAdapter , str , str ]] = []
1616+ models_to_diff : t .List [
1617+ t .Tuple [Model , EngineAdapter , str , str , t .Optional [t .List [str ] | exp .Condition ]]
1618+ ] = []
16171619 models_in_source : t .List [str ] = []
16181620 models_in_target : t .List [str ] = []
16191621 models_no_diff : t .List [str ] = []
1622+ models_without_grain : t .List [Model ] = []
1623+ source_snapshots_to_name = {
1624+ snapshot .name : snapshot for snapshot in source_env .snapshots
1625+ }
1626+ target_snapshots_to_name = {
1627+ snapshot .name : snapshot for snapshot in target_env .snapshots
1628+ }
16201629
1621- for model_or_snapshot in selected_models :
1622- model = self .get_model ( model_or_snapshot , raise_if_missing = True )
1630+ for model_fqn in selected_models :
1631+ model = self ._models [ model_fqn ]
16231632 adapter = self ._get_engine_adapter (model .gateway )
1624- source_snapshot = next (
1625- (snapshot for snapshot in source_env .snapshots if snapshot .name == model .fqn ),
1626- None ,
1627- )
1628- target_snapshot = next (
1629- (snapshot for snapshot in target_env .snapshots if snapshot .name == model .fqn ),
1630- None ,
1631- )
1633+ source_snapshot = source_snapshots_to_name .get (model .fqn )
1634+ target_snapshot = target_snapshots_to_name .get (model .fqn )
1635+
16321636 if source_snapshot is None and target_snapshot :
1633- models_in_source .append (model_or_snapshot )
1637+ models_in_source .append (model_fqn )
16341638 elif target_snapshot is None and source_snapshot :
1635- models_in_target .append (model_or_snapshot )
1639+ models_in_target .append (model_fqn )
16361640 elif target_snapshot and source_snapshot :
16371641 if source_snapshot .fingerprint != target_snapshot .fingerprint :
16381642 # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
@@ -1644,9 +1648,27 @@ def table_diff(
16441648 target_env .naming_info , adapter .dialect
16451649 )
16461650
1647- models_to_diff .append ((model , adapter , source , target ))
1651+ model_on = []
1652+ if not on :
1653+ for expr in [
1654+ ref .expression for ref in model .all_references if ref .unique
1655+ ]:
1656+ if isinstance (expr , exp .Tuple ):
1657+ model_on .extend (
1658+ [
1659+ key .this .sql (dialect = adapter .dialect )
1660+ for key in expr .expressions
1661+ ]
1662+ )
1663+ else :
1664+ # Handle a single Column or Paren expression
1665+ model_on .append (expr .this .sql (dialect = adapter .dialect ))
1666+
1667+ models_to_diff .append ((model , adapter , source , target , on or model_on ))
1668+ if not (on or model_on ):
1669+ models_without_grain .append (model )
16481670 else :
1649- models_no_diff .append (model_or_snapshot )
1671+ models_no_diff .append (model_fqn )
16501672
16511673 self .console .show_table_diff_details (
16521674 models_in_source ,
@@ -1656,6 +1678,15 @@ def table_diff(
16561678 )
16571679
16581680 if models_to_diff :
1681+ if models_without_grain :
1682+ model_names = "\n " .join (
1683+ f"─ { model .name } \n at '{ model ._path } '" for model in models_without_grain
1684+ )
1685+ raise SQLMeshError (
1686+ f"SQLMesh doesn't know how to join the tables for the following models:\n { model_names } \n "
1687+ "\n Please specify the `grains` in each model definition."
1688+ )
1689+
16591690 self .console .start_table_diff_progress (len (models_to_diff ))
16601691 tasks_num = min (len (models_to_diff ), self .concurrent_tasks )
16611692 table_diffs = concurrent_apply_to_values (
@@ -1665,11 +1696,11 @@ def table_diff(
16651696 adapter = model_info [1 ],
16661697 source = model_info [2 ],
16671698 target = model_info [3 ],
1699+ on = model_info [4 ],
16681700 source_alias = source_env .name ,
16691701 target_alias = target_env .name ,
16701702 limit = limit ,
16711703 decimals = decimals ,
1672- on = on ,
16731704 skip_columns = skip_columns ,
16741705 where = where ,
16751706 show = show ,
@@ -1698,7 +1729,7 @@ def table_diff(
16981729 if show :
16991730 self .console .show_table_diff (table_diffs , show_sample , skip_grain_check , temp_schema )
17001731
1701- return table_diffs [ 0 ] if len ( table_diffs ) == 1 else table_diffs
1732+ return table_diffs
17021733
17031734 def _model_diff (
17041735 self ,
@@ -1717,15 +1748,6 @@ def _model_diff(
17171748 temp_schema : t .Optional [str ] = None ,
17181749 skip_grain_check : bool = False ,
17191750 ) -> TableDiff :
1720- if not on :
1721- on = []
1722- for expr in [ref .expression for ref in model .all_references if ref .unique ]:
1723- if isinstance (expr , exp .Tuple ):
1724- on .extend ([key .this .sql (dialect = adapter .dialect ) for key in expr .expressions ])
1725- else :
1726- # Handle a single Column or Paren expression
1727- on .append (expr .this .sql (dialect = adapter .dialect ))
1728-
17291751 self .console .start_table_diff_model_progress (model .name )
17301752
17311753 table_diff = self ._table_diff (
0 commit comments