Skip to content

Commit 18911b3

Browse files
authored
Feat: Allow supplying complex connection config fields as JSON strings (#4519)
1 parent 657cf0d commit 18911b3

5 files changed

Lines changed: 152 additions & 1 deletion

File tree

sqlmesh/core/config/connection.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import pydantic
1515
from pydantic import Field
16+
from pydantic_core import from_json
1617
from packaging import version
1718
from sqlglot import exp
1819
from sqlglot.helper import subclasses
@@ -33,6 +34,7 @@
3334
field_validator,
3435
model_validator,
3536
validation_error_message,
37+
get_concrete_types_from_typehint,
3638
)
3739
from sqlmesh.utils.aws import validate_s3_uri
3840

@@ -177,6 +179,42 @@ def get_catalog(self) -> t.Optional[str]:
177179
return self.db
178180
return None
179181

182+
@model_validator(mode="before")
183+
@classmethod
184+
def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any:
185+
"""
186+
There are situations where a connection config class has a field that is some kind of complex type
187+
(eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable
188+
189+
When this happens, the value is supplied as a string rather than a Python object. We need some way
190+
of turning this string into the corresponding Python list or dict.
191+
192+
Rather than doing this piecemeal on every config subclass, this provides a generic implementatation
193+
to identify fields that may be be supplied as JSON strings and handle them transparently
194+
"""
195+
if data and isinstance(data, dict):
196+
for maybe_json_field_name in cls._get_list_and_dict_field_names():
197+
if (value := data.get(maybe_json_field_name)) and isinstance(value, str):
198+
# crude JSON check as we dont want to try and parse every string we get
199+
value = value.strip()
200+
if value.startswith("{") or value.startswith("["):
201+
data[maybe_json_field_name] = from_json(value)
202+
203+
return data
204+
205+
@classmethod
206+
def _get_list_and_dict_field_names(cls) -> t.Set[str]:
207+
field_names = set()
208+
for name, field in cls.model_fields.items():
209+
if field.annotation:
210+
field_types = get_concrete_types_from_typehint(field.annotation)
211+
212+
# check if the field type is something that could concievably be supplied as a json string
213+
if any(ft is t for t in (list, tuple, set, dict) for ft in field_types):
214+
field_names.add(name)
215+
216+
return field_names
217+
180218

181219
class DuckDBAttachOptions(BaseConfig):
182220
type: str

sqlmesh/utils/pydantic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,24 @@ def cron_validator(v: t.Any) -> str:
306306
return v
307307

308308

309+
def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
310+
concrete_types = set()
311+
unpacked = t.get_origin(typehint)
312+
if unpacked is None:
313+
if type(typehint) == type(type):
314+
return {typehint}
315+
elif unpacked is t.Union:
316+
for item in t.get_args(typehint):
317+
if str(item).startswith("typing."):
318+
concrete_types |= get_concrete_types_from_typehint(item)
319+
else:
320+
concrete_types.add(item)
321+
else:
322+
concrete_types.add(unpacked)
323+
324+
return concrete_types
325+
326+
309327
if t.TYPE_CHECKING:
310328
SQLGlotListOfStrings = t.List[str]
311329
SQLGlotString = str

tests/core/test_config.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,37 @@ def test_config_subclassing() -> None:
985985
class ConfigSubclass(Config): ...
986986

987987
ConfigSubclass()
988+
989+
990+
def test_config_complex_types_supplied_as_json_strings_from_env(tmp_path: Path) -> None:
991+
config_path = tmp_path / "config.yaml"
992+
config_path.write_text("""
993+
gateways:
994+
bigquery:
995+
connection:
996+
type: bigquery
997+
project: unit-test
998+
999+
default_gateway: bigquery
1000+
1001+
model_defaults:
1002+
dialect: bigquery
1003+
""")
1004+
with mock.patch.dict(
1005+
os.environ,
1006+
{
1007+
"SQLMESH__GATEWAYS__BIGQUERY__CONNECTION__SCOPES": ' ["a","b","c"]', # note: leading whitespace is deliberate
1008+
"SQLMESH__GATEWAYS__BIGQUERY__CONNECTION__KEYFILE_JSON": '{ "foo": "bar" }',
1009+
},
1010+
):
1011+
config = load_config_from_paths(
1012+
Config,
1013+
project_paths=[config_path],
1014+
)
1015+
1016+
conn = config.gateways["bigquery"].connection
1017+
assert isinstance(conn, BigQueryConnectionConfig)
1018+
1019+
assert conn.project == "unit-test"
1020+
assert conn.scopes == ("a", "b", "c")
1021+
assert conn.keyfile_json == {"foo": "bar"}

