Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit bf4eeea

Browse files
author
Sergey Vasilyev
committed
Move table name parsing to dialects, where they semantically belong
1 parent 651b8bc commit bf4eeea

7 files changed

Lines changed: 27 additions & 23 deletions

File tree

data_diff/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

33
from data_diff.abcs.database_types import DbTime, DbPath
4+
from data_diff.databases import Database
45
from data_diff.tracking import disable_tracking
56
from data_diff.databases._connect import connect
67
from data_diff.diff_tables import Algorithm
@@ -31,10 +32,10 @@ def connect_to_table(
3132
if isinstance(key_columns, str):
3233
key_columns = (key_columns,)
3334

34-
db = connect(db_info, thread_count=thread_count)
35+
db: Database = connect(db_info, thread_count=thread_count)
3536

3637
if isinstance(table_name, str):
37-
table_name = db.parse_table_name(table_name)
38+
table_name = db.dialect.parse_table_name(table_name)
3839

3940
return TableSegment(db, table_name, key_columns, **kwargs)
4041

@@ -161,7 +162,8 @@ def diff_tables(
161162
)
162163
elif algorithm == Algorithm.JOINDIFF:
163164
if isinstance(materialize_to_table, str):
164-
materialize_to_table = table1.database.parse_table_name(eval_name_template(materialize_to_table))
165+
table_name = eval_name_template(materialize_to_table)
166+
materialize_to_table = table1.database.dialect.parse_table_name(table_name)
165167
differ = JoinDiffer(
166168
threaded=threaded,
167169
max_threadpool_size=max_threadpool_size,

data_diff/__main__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import json
77
import logging
88
from itertools import islice
9-
from typing import Dict, Optional
9+
from typing import Dict, Optional, Tuple
1010

1111
import rich
1212
from rich.logging import RichHandler
1313
import click
1414

15+
from data_diff import Database
1516
from data_diff.schema import create_schema
1617
from data_diff.queries.api import current_timestamp
1718

@@ -425,7 +426,7 @@ def _data_diff(
425426
logging.error(f"Error while parsing age expression: {e}")
426427
return
427428

428-
dbs = db1, db2
429+
dbs: Tuple[Database, Database] = db1, db2
429430

430431
if interactive:
431432
for db in dbs:
@@ -444,7 +445,7 @@ def _data_diff(
444445
materialize_all_rows=materialize_all_rows,
445446
table_write_limit=table_write_limit,
446447
materialize_to_table=materialize_to_table
447-
and db1.parse_table_name(eval_name_template(materialize_to_table)),
448+
and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)),
448449
)
449450
else:
450451
assert algorithm == Algorithm.HASHDIFF
@@ -456,7 +457,7 @@ def _data_diff(
456457
)
457458

458459
table_names = table1, table2
459-
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
460+
table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)]
460461

461462
schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
462463
schema1, schema2 = schemas = [

data_diff/abcs/database_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ class UnknownColType(ColType):
176176
class AbstractDialect(ABC):
177177
"""Dialect-dependent query expressions"""
178178

179+
@abstractmethod
180+
def parse_table_name(self, name: str) -> DbPath:
181+
"Parse the given table name into a DbPath"
182+
179183
@property
180184
@abstractmethod
181185
def name(self) -> str:
@@ -319,10 +323,6 @@ def _process_table_schema(
319323
320324
"""
321325

322-
@abstractmethod
323-
def parse_table_name(self, name: str) -> DbPath:
324-
"Parse the given table name into a DbPath"
325-
326326
@abstractmethod
327327
def close(self):
328328
"Close connection(s) to the database instance. Querying will stop functioning."

data_diff/databases/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ class BaseDialect(AbstractDialect):
156156

157157
PLACEHOLDER_TABLE = None # Used for Oracle
158158

159+
def parse_table_name(self, name: str) -> DbPath:
160+
return parse_table_name(name)
161+
159162
def offset_limit(
160163
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
161164
) -> str:
@@ -518,9 +521,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
518521

519522
raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")
520523

521-
def parse_table_name(self, name: str) -> DbPath:
522-
return parse_table_name(name)
523-
524524
def _query_cursor(self, c, sql_code: str) -> QueryResult:
525525
assert isinstance(sql_code, str), sql_code
526526
try:

data_diff/databases/bigquery.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ def to_comparable(self, value: str, coltype: ColType) -> str:
212212
def set_timezone_to_utc(self) -> str:
213213
raise NotImplementedError()
214214

215+
def parse_table_name(self, name: str) -> DbPath:
216+
path = parse_table_name(name)
217+
return tuple(i for i in path if i is not None)
218+
215219

216220
class BigQuery(Database):
217221
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
@@ -288,10 +292,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
288292
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table"
289293
)
290294

291-
def parse_table_name(self, name: str) -> DbPath:
292-
path = parse_table_name(name)
293-
return tuple(i for i in self._normalize_table_path(path) if i is not None)
294-
295295
@property
296296
def is_autocommit(self) -> bool:
297297
return True

data_diff/databases/databricks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
9494
def set_timezone_to_utc(self) -> str:
9595
return "SET TIME ZONE 'UTC'"
9696

97+
def parse_table_name(self, name: str) -> DbPath:
98+
path = parse_table_name(name)
99+
return tuple(i for i in path if i is not None)
100+
97101

98102
class Databricks(ThreadedDatabase):
99103
dialect = Dialect()
@@ -178,10 +182,6 @@ def _process_table_schema(
178182
self._refine_coltypes(path, col_dict, where)
179183
return col_dict
180184

181-
def parse_table_name(self, name: str) -> DbPath:
182-
path = parse_table_name(name)
183-
return tuple(i for i in self._normalize_table_path(path) if i is not None)
184-
185185
@property
186186
def is_autocommit(self) -> bool:
187187
return True

data_diff/queries/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def new_unique_name(self, prefix="tmp"):
7979

8080
def new_unique_table_name(self, prefix="tmp") -> DbPath:
8181
self._counter[0] += 1
82-
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
82+
table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}"
83+
return self.database.dialect.parse_table_name(table_name)
8384

8485
def add_table_context(self, *tables: Sequence, **kw) -> Self:
8586
return self.replace(_table_context=self._table_context + list(tables), **kw)

0 commit comments

Comments
 (0)