Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 9f27117

Browse files
author
Sergey Vasilyev
committed
Group raw column info from rows to structures for schema parsing
1 parent 60ac169 commit 9f27117

14 files changed

Lines changed: 182 additions & 138 deletions

File tree

data_diff/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from rich.logging import RichHandler
1313
import click
1414

15-
from data_diff import Database
16-
from data_diff.schema import create_schema
15+
from data_diff import Database, DbPath
16+
from data_diff.schema import RawColumnInfo, create_schema
1717
from data_diff.queries.api import current_timestamp
1818

1919
from data_diff.dbt import dbt_diff
@@ -72,7 +72,7 @@ def _remove_passwords_in_dict(d: dict) -> None:
7272
d[k] = remove_password_from_url(v)
7373

7474

75-
def _get_schema(pair):
75+
def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]:
7676
db, table_path = pair
7777
return db.query_table_schema(table_path)
7878

data_diff/databases/base.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from data_diff.abcs.compiler import AbstractCompiler, Compilable
2121
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
22+
from data_diff.schema import RawColumnInfo
2223
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
2324
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
2425
from data_diff.queries.ast_classes import (
@@ -707,27 +708,18 @@ def type_repr(self, t) -> str:
707708
datetime: "TIMESTAMP",
708709
}[t]
709710

710-
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
711-
return self.TYPE_CLASSES.get(type_repr)
712-
713-
def parse_type(
714-
self,
715-
table_path: DbPath,
716-
col_name: str,
717-
type_repr: str,
718-
datetime_precision: int = None,
719-
numeric_precision: int = None,
720-
numeric_scale: int = None,
721-
) -> ColType:
711+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
722712
"Parse type info as returned by the database"
723713

724-
cls = self._parse_type_repr(type_repr)
714+
cls = self.TYPE_CLASSES.get(info.type_repr)
725715
if cls is None:
726-
return UnknownColType(type_repr)
716+
return UnknownColType(info.type_repr)
727717

728718
if issubclass(cls, TemporalType):
729719
return cls(
730-
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
720+
precision=info.datetime_precision
721+
if info.datetime_precision is not None
722+
else DEFAULT_DATETIME_PRECISION,
731723
rounds=self.ROUNDS_ON_PREC_LOSS,
732724
)
733725

@@ -738,22 +730,22 @@ def parse_type(
738730
return cls()
739731

740732
elif issubclass(cls, Decimal):
741-
if numeric_scale is None:
742-
numeric_scale = 0 # Needed for Oracle.
743-
return cls(precision=numeric_scale)
733+
if info.numeric_scale is None:
734+
return cls(precision=0) # Needed for Oracle.
735+
return cls(precision=info.numeric_scale)
744736

745737
elif issubclass(cls, Float):
746738
# assert numeric_scale is None
747739
return cls(
748740
precision=self._convert_db_precision_to_digits(
749-
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
741+
info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
750742
)
751743
)
752744

753745
elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
754746
return cls()
755747

756-
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
748+
raise TypeError(f"Parsing {info.type_repr} returned an unknown type {cls!r}.")
757749

758750
def _convert_db_precision_to_digits(self, p: int) -> int:
759751
"""Convert from binary precision, used by floats, to decimal precision."""
@@ -1018,7 +1010,7 @@ def select_table_schema(self, path: DbPath) -> str:
10181010
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
10191011
)
10201012

1021-
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
1013+
def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
10221014
"""Query the table for its schema for table in 'path', and return {column: tuple}
10231015
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
10241016
@@ -1029,7 +1021,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10291021
if not rows:
10301022
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
10311023

1032-
d = {r[0]: r for r in rows}
1024+
d = {
1025+
r[0]: RawColumnInfo(
1026+
column_name=r[0],
1027+
type_repr=r[1],
1028+
datetime_precision=r[2],
1029+
numeric_precision=r[3],
1030+
numeric_scale=r[4],
1031+
collation_name=r[5] if len(r) > 5 else None,
1032+
)
1033+
for r in rows
1034+
}
10331035
assert len(d) == len(rows)
10341036
return d
10351037

@@ -1051,7 +1053,11 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
10511053
return list(res)
10521054

10531055
def _process_table_schema(
1054-
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None
1056+
self,
1057+
path: DbPath,
1058+
raw_schema: Dict[str, RawColumnInfo],
1059+
filter_columns: Sequence[str] = None,
1060+
where: str = None,
10551061
):
10561062
"""Process the result of query_table_schema().
10571063
@@ -1067,7 +1073,7 @@ def _process_table_schema(
10671073
accept = {i.lower() for i in filter_columns}
10681074
filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept}
10691075

