Skip to content

Commit a6e5879

Browse files
committed
Fix: Improve error message for missing engine imports
1 parent e24c9da commit a6e5879

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import base64
55
import logging
66
import os
7+
import importlib
78
import pathlib
89
import re
910
import typing as t
@@ -49,6 +50,25 @@
4950
MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)")
5051

5152

53+
def _get_engine_import_validator(
54+
import_name: str, engine_type: str, extra_name: t.Optional[str] = None
55+
) -> t.Callable:
56+
extra_name = extra_name or engine_type
57+
58+
@model_validator(mode="before")
59+
def validate(cls: t.Any, data: t.Any) -> t.Any:
60+
try:
61+
importlib.import_module(import_name)
62+
except ImportError:
63+
raise ConfigError(
64+
f"Failed to import the '{engine_type}' engine library. Please run `pip install \"sqlmesh[{extra_name}]\"`."
65+
)
66+
67+
return data
68+
69+
return validate
70+
71+
5272
class ConnectionConfig(abc.ABC, BaseConfig):
5373
type_: str
5474
concurrent_tasks: int
@@ -428,6 +448,7 @@ class SnowflakeConnectionConfig(ConnectionConfig):
428448
type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
429449

430450
_concurrent_tasks_validator = concurrent_tasks_validator
451+
_engine_import_validator = _get_engine_import_validator("snowflake", "snowflake")
431452

432453
@model_validator(mode="before")
433454
def _validate_authenticator(cls, data: t.Any) -> t.Any:
@@ -621,6 +642,7 @@ class DatabricksConnectionConfig(ConnectionConfig):
621642

622643
_concurrent_tasks_validator = concurrent_tasks_validator
623644
_http_headers_validator = http_headers_validator
645+
_engine_import_validator = _get_engine_import_validator("databricks", "databricks")
624646

625647
@model_validator(mode="before")
626648
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
@@ -873,6 +895,8 @@ class BigQueryConnectionConfig(ConnectionConfig):
873895

874896
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
875897

898+
_engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
899+
876900
@field_validator("execution_project")
877901
def validate_execution_project(
878902
cls,
@@ -1015,6 +1039,10 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
10151039
register_comments: bool = True
10161040
pre_ping: bool = True
10171041

1042+
_engine_import_validator = _get_engine_import_validator(
1043+
"google.cloud.sql", "gcp_postgres", "gcppostgres"
1044+
)
1045+
10181046
@model_validator(mode="before")
10191047
def _validate_auth_method(cls, data: t.Any) -> t.Any:
10201048
if not isinstance(data, dict):
@@ -1142,6 +1170,8 @@ class RedshiftConnectionConfig(ConnectionConfig):
11421170

11431171
type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
11441172

1173+
_engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
1174+
11451175
@property
11461176
def _connection_kwargs_keys(self) -> t.Set[str]:
11471177
return {
@@ -1201,6 +1231,8 @@ class PostgresConnectionConfig(ConnectionConfig):
12011231

12021232
type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
12031233

1234+
_engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
1235+
12041236
@property
12051237
def _connection_kwargs_keys(self) -> t.Set[str]:
12061238
return {
@@ -1252,6 +1284,8 @@ class MySQLConnectionConfig(ConnectionConfig):
12521284

12531285
type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
12541286

1287+
_engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
1288+
12551289
@property
12561290
def _connection_kwargs_keys(self) -> t.Set[str]:
12571291
connection_keys = {
@@ -1302,6 +1336,8 @@ class MSSQLConnectionConfig(ConnectionConfig):
13021336

13031337
type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
13041338

1339+
_engine_import_validator = _get_engine_import_validator("pymssql", "mssql")
1340+
13051341
@property
13061342
def _connection_kwargs_keys(self) -> t.Set[str]:
13071343
return {
@@ -1357,6 +1393,8 @@ class SparkConnectionConfig(ConnectionConfig):
13571393

13581394
type_: t.Literal["spark"] = Field(alias="type", default="spark")
13591395

1396+
_engine_import_validator = _get_engine_import_validator("pyspark", "spark")
1397+
13601398
@property
13611399
def _connection_kwargs_keys(self) -> t.Set[str]:
13621400
return {
@@ -1473,6 +1511,8 @@ class TrinoConnectionConfig(ConnectionConfig):
14731511

14741512
type_: t.Literal["trino"] = Field(alias="type", default="trino")
14751513

1514+
_engine_import_validator = _get_engine_import_validator("trino", "trino")
1515+
14761516
@field_validator("schema_location_mapping", mode="before")
14771517
@classmethod
14781518
def _validate_regex_keys(
@@ -1623,6 +1663,8 @@ class ClickhouseConnectionConfig(ConnectionConfig):
16231663

16241664
type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
16251665

1666+
_engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
1667+
16261668
@property
16271669
def _connection_kwargs_keys(self) -> t.Set[str]:
16281670
kwargs = {
@@ -1727,6 +1769,8 @@ class AthenaConnectionConfig(ConnectionConfig):
17271769

17281770
type_: t.Literal["athena"] = Field(alias="type", default="athena")
17291771

1772+
_engine_import_validator = _get_engine_import_validator("pyathena", "athena")
1773+
17301774
@model_validator(mode="after")
17311775
def _root_validator(self) -> Self:
17321776
work_group = self.work_group
@@ -1793,6 +1837,8 @@ class RisingwaveConnectionConfig(ConnectionConfig):
17931837

17941838
type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
17951839

1840+
_engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
1841+
17961842
@property
17971843
def _connection_kwargs_keys(self) -> t.Set[str]:
17981844
return {

tests/core/test_connection_config.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
TrinoAuthenticationMethod,
2121
AthenaConnectionConfig,
2222
_connection_config_validator,
23+
_get_engine_import_validator,
2324
)
2425
from sqlmesh.utils.errors import ConfigError
26+
from sqlmesh.utils.pydantic import PydanticModel
2527

2628

2729
@pytest.fixture
@@ -994,3 +996,36 @@ def test_databricks(make_config):
994996
server_hostname="dbc-test.cloud.databricks.com",
995997
auth_type="databricks-oauth",
996998
)
999+
1000+
1001+
def test_engine_import_validator():
1002+
with pytest.raises(
1003+
ConfigError,
1004+
match=re.escape(
1005+
"""Failed to import the 'bigquery' engine library. Please run `pip install "sqlmesh[bigquery]"`."""
1006+
),
1007+
):
1008+
1009+
class TestConfigA(PydanticModel):
1010+
_engine_import_validator = _get_engine_import_validator("missing", "bigquery")
1011+
1012+
TestConfigA()
1013+
1014+
with pytest.raises(
1015+
ConfigError,
1016+
match=re.escape(
1017+
"""Failed to import the 'bigquery' engine library. Please run `pip install "sqlmesh[bigquery_extra]"`."""
1018+
),
1019+
):
1020+
1021+
class TestConfigB(PydanticModel):
1022+
_engine_import_validator = _get_engine_import_validator(
1023+
"missing", "bigquery", "bigquery_extra"
1024+
)
1025+
1026+
TestConfigB()
1027+
1028+
class TestConfigC(PydanticModel):
1029+
_engine_import_validator = _get_engine_import_validator("sqlmesh", "bigquery")
1030+
1031+
TestConfigC()

0 commit comments

Comments
 (0)