Skip to content
This repository was archived by the owner on May 2, 2023. It is now read-only.

Commit ad08db8

Browse files
committed
Added tests
1 parent b91b622 commit ad08db8

6 files changed

Lines changed: 623 additions & 0 deletions

File tree

tests/__init__.py

Whitespace-only changes.

tests/common.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import hashlib
2+
import os
3+
import string
4+
import random
5+
from typing import Callable
6+
import unittest
7+
import logging
8+
import subprocess
9+
10+
from parameterized import parameterized_class
11+
12+
from sqeleton import databases as db
13+
from sqeleton import connect
14+
from sqeleton.queries import table
15+
from sqeleton.databases import Database
16+
from sqeleton.query_utils import drop_table
17+
18+
19+
# We write 'or None' because Github sometimes creates empty env vars for secrets
20+
TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql"
21+
TEST_POSTGRESQL_CONN_STRING: str = "postgresql://postgres:Password1@localhost/postgres"
22+
TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("SNOWFLAKE_URI") or None
23+
TEST_PRESTO_CONN_STRING: str = os.environ.get("PRESTO_URI") or None
24+
TEST_BIGQUERY_CONN_STRING: str = os.environ.get("BIGQUERY_URI") or None
25+
TEST_REDSHIFT_CONN_STRING: str = os.environ.get("REDSHIFT_URI") or None
26+
TEST_ORACLE_CONN_STRING: str = None
27+
TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATABRICKS_URI")
28+
TEST_TRINO_CONN_STRING: str = os.environ.get("TRINO_URI") or None
29+
# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
30+
TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("CLICKHOUSE_URI")
31+
# vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica"
32+
TEST_VERTICA_CONN_STRING: str = os.environ.get("VERTICA_URI")
33+
TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:"
34+
35+
36+
DEFAULT_N_SAMPLES = 50
37+
N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES))
38+
BENCHMARK = os.environ.get("BENCHMARK", False)
39+
N_THREADS = int(os.environ.get("N_THREADS", 1))
40+
TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite?
41+
42+
43+
def get_git_revision_short_hash() -> str:
44+
return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
45+
46+
47+
GIT_REVISION = get_git_revision_short_hash()
48+
49+
level = logging.ERROR
50+
if os.environ.get("LOG_LEVEL", False):
51+
level = getattr(logging, os.environ["LOG_LEVEL"].upper())
52+
53+
logging.basicConfig(level=level)
54+
logging.getLogger("database").setLevel(level)
55+
56+
try:
57+
from .local_settings import *
58+
except ImportError:
59+
pass # No local settings
60+
61+
62+
CONN_STRINGS = {
63+
db.BigQuery: TEST_BIGQUERY_CONN_STRING,
64+
db.MySQL: TEST_MYSQL_CONN_STRING,
65+
db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING,
66+
db.Snowflake: TEST_SNOWFLAKE_CONN_STRING,
67+
db.Redshift: TEST_REDSHIFT_CONN_STRING,
68+
db.Oracle: TEST_ORACLE_CONN_STRING,
69+
db.Presto: TEST_PRESTO_CONN_STRING,
70+
db.Databricks: TEST_DATABRICKS_CONN_STRING,
71+
db.Trino: TEST_TRINO_CONN_STRING,
72+
db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING,
73+
db.Vertica: TEST_VERTICA_CONN_STRING,
74+
db.DuckDB: TEST_DUCKDB_CONN_STRING,
75+
}
76+
77+
_database_instances = {}
78+
79+
80+
def get_conn(cls: type, shared: bool = True) -> Database:
81+
if shared:
82+
if cls not in _database_instances:
83+
_database_instances[cls] = get_conn(cls, shared=False)
84+
return _database_instances[cls]
85+
86+
return connect(CONN_STRINGS[cls], N_THREADS)
87+
88+
89+
def _print_used_dbs():
90+
used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None}
91+
unused = {k.__name__ for k, v in CONN_STRINGS.items() if v is None}
92+
93+
print(f"Testing databases: {', '.join(used)}")
94+
if unused:
95+
logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}")
96+
if TEST_ACROSS_ALL_DBS:
97+
logging.info(
98+
f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}"
99+
)
100+
101+
102+
_print_used_dbs()
103+
CONN_STRINGS = {k: v for k, v in CONN_STRINGS.items() if v is not None}
104+
105+
106+
def random_table_suffix() -> str:
107+
char_set = string.ascii_lowercase + string.digits
108+
suffix = "_"
109+
suffix += "".join(random.choice(char_set) for _ in range(5))
110+
return suffix
111+
112+
113+
def str_to_checksum(str: str):
114+
# hello world
115+
# => 5eb63bbbe01eeed093cb22bb8f5acdc3
116+
# => cb22bb8f5acdc3
117+
# => 273350391345368515
118+
m = hashlib.md5()
119+
m.update(str.encode("utf-8")) # encode to binary
120+
md5 = m.hexdigest()
121+
# 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs
122+
half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS
123+
return int(md5[half_pos:], 16)
124+
125+
126+
class DbTestCase(unittest.TestCase):
127+
"Sets up a table for testing"
128+
db_cls = None
129+
table1_schema = None
130+
shared_connection = True
131+
132+
def setUp(self):
133+
assert self.db_cls, self.db_cls
134+
135+
self.connection = get_conn(self.db_cls, self.shared_connection)
136+
137+
table_suffix = random_table_suffix()
138+
self.table1_name = f"src{table_suffix}"
139+
140+
self.table1_path = self.connection.parse_table_name(self.table1_name)
141+
142+
drop_table(self.connection, self.table1_path)
143+
144+
self.src_table = table(self.table1_path, schema=self.table1_schema)
145+
if self.table1_schema:
146+
self.connection.query(self.src_table.create())
147+
148+
return super().setUp()
149+
150+
def tearDown(self):
151+
drop_table(self.connection, self.table1_path)
152+
153+
154+
def _parameterized_class_per_conn(test_databases):
155+
test_databases = set(test_databases)
156+
names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases]
157+
return parameterized_class(("name", "db_cls"), names)
158+
159+
160+
def test_each_database_in_list(databases) -> Callable:
161+
def _test_per_database(cls):
162+
return _parameterized_class_per_conn(databases)(cls)
163+
164+
return _test_per_database
165+

