@@ -1571,8 +1571,7 @@ def table_diff(
15711571 target : str ,
15721572 on : t .List [str ] | exp .Condition | None = None ,
15731573 skip_columns : t .List [str ] | None = None ,
1574- model_or_snapshot : t .Optional [ModelOrSnapshot ] = None ,
1575- select_model : t .Optional [t .Collection [str ]] = None ,
1574+ select_models : t .Optional [t .Collection [str ]] = None ,
15761575 where : t .Optional [str | exp .Condition ] = None ,
15771576 limit : int = 20 ,
15781577 show : bool = True ,
@@ -1589,7 +1588,7 @@ def table_diff(
15891588 on: The join condition, table aliases must be "s" and "t" for source and target.
15901589 If omitted, the table's grain will be used.
15911590 skip_columns: The columns to skip when computing the table diff.
1592- model_or_snapshot : The model or snapshot to use when environments are passed in.
1591+ select_models : The modelσ or snapshotσ to use when environments are passed in.
15931592 where: An optional where statement to filter results.
15941593 limit: The limit of the sample dataframe.
15951594 show: Show the table diff output in the console.
@@ -1605,51 +1604,81 @@ def table_diff(
16051604 table_diffs : t .List [TableDiff ] = []
16061605
16071606 # Diffs multiple or a single model across two environments
1608- if model_or_snapshot or select_model :
1607+ if select_models :
16091608 source_env = self .state_reader .get_environment (source )
16101609 target_env = self .state_reader .get_environment (target )
16111610 if not source_env :
16121611 raise SQLMeshError (f"Could not find environment '{ source } '" )
16131612 if not target_env :
16141613 raise SQLMeshError (f"Could not find environment '{ target } '" )
16151614
1616- modified_snapshots : t .Set [ModelOrSnapshot ] = (
1617- {model_or_snapshot } if model_or_snapshot else set ()
1618- )
1619- if select_model :
1620- models_to_diff = self ._new_selector ().expand_model_selections (select_model )
1621- target_snapshots = {
1622- s .name : s
1623- for s in self .state_reader .get_snapshots (target_env .snapshots ).values ()
1624- if s .name in models_to_diff
1625- }
1626- context_diff = self ._context_diff (
1627- source ,
1628- snapshots = target_snapshots ,
1629- ensure_finalized_snapshots = self .config .plan .use_finalized_state ,
1615+ selected_models = self ._new_selector ().expand_model_selections (select_models )
1616+ models_to_diff : t .List [t .Tuple [Model , EngineAdapter , str , str ]] = []
1617+ models_in_source : t .List [str ] = []
1618+ models_in_target : t .List [str ] = []
1619+ models_no_diff : t .List [str ] = []
1620+
1621+ for model_or_snapshot in selected_models :
1622+ model = self .get_model (model_or_snapshot , raise_if_missing = True )
1623+ adapter = self ._get_engine_adapter (model .gateway )
1624+ source_snapshot = next (
1625+ (snapshot for snapshot in source_env .snapshots if snapshot .name == model .fqn ),
1626+ None ,
16301627 )
1631- modified_snapshots = {
1632- current_snapshot .snapshot_id .name
1633- for _ , (current_snapshot , _ ) in context_diff .modified_snapshots .items ()
1634- }
1635- tasks_num = min (len (modified_snapshots ), self .concurrent_tasks )
1636- table_diffs = concurrent_apply_to_values (
1637- list (modified_snapshots ),
1638- lambda s : self ._model_diff (
1639- source_env = source_env ,
1640- target_env = target_env ,
1641- model_or_snapshot = s ,
1642- limit = limit ,
1643- decimals = decimals ,
1644- on = on ,
1645- skip_columns = skip_columns ,
1646- where = where ,
1647- show = show ,
1648- temp_schema = temp_schema ,
1649- skip_grain_check = skip_grain_check ,
1650- ),
1651- tasks_num = tasks_num ,
1628+ target_snapshot = next (
1629+ (snapshot for snapshot in target_env .snapshots if snapshot .name == model .fqn ),
1630+ None ,
1631+ )
1632+ if source_snapshot is None and target_snapshot :
1633+ models_in_source .append (model_or_snapshot )
1634+ elif target_snapshot is None and source_snapshot :
1635+ models_in_target .append (model_or_snapshot )
1636+ elif target_snapshot and source_snapshot :
1637+ if source_snapshot .fingerprint != target_snapshot .fingerprint :
1638+ # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1639+ # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
1640+ source = source_snapshot .qualified_view_name .for_environment (
1641+ source_env .naming_info , adapter .dialect
1642+ )
1643+ target = target_snapshot .qualified_view_name .for_environment (
1644+ target_env .naming_info , adapter .dialect
1645+ )
1646+
1647+ models_to_diff .append ((model , adapter , source , target ))
1648+ else :
1649+ models_no_diff .append (model_or_snapshot )
1650+
1651+ self .console .show_table_diff_details (
1652+ models_in_source ,
1653+ models_in_target ,
1654+ models_no_diff ,
1655+ [model [0 ].name for model in models_to_diff ],
16521656 )
1657+
1658+ if models_to_diff :
1659+ self .console .start_table_diff_progress (len (models_to_diff ))
1660+ tasks_num = min (len (models_to_diff ), self .concurrent_tasks )
1661+ table_diffs = concurrent_apply_to_values (
1662+ list (models_to_diff ),
1663+ lambda model_info : self ._model_diff (
1664+ model = model_info [0 ],
1665+ adapter = model_info [1 ],
1666+ source = model_info [2 ],
1667+ target = model_info [3 ],
1668+ source_alias = source_env .name ,
1669+ target_alias = target_env .name ,
1670+ limit = limit ,
1671+ decimals = decimals ,
1672+ on = on ,
1673+ skip_columns = skip_columns ,
1674+ where = where ,
1675+ show = show ,
1676+ temp_schema = temp_schema ,
1677+ skip_grain_check = skip_grain_check ,
1678+ ),
1679+ tasks_num = tasks_num ,
1680+ )
1681+ self .console .stop_table_diff_progress ()
16531682 else :
16541683 table_diffs = [
16551684 self ._table_diff (
@@ -1673,9 +1702,12 @@ def table_diff(
16731702
16741703 def _model_diff (
16751704 self ,
1676- source_env : Environment ,
1677- target_env : Environment ,
1678- model_or_snapshot : ModelOrSnapshot ,
1705+ model : Model ,
1706+ adapter : EngineAdapter ,
1707+ source : str ,
1708+ target : str ,
1709+ source_alias : str ,
1710+ target_alias : str ,
16791711 limit : int ,
16801712 decimals : int ,
16811713 on : t .Optional [t .List [str ] | exp .Condition ] = None ,
@@ -1685,22 +1717,6 @@ def _model_diff(
16851717 temp_schema : t .Optional [str ] = None ,
16861718 skip_grain_check : bool = False ,
16871719 ) -> TableDiff :
1688- model = self .get_model (model_or_snapshot , raise_if_missing = True )
1689- adapter = self ._get_engine_adapter (model .gateway )
1690-
1691- # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1692- # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
1693- source = next (
1694- snapshot for snapshot in source_env .snapshots if snapshot .name == model .fqn
1695- ).qualified_view_name .for_environment (source_env .naming_info , adapter .dialect )
1696-
1697- target = next (
1698- snapshot for snapshot in target_env .snapshots if snapshot .name == model .fqn
1699- ).qualified_view_name .for_environment (target_env .naming_info , adapter .dialect )
1700-
1701- source_alias = source_env .name
1702- target_alias = target_env .name
1703-
17041720 if not on :
17051721 on = []
17061722 for expr in [ref .expression for ref in model .all_references if ref .unique ]:
@@ -1710,6 +1726,8 @@ def _model_diff(
17101726 # Handle a single Column or Paren expression
17111727 on .append (expr .this .sql (dialect = adapter .dialect ))
17121728
1729+ self .console .start_table_diff_model_progress (model .name )
1730+
17131731 table_diff = self ._table_diff (
17141732 on = on ,
17151733 skip_columns = skip_columns ,
@@ -1723,10 +1741,13 @@ def _model_diff(
17231741 source_alias = source_alias ,
17241742 target_alias = target_alias ,
17251743 )
1726- # Trigger row_diff in parallel execution so it's available for ordered display later
1744+
17271745 if show :
1746+ # Trigger row_diff in parallel execution so it's available for ordered display later
17281747 table_diff .row_diff (temp_schema = temp_schema , skip_grain_check = skip_grain_check )
17291748
1749+ self .console .update_table_diff_progress (model .name )
1750+
17301751 return table_diff
17311752
17321753 def _table_diff (
0 commit comments