tests/core/test_connection_config.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,27 @@ def test_duckdb_attach_options():
617617
assert options.to_sql(alias="db") == "ATTACH IF NOT EXISTS 'test.db' AS db"
618618

619619

620+
def test_duckdb_config_json_strings(make_config):
621+
config = make_config(
622+
type="duckdb",
623+
extensions='["foo","bar"]',
624+
catalogs="""{
625+
"test1": "test1.duckdb",
626+
"test2": {
627+
"type": "duckdb",
628+
"path": "test2.duckdb"
629+
}
630+
}""",
631+
)
632+
assert isinstance(config, DuckDBConnectionConfig)
633+
634+
assert config.extensions == ["foo", "bar"]
635+
636+
assert config.get_catalog() == "test1"
637+
assert config.catalogs.get("test1") == "test1.duckdb"
638+
assert config.catalogs.get("test2").path == "test2.duckdb"
639+
640+
620641
def test_motherduck_attach_catalog(make_config):
621642
config = make_config(
622643
type="motherduck",
@@ -779,6 +800,21 @@ def test_bigquery(make_config):
779800
make_config(type="bigquery", quota_project="quota_project", check_import=False)
780801

781802

803+
def test_bigquery_config_json_string(make_config):
804+
config = make_config(
805+
type="bigquery",
806+
project="project",
807+
# these can be present as strings if they came from env vars
808+
scopes='["a","b","c"]',
809+
keyfile_json='{"foo":"bar"}',
810+
)
811+
812+
assert isinstance(config, BigQueryConnectionConfig)
813+
814+
assert config.scopes == ("a", "b", "c")
815+
assert config.keyfile_json == {"foo": "bar"}
816+
817+
782818
def test_postgres(make_config):
783819
config = make_config(
784820
type="postgres",

tests/utils/test_pydantic.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import typing as t
2+
import pytest
13
from functools import cached_property
24

35
from sqlmesh.utils.date import TimeLike, to_date, to_datetime
4-
from sqlmesh.utils.pydantic import PydanticModel
6+
from sqlmesh.utils.pydantic import PydanticModel, get_concrete_types_from_typehint
57

68

79
def test_datetime_date_serialization() -> None:
@@ -62,3 +64,26 @@ class TestModel(PydanticModel):
6264
name: str
6365

6466
assert TestModel(name="foo").dict(by_alias=True)
67+
68+
69+
@pytest.mark.parametrize(
70+
"input,output",
71+
[
72+
(t.Dict[str, t.Any], {dict}),
73+
(dict, {dict}),
74+
(t.List[str], {list}),
75+
(list, {list}),
76+
(t.Tuple[str, ...], {tuple}),
77+
(tuple, {tuple}),
78+
(t.Set[str], {set}),
79+
(set, {set}),
80+
(t.Optional[t.Dict[str, t.Any]], {dict, type(None)}),
81+
(t.Optional[t.List[str]], {list, type(None)}),
82+
(
83+
t.Union[str, t.List[str], t.Dict[str, t.Any], t.Optional[t.Set[str]]],
84+
{str, list, dict, set, type(None)},
85+
),
86+
],
87+
)
88+
def test_get_concrete_types_from_typehint(input: t.Any, output: set[type]) -> None:
89+
assert get_concrete_types_from_typehint(input) == output

0 commit comments

Comments
 (0)