1070-
col_dict = {row[0]: self.dialect.parse_type(path, *row) for _name, row in filtered_schema.items()}
1076+
col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()}
10711077

10721078
self._refine_coltypes(path, col_dict, where)
10731079

@@ -1076,15 +1082,15 @@ def _process_table_schema(
10761082

10771083
def _refine_coltypes(
10781084
self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64
1079-
):
1085+
) -> Dict[str, ColType]:
10801086
"""Refine the types in the column dict, by querying the database for a sample of their values
10811087
10821088
'where' restricts the rows to be sampled.
10831089
"""
10841090

10851091
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
10861092
if not text_columns:
1087-
return
1093+
return col_dict
10881094

10891095
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
10901096

@@ -1118,6 +1124,8 @@ def _refine_coltypes(
11181124
assert col_name in col_dict
11191125
col_dict[col_name] = String_VaryingAlphanum()
11201126

1127+
return col_dict
1128+
11211129
def _normalize_table_path(self, path: DbPath) -> DbPath:
11221130
if len(path) == 1:
11231131
return self.default_schema, path[0]

data_diff/databases/bigquery.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MD5_HEXDIGITS,
3434
)
3535
from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
36+
from data_diff.schema import RawColumnInfo
3637

3738

3839
@import_helper(text="Please install BigQuery and configure your google-cloud access.")
@@ -91,27 +92,21 @@ def type_repr(self, t) -> str:
9192
except KeyError:
9293
return super().type_repr(t)
9394

94-
def parse_type(
95-
self,
96-
table_path: DbPath,
97-
col_name: str,
98-
type_repr: str,
99-
*args: Any, # pass-through args
100-
**kwargs: Any, # pass-through args
101-
) -> ColType:
102-
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
95+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
96+
col_type = super().parse_type(table_path, info)
10397
if isinstance(col_type, UnknownColType):
104-
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
98+
m = self.TYPE_ARRAY_RE.fullmatch(info.type_repr)
10599
if m:
106-
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
100+
item_info = attrs.evolve(info, data_type=m.group(1))
101+
item_type = self.parse_type(table_path, item_info)
107102
col_type = Array(item_type=item_type)
108103

109104
# We currently ignore structs' structure, but later can parse it too. Examples:
110105
# - STRUCT<INT64, STRING(10)> (unnamed)
111106
# - STRUCT<foo INT64, bar STRING(10)> (named)
112107
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
113108
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
114-
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
109+
m = self.TYPE_STRUCT_RE.fullmatch(info.type_repr)
115110
if m:
116111
col_type = Struct()
117112

data_diff/databases/clickhouse.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from data_diff.abcs.database_types import (
1616
ColType,
17+
DbPath,
1718
Decimal,
1819
Float,
1920
Integer,
@@ -24,6 +25,7 @@
2425
Timestamp,
2526
Boolean,
2627
)
28+
from data_diff.schema import RawColumnInfo
2729

2830
# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database
2931
DEFAULT_DATABASE = "default"
@@ -75,19 +77,19 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
7577
# because it does not help for float with a big integer part.
7678
return super()._convert_db_precision_to_digits(p) - 2
7779

78-
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
80+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
7981
nullable_prefix = "Nullable("
80-
if type_repr.startswith(nullable_prefix):
81-
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")
82+
if info.type_repr.startswith(nullable_prefix):
83+
info = attrs.evolve(info, data_type=info.type_repr[len(nullable_prefix) :].rstrip(")"))
8284

83-
if type_repr.startswith("Decimal"):
84-
type_repr = "Decimal"
85-
elif type_repr.startswith("FixedString"):
86-
type_repr = "FixedString"
87-
elif type_repr.startswith("DateTime64"):
88-
type_repr = "DateTime64"
85+
if info.type_repr.startswith("Decimal"):
86+
info = attrs.evolve(info, data_type="Decimal")
87+
elif info.type_repr.startswith("FixedString"):
88+
info = attrs.evolve(info, data_type="FixedString")
89+
elif info.type_repr.startswith("DateTime64"):
90+
info = attrs.evolve(info, data_type="DateTime64")
8991

90-
return self.TYPE_CLASSES.get(type_repr)
92+
return super().parse_type(table_path, info)
9193

9294
# def timestamp_value(self, t: DbTime) -> str:
9395
# # return f"'{t}'"

data_diff/databases/databricks.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import_helper,
2727
parse_table_name,
2828
)
29+
from data_diff.schema import RawColumnInfo
2930

