22from numbers import Number
33import logging
44from collections import defaultdict
5- from typing import Iterator
5+ from typing import Any , Collection , Dict , Iterator , List , Sequence , Set , Tuple
66
77import attrs
8+ from typing_extensions import Literal
89
910from data_diff .abcs .database_types import ColType_UUID , NumericType , PrecisionType , StringType , Boolean , JSON
1011from data_diff .info_tree import InfoTree
2021
2122logger = logging .getLogger ("hashdiff_tables" )
2223
23-
24- def diff_sets (a : list , b : list , json_cols : dict = None ) -> Iterator :
25- sa = set (a )
26- sb = set (b )
24+ # Just for local readability: TODO: later switch to real type declarations of these.
25+ _Op = Literal ["+" , "-" ]
26+ _PK = Any
27+ _Row = Tuple [Any ]
28+
29+
30+ def diff_sets (
31+ a : Sequence [_Row ],
32+ b : Sequence [_Row ],
33+ * ,
34+ json_cols : dict = None ,
35+ columns1 : Sequence [str ],
36+ columns2 : Sequence [str ],
37+ ignored_columns1 : Collection [str ],
38+ ignored_columns2 : Collection [str ],
39+ ) -> Iterator :
40+ # Differ only by columns of interest (PKs+relevant-ignored). But yield with ignored ones!
41+ sa : Set [_Row ] = {tuple (val for col , val in safezip (columns1 , row ) if col not in ignored_columns1 ) for row in a }
42+ sb : Set [_Row ] = {tuple (val for col , val in safezip (columns2 , row ) if col not in ignored_columns2 ) for row in b }
2743
2844 # The first item is always the key (see TableDiffer.relevant_columns)
2945 # TODO update when we add compound keys to hashdiff
30- d = defaultdict (list )
46+ diffs_by_pks : Dict [ _PK , List [ Tuple [ _Op , _Row ]]] = defaultdict (list )
3147 for row in a :
32- if row not in sb :
33- d [row [0 ]].append (("-" , row ))
48+ cutrow : _Row = tuple (val for col , val in zip (columns1 , row ) if col not in ignored_columns1 )
49+ if cutrow not in sb :
50+ diffs_by_pks [row [0 ]].append (("-" , row ))
3451 for row in b :
35- if row not in sa :
36- d [row [0 ]].append (("+" , row ))
52+ cutrow : _Row = tuple (val for col , val in zip (columns2 , row ) if col not in ignored_columns2 )
53+ if cutrow not in sa :
54+ diffs_by_pks [row [0 ]].append (("+" , row ))
3755
3856 warned_diff_cols = set ()
39- for _k , v in sorted ( d . items (), key = lambda i : i [ 0 ] ):
57+ for diffs in ( diffs_by_pks [ pk ] for pk in sorted ( diffs_by_pks ) ):
4058 if json_cols :
41- parsed_match , overriden_diff_cols = diffs_are_equiv_jsons (v , json_cols )
59+ parsed_match , overriden_diff_cols = diffs_are_equiv_jsons (diffs , json_cols )
4260 if parsed_match :
4361 to_warn = overriden_diff_cols - warned_diff_cols
4462 for w in to_warn :
@@ -48,7 +66,7 @@ def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator:
4866 )
4967 warned_diff_cols .add (w )
5068 continue
51- yield from v
69+ yield from diffs
5270
5371
5472@attrs .define (frozen = False )
@@ -201,7 +219,17 @@ def _bisect_and_diff_segments(
201219 for i , colname in enumerate (table1 .extra_columns )
202220 if isinstance (table1 ._schema [colname ], JSON )
203221 }
204- diff = list (diff_sets (rows1 , rows2 , json_cols ))
222+ diff = list (
223+ diff_sets (
224+ rows1 ,
225+ rows2 ,
226+ json_cols = json_cols ,
227+ columns1 = table1 .relevant_columns ,
228+ columns2 = table2 .relevant_columns ,
229+ ignored_columns1 = self .ignored_columns1 ,
230+ ignored_columns2 = self .ignored_columns1 ,
231+ )
232+ )
205233
206234 info_tree .info .set_diff (diff )
207235 info_tree .info .rowcounts = {1 : len (rows1 ), 2 : len (rows2 )}
0 commit comments