tests/test_database.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Callable, List
2+
from datetime import datetime
3+
import unittest
4+
5+
from .common import str_to_checksum, TEST_MYSQL_CONN_STRING
6+
from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix
7+
8+
from sqeleton.queries import table, current_timestamp
9+
10+
from sqeleton import databases as dbs
11+
from sqeleton import connect
12+
13+
14+
TEST_DATABASES = {
15+
dbs.MySQL,
16+
dbs.PostgreSQL,
17+
dbs.Oracle,
18+
dbs.Redshift,
19+
dbs.Snowflake,
20+
dbs.DuckDB,
21+
dbs.BigQuery,
22+
dbs.Presto,
23+
dbs.Trino,
24+
dbs.Vertica,
25+
}
26+
27+
test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)
28+
29+
30+
class TestDatabase(unittest.TestCase):
31+
def setUp(self):
32+
self.mysql = connect(TEST_MYSQL_CONN_STRING)
33+
34+
def test_connect_to_db(self):
35+
self.assertEqual(1, self.mysql.query("SELECT 1", int))
36+
37+
class TestMD5(unittest.TestCase):
38+
def test_md5_as_int(self):
39+
class MD5Dialect(dbs.mysql.Dialect, dbs.mysql.Mixin_MD5):
40+
pass
41+
42+
self.mysql = connect(TEST_MYSQL_CONN_STRING)
43+
self.mysql.dialect = MD5Dialect()
44+
45+
str = "hello world"
46+
query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str))
47+
query = f"SELECT {query_fragment}"
48+
49+
self.assertEqual(str_to_checksum(str), self.mysql.query(query, int))
50+
51+
52+
class TestConnect(unittest.TestCase):
53+
def test_bad_uris(self):
54+
self.assertRaises(ValueError, connect, "p")
55+
self.assertRaises(ValueError, connect, "postgresql:///bla/foo")
56+
self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1")
57+
self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1?warehouse=ha&schema=dup")
58+
59+
60+
@test_each_database
61+
class TestSchema(unittest.TestCase):
62+
def test_table_list(self):
63+
name = "tbl_" + random_table_suffix()
64+
db = get_conn(self.db_cls)
65+
tbl = table(db.parse_table_name(name), schema={"id": int})
66+
q = db.dialect.list_tables(db.default_schema, name)
67+
assert not db.query(q)
68+
69+
db.query(tbl.create())
70+
self.assertEqual(db.query(q, List[str]), [name])
71+
72+
db.query(tbl.drop())
73+
assert not db.query(q)
74+
75+
76+
@test_each_database
77+
class TestQueries(unittest.TestCase):
78+
def test_current_timestamp(self):
79+
db = get_conn(self.db_cls)
80+
res = db.query(current_timestamp(), datetime)
81+
assert isinstance(res, datetime), (res, type(res))

0 commit comments

Comments
 (0)