3031

3132
@import_helper(text="You can install it using 'pip install databricks-sql-connector'")
@@ -138,7 +139,7 @@ def create_connection(self):
138139
except databricks.sql.exc.Error as e:
139140
raise ConnectionError(*e.args) from e
140141

141-
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
142+
def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
142143
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
143144
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
144145
# So, to obtain information about schema, we should use another approach.
@@ -155,7 +156,12 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
155156
if not rows:
156157
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
157158

158-
d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
159+
d = {
160+
r.COLUMN_NAME: RawColumnInfo(
161+
column_name=r.COLUMN_NAME, type_repr=r.TYPE_NAME, datetime_precision=r.DECIMAL_DIGITS
162+
)
163+
for r in rows
164+
}
159165
assert len(d) == len(rows)
160166
return d
161167

@@ -173,37 +179,39 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
173179
# )
174180

175181
def _process_table_schema(
176-
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
182+
self, path: DbPath, raw_schema: Dict[str, RawColumnInfo], filter_columns: Sequence[str], where: str = None
177183
):
178184
accept = {i.lower() for i in filter_columns}
179-
rows = [row for name, row in raw_schema.items() if name.lower() in accept]
185+
col_infos = [row for name, row in raw_schema.items() if name.lower() in accept]
180186

181187
resulted_rows = []
182-
for row in rows:
183-
row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1]
188+
for info in col_infos:
189+
row_type = "DECIMAL" if info.type_repr.startswith("DECIMAL") else info.type_repr
190+
info = attrs.evolve(info, type_repr=row_type)
184191
type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType)
185192

186193
if issubclass(type_cls, Integer):
187-
row = (row[0], row_type, None, None, 0)
194+
info = attrs.evolve(info, numeric_scale=0)
188195

189196
elif issubclass(type_cls, Float):
190-
numeric_precision = math.ceil(row[2] / math.log(2, 10))
191-
row = (row[0], row_type, None, numeric_precision, None)
197+
numeric_precision = math.ceil(info[2] / math.log(2, 10))
198+
info = attrs.evolve(info, numeric_precision=numeric_precision)
192199

193200
elif issubclass(type_cls, Decimal):
194-
items = row[1][8:].rstrip(")").split(",")
201+
items = info.type_repr[8:].rstrip(")").split(",")
195202
numeric_precision, numeric_scale = int(items[0]), int(items[1])
196-
row = (row[0], row_type, None, numeric_precision, numeric_scale)
203+
info = attrs.evolve(
204+
info,
205+
numeric_precision=numeric_precision,
206+
numeric_scale=numeric_scale,
207+
)
197208

198209
elif issubclass(type_cls, Timestamp):
199-
row = (row[0], row_type, row[2], None, None)
210+
info = attrs.evolve(info, datetime_precision=info.datetime_precision)
200211

201-
else:
202-
row = (row[0], row_type, None, None, None)
212+
resulted_rows.append(info)
203213

204-
resulted_rows.append(row)
205-
206-
col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows}
214+
col_dict: Dict[str, ColType] = {info.column_name: self.dialect.parse_type(path, info) for info in resulted_rows}
207215

208216
self._refine_coltypes(path, col_dict, where)
209217
return col_dict

data_diff/databases/duckdb.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import attrs
44
from packaging.version import parse as parse_version
55

6+
from data_diff.schema import RawColumnInfo
67
from data_diff.utils import match_regexps
78
from data_diff.abcs.database_types import (
89
Timestamp,
@@ -74,24 +75,16 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
7475
# Subtracting 2 due to wierd precision issues in PostgreSQL
7576
return super()._convert_db_precision_to_digits(p) - 2
7677

77-
def parse_type(
78-
self,
79-
table_path: DbPath,
80-
col_name: str,
81-
type_repr: str,
82-
datetime_precision: int = None,
83-
numeric_precision: int = None,
84-
numeric_scale: int = None,
85-
) -> ColType:
78+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
8679
regexps = {
8780
r"DECIMAL\((\d+),(\d+)\)": Decimal,
8881
}
8982

90-
for m, t_cls in match_regexps(regexps, type_repr):
83+
for m, t_cls in match_regexps(regexps, info.type_repr):
9184
precision = int(m.group(2))
9285
return t_cls(precision=precision)
9386

94-
return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)
87+
return super().parse_type(table_path, info)
9588

9689
def set_timezone_to_utc(self) -> str:
9790
return "SET GLOBAL TimeZone='UTC'"

0 commit comments

Comments
 (0)