|
| 1 | +import hashlib |
| 2 | +import os |
| 3 | +import string |
| 4 | +import random |
| 5 | +from typing import Callable |
| 6 | +import unittest |
| 7 | +import logging |
| 8 | +import subprocess |
| 9 | + |
| 10 | +from parameterized import parameterized_class |
| 11 | + |
| 12 | +from sqeleton import databases as db |
| 13 | +from sqeleton import connect |
| 14 | +from sqeleton.queries import table |
| 15 | +from sqeleton.databases import Database |
| 16 | +from sqeleton.query_utils import drop_table |
| 17 | + |
| 18 | + |
| 19 | +# We write 'or None' because Github sometimes creates empty env vars for secrets |
| 20 | +TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" |
| 21 | +TEST_POSTGRESQL_CONN_STRING: str = "postgresql://postgres:Password1@localhost/postgres" |
| 22 | +TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("SNOWFLAKE_URI") or None |
| 23 | +TEST_PRESTO_CONN_STRING: str = os.environ.get("PRESTO_URI") or None |
| 24 | +TEST_BIGQUERY_CONN_STRING: str = os.environ.get("BIGQUERY_URI") or None |
| 25 | +TEST_REDSHIFT_CONN_STRING: str = os.environ.get("REDSHIFT_URI") or None |
| 26 | +TEST_ORACLE_CONN_STRING: str = None |
| 27 | +TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATABRICKS_URI") |
| 28 | +TEST_TRINO_CONN_STRING: str = os.environ.get("TRINO_URI") or None |
| 29 | +# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" |
| 30 | +TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("CLICKHOUSE_URI") |
| 31 | +# vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica" |
| 32 | +TEST_VERTICA_CONN_STRING: str = os.environ.get("VERTICA_URI") |
| 33 | +TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" |
| 34 | + |
| 35 | + |
| 36 | +DEFAULT_N_SAMPLES = 50 |
| 37 | +N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) |
| 38 | +BENCHMARK = os.environ.get("BENCHMARK", False) |
| 39 | +N_THREADS = int(os.environ.get("N_THREADS", 1)) |
| 40 | +TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite? |
| 41 | + |
| 42 | + |
| 43 | +def get_git_revision_short_hash() -> str: |
| 44 | + return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() |
| 45 | + |
| 46 | + |
| 47 | +GIT_REVISION = get_git_revision_short_hash() |
| 48 | + |
| 49 | +level = logging.ERROR |
| 50 | +if os.environ.get("LOG_LEVEL", False): |
| 51 | + level = getattr(logging, os.environ["LOG_LEVEL"].upper()) |
| 52 | + |
| 53 | +logging.basicConfig(level=level) |
| 54 | +logging.getLogger("database").setLevel(level) |
| 55 | + |
| 56 | +try: |
| 57 | + from .local_settings import * |
| 58 | +except ImportError: |
| 59 | + pass # No local settings |
| 60 | + |
| 61 | + |
| 62 | +CONN_STRINGS = { |
| 63 | + db.BigQuery: TEST_BIGQUERY_CONN_STRING, |
| 64 | + db.MySQL: TEST_MYSQL_CONN_STRING, |
| 65 | + db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING, |
| 66 | + db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, |
| 67 | + db.Redshift: TEST_REDSHIFT_CONN_STRING, |
| 68 | + db.Oracle: TEST_ORACLE_CONN_STRING, |
| 69 | + db.Presto: TEST_PRESTO_CONN_STRING, |
| 70 | + db.Databricks: TEST_DATABRICKS_CONN_STRING, |
| 71 | + db.Trino: TEST_TRINO_CONN_STRING, |
| 72 | + db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, |
| 73 | + db.Vertica: TEST_VERTICA_CONN_STRING, |
| 74 | + db.DuckDB: TEST_DUCKDB_CONN_STRING, |
| 75 | +} |
| 76 | + |
| 77 | +_database_instances = {} |
| 78 | + |
| 79 | + |
| 80 | +def get_conn(cls: type, shared: bool = True) -> Database: |
| 81 | + if shared: |
| 82 | + if cls not in _database_instances: |
| 83 | + _database_instances[cls] = get_conn(cls, shared=False) |
| 84 | + return _database_instances[cls] |
| 85 | + |
| 86 | + return connect(CONN_STRINGS[cls], N_THREADS) |
| 87 | + |
| 88 | + |
| 89 | +def _print_used_dbs(): |
| 90 | + used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None} |
| 91 | + unused = {k.__name__ for k, v in CONN_STRINGS.items() if v is None} |
| 92 | + |
| 93 | + print(f"Testing databases: {', '.join(used)}") |
| 94 | + if unused: |
| 95 | + logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") |
| 96 | + if TEST_ACROSS_ALL_DBS: |
| 97 | + logging.info( |
| 98 | + f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" |
| 99 | + ) |
| 100 | + |
| 101 | + |
| 102 | +_print_used_dbs() |
| 103 | +CONN_STRINGS = {k: v for k, v in CONN_STRINGS.items() if v is not None} |
| 104 | + |
| 105 | + |
| 106 | +def random_table_suffix() -> str: |
| 107 | + char_set = string.ascii_lowercase + string.digits |
| 108 | + suffix = "_" |
| 109 | + suffix += "".join(random.choice(char_set) for _ in range(5)) |
| 110 | + return suffix |
| 111 | + |
| 112 | + |
| 113 | +def str_to_checksum(str: str): |
| 114 | + # hello world |
| 115 | + # => 5eb63bbbe01eeed093cb22bb8f5acdc3 |
| 116 | + # => cb22bb8f5acdc3 |
| 117 | + # => 273350391345368515 |
| 118 | + m = hashlib.md5() |
| 119 | + m.update(str.encode("utf-8")) # encode to binary |
| 120 | + md5 = m.hexdigest() |
| 121 | + # 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs |
| 122 | + half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS |
| 123 | + return int(md5[half_pos:], 16) |
| 124 | + |
| 125 | + |
| 126 | +class DbTestCase(unittest.TestCase): |
| 127 | + "Sets up a table for testing" |
| 128 | + db_cls = None |
| 129 | + table1_schema = None |
| 130 | + shared_connection = True |
| 131 | + |
| 132 | + def setUp(self): |
| 133 | + assert self.db_cls, self.db_cls |
| 134 | + |
| 135 | + self.connection = get_conn(self.db_cls, self.shared_connection) |
| 136 | + |
| 137 | + table_suffix = random_table_suffix() |
| 138 | + self.table1_name = f"src{table_suffix}" |
| 139 | + |
| 140 | + self.table1_path = self.connection.parse_table_name(self.table1_name) |
| 141 | + |
| 142 | + drop_table(self.connection, self.table1_path) |
| 143 | + |
| 144 | + self.src_table = table(self.table1_path, schema=self.table1_schema) |
| 145 | + if self.table1_schema: |
| 146 | + self.connection.query(self.src_table.create()) |
| 147 | + |
| 148 | + return super().setUp() |
| 149 | + |
| 150 | + def tearDown(self): |
| 151 | + drop_table(self.connection, self.table1_path) |
| 152 | + |
| 153 | + |
| 154 | +def _parameterized_class_per_conn(test_databases): |
| 155 | + test_databases = set(test_databases) |
| 156 | + names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] |
| 157 | + return parameterized_class(("name", "db_cls"), names) |
| 158 | + |
| 159 | + |
| 160 | +def test_each_database_in_list(databases) -> Callable: |
| 161 | + def _test_per_database(cls): |
| 162 | + return _parameterized_class_per_conn(databases)(cls) |
| 163 | + |
| 164 | + return _test_per_database |
| 165 | + |
0 commit comments