Skip to content

Commit cc3e493

Browse files
authored
Fix: Improve error message for missing engine imports (#4447)
1 parent beee40c commit cc3e493

4 files changed

Lines changed: 115 additions & 4 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import base64
55
import logging
66
import os
7+
import importlib
78
import pathlib
89
import re
910
import typing as t
@@ -24,6 +25,7 @@
2425
)
2526
from sqlmesh.core.engine_adapter.shared import CatalogSupport
2627
from sqlmesh.core.engine_adapter import EngineAdapter
28+
from sqlmesh.utils import str_to_bool
2729
from sqlmesh.utils.errors import ConfigError
2830
from sqlmesh.utils.pydantic import (
2931
ValidationInfo,
@@ -49,6 +51,30 @@
4951
MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)")
5052

5153

54+
def _get_engine_import_validator(
55+
import_name: str, engine_type: str, extra_name: t.Optional[str] = None
56+
) -> t.Callable:
57+
extra_name = extra_name or engine_type
58+
59+
@model_validator(mode="before")
60+
def validate(cls: t.Any, data: t.Any) -> t.Any:
61+
check_import = (
62+
str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True
63+
)
64+
if not check_import:
65+
return data
66+
try:
67+
importlib.import_module(import_name)
68+
except ImportError:
69+
raise ConfigError(
70+
f"Failed to import the '{engine_type}' engine library. Please run `pip install \"sqlmesh[{extra_name}]\"`."
71+
)
72+
73+
return data
74+
75+
return validate
76+
77+
5278
class ConnectionConfig(abc.ABC, BaseConfig):
5379
type_: str
5480
concurrent_tasks: int
@@ -428,6 +454,7 @@ class SnowflakeConnectionConfig(ConnectionConfig):
428454
type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
429455

430456
_concurrent_tasks_validator = concurrent_tasks_validator
457+
_engine_import_validator = _get_engine_import_validator("snowflake", "snowflake")
431458

432459
@model_validator(mode="before")
433460
def _validate_authenticator(cls, data: t.Any) -> t.Any:
@@ -621,6 +648,7 @@ class DatabricksConnectionConfig(ConnectionConfig):
621648

622649
_concurrent_tasks_validator = concurrent_tasks_validator
623650
_http_headers_validator = http_headers_validator
651+
_engine_import_validator = _get_engine_import_validator("databricks", "databricks")
624652

625653
@model_validator(mode="before")
626654
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
@@ -873,6 +901,8 @@ class BigQueryConnectionConfig(ConnectionConfig):
873901

874902
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
875903

904+
_engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
905+
876906
@field_validator("execution_project")
877907
def validate_execution_project(
878908
cls,
@@ -1015,6 +1045,10 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
10151045
register_comments: bool = True
10161046
pre_ping: bool = True
10171047

1048+
_engine_import_validator = _get_engine_import_validator(
1049+
"google.cloud.sql", "gcp_postgres", "gcppostgres"
1050+
)
1051+
10181052
@model_validator(mode="before")
10191053
def _validate_auth_method(cls, data: t.Any) -> t.Any:
10201054
if not isinstance(data, dict):
@@ -1142,6 +1176,8 @@ class RedshiftConnectionConfig(ConnectionConfig):
11421176

11431177
type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
11441178

1179+
_engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
1180+
11451181
@property
11461182
def _connection_kwargs_keys(self) -> t.Set[str]:
11471183
return {
@@ -1201,6 +1237,8 @@ class PostgresConnectionConfig(ConnectionConfig):
12011237

12021238
type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
12031239

1240+
_engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
1241+
12041242
@property
12051243
def _connection_kwargs_keys(self) -> t.Set[str]:
12061244
return {
@@ -1252,6 +1290,8 @@ class MySQLConnectionConfig(ConnectionConfig):
12521290

12531291
type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
12541292

1293+
_engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
1294+
12551295
@property
12561296
def _connection_kwargs_keys(self) -> t.Set[str]:
12571297
connection_keys = {
@@ -1302,6 +1342,8 @@ class MSSQLConnectionConfig(ConnectionConfig):
13021342

13031343
type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
13041344

1345+
_engine_import_validator = _get_engine_import_validator("pymssql", "mssql")
1346+
13051347
@property
13061348
def _connection_kwargs_keys(self) -> t.Set[str]:
13071349
return {
@@ -1357,6 +1399,8 @@ class SparkConnectionConfig(ConnectionConfig):
13571399

13581400
type_: t.Literal["spark"] = Field(alias="type", default="spark")
13591401

1402+
_engine_import_validator = _get_engine_import_validator("pyspark", "spark")
1403+
13601404
@property
13611405
def _connection_kwargs_keys(self) -> t.Set[str]:
13621406
return {
@@ -1473,6 +1517,8 @@ class TrinoConnectionConfig(ConnectionConfig):
14731517

14741518
type_: t.Literal["trino"] = Field(alias="type", default="trino")
14751519

1520+
_engine_import_validator = _get_engine_import_validator("trino", "trino")
1521+
14761522
@field_validator("schema_location_mapping", mode="before")
14771523
@classmethod
14781524
def _validate_regex_keys(
@@ -1623,6 +1669,8 @@ class ClickhouseConnectionConfig(ConnectionConfig):
16231669

16241670
type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
16251671

1672+
_engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
1673+
16261674
@property
16271675
def _connection_kwargs_keys(self) -> t.Set[str]:
16281676
kwargs = {
@@ -1727,6 +1775,8 @@ class AthenaConnectionConfig(ConnectionConfig):
17271775

17281776
type_: t.Literal["athena"] = Field(alias="type", default="athena")
17291777

1778+
_engine_import_validator = _get_engine_import_validator("pyathena", "athena")
1779+
17301780
@model_validator(mode="after")
17311781
def _root_validator(self) -> Self:
17321782
work_group = self.work_group
@@ -1793,6 +1843,8 @@ class RisingwaveConnectionConfig(ConnectionConfig):
17931843

17941844
type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
17951845

1846+
_engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
1847+
17961848
@property
17971849
def _connection_kwargs_keys(self) -> t.Set[str]:
17981850
return {

tests/core/engine_adapter/integration/config.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ gateways:
1717
catalog: datalake
1818
http_scheme: http
1919
retries: 20
20+
check_import: false
2021
state_connection:
2122
type: duckdb
2223
inttest_trino_iceberg:
@@ -28,6 +29,7 @@ gateways:
2829
catalog: datalake_iceberg
2930
http_scheme: http
3031
retries: 20
32+
check_import: false
3133
state_connection:
3234
type: duckdb
3335
inttest_trino_delta:
@@ -39,6 +41,7 @@ gateways:
3941
catalog: datalake_delta
4042
http_scheme: http
4143
retries: 20
44+
check_import: false
4245
state_connection:
4346
type: duckdb
4447
inttest_trino_nessie:
@@ -50,13 +53,15 @@ gateways:
5053
catalog: datalake_nessie
5154
http_scheme: http
5255
retries: 20
56+
check_import: false
5357
state_connection:
5458
type: duckdb
5559
inttest_spark:
5660
connection:
5761
type: spark
5862
config:
5963
spark.remote: sc://{{ env_var('DOCKER_HOSTNAME', 'localhost') }}
64+
check_import: false
6065
state_connection:
6166
type: duckdb
6267
inttest_mssql:
@@ -65,6 +70,7 @@ gateways:
6570
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }}
6671
user: sa
6772
password: 1StrongPwd@@
73+
check_import: false
6874
inttest_postgres:
6975
connection:
7076
type: postgres
@@ -73,6 +79,7 @@ gateways:
7379
database: postgres
7480
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }}
7581
port: 5432
82+
check_import: false
7683
inttest_mysql:
7784
connection:
7885
type: mysql
@@ -81,13 +88,15 @@ gateways:
8188
password: mysql
8289
port: 3306
8390
charset: utf8
91+
check_import: false
8492
inttest_clickhouse_standalone:
8593
connection:
8694
type: clickhouse
8795
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }}
8896
port: 8123
8997
username: clickhouse
9098
password: clickhouse
99+
check_import: false
91100
state_connection:
92101
type: duckdb
93102
inttest_clickhouse_cluster:
@@ -98,6 +107,7 @@ gateways:
98107
username: clickhouse
99108
password: clickhouse
100109
cluster: cluster1
110+
check_import: false
101111
state_connection:
102112
type: duckdb
103113
inttest_risingwave:
@@ -107,6 +117,7 @@ gateways:
107117
database: dev
108118
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }}
109119
port: 4566
120+
check_import: false
110121

111122

112123
# Cloud databases
@@ -118,6 +129,7 @@ gateways:
118129
database: {{ env_var('SNOWFLAKE_DATABASE') }}
119130
user: {{ env_var('SNOWFLAKE_USER') }}
120131
password: {{ env_var('SNOWFLAKE_PASSWORD') }}
132+
check_import: false
121133
state_connection:
122134
type: duckdb
123135

@@ -128,6 +140,7 @@ gateways:
128140
server_hostname: {{ env_var('DATABRICKS_SERVER_HOSTNAME') }}
129141
http_path: {{ env_var('DATABRICKS_HTTP_PATH') }}
130142
access_token: {{ env_var('DATABRICKS_ACCESS_TOKEN') }}
143+
check_import: false
131144

132145
inttest_redshift:
133146
connection:
@@ -136,12 +149,14 @@ gateways:
136149
user: {{ env_var('REDSHIFT_USER') }}
137150
password: {{ env_var('REDSHIFT_PASSWORD') }}
138151
database: {{ env_var('REDSHIFT_DATABASE') }}
152+
check_import: false
139153

140154
inttest_bigquery:
141155
connection:
142156
type: bigquery
143157
method: service-account
144158
keyfile: {{ env_var('BIGQUERY_KEYFILE') }}
159+
check_import: false
145160
state_connection:
146161
type: duckdb
147162

@@ -155,6 +170,7 @@ gateways:
155170
connect_timeout: 30
156171
connection_pool_options:
157172
retries: 5
173+
check_import: false
158174
state_connection:
159175
type: duckdb
160176

@@ -166,6 +182,7 @@ gateways:
166182
region_name: {{ env_var("AWS_REGION") }}
167183
work_group: {{ env_var("ATHENA_WORK_GROUP", "primary") }}
168184
s3_warehouse_location: {{ env_var("ATHENA_S3_WAREHOUSE_LOCATION", "") }}
185+
check_import: false
169186
state_connection:
170187
type: duckdb
171188

tests/core/test_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def test_load_yaml_config_env_var_gateway_override(tmp_path_factory):
343343
os.environ,
344344
{
345345
"SQLMESH__GATEWAYS__TESTING__STATE_CONNECTION__TYPE": "bigquery",
346+
"SQLMESH__GATEWAYS__TESTING__STATE_CONNECTION__CHECK_IMPORT": "false",
346347
"SQLMESH__DEFAULT_GATEWAY": "testing",
347348
},
348349
):
@@ -353,7 +354,7 @@ def test_load_yaml_config_env_var_gateway_override(tmp_path_factory):
353354
gateways={
354355
"testing": GatewayConfig(
355356
connection=MotherDuckConnectionConfig(database="blah"),
356-
state_connection=BigQueryConnectionConfig(),
357+
state_connection=BigQueryConnectionConfig(check_import=False),
357358
),
358359
},
359360
model_defaults=ModelDefaultsConfig(dialect="bigquery"),
@@ -373,6 +374,7 @@ def test_load_py_config_env_var_gateway_override(tmp_path_factory):
373374
os.environ,
374375
{
375376
"SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__TYPE": "bigquery",
377+
"SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__CHECK_IMPORT": "false",
376378
"SQLMESH__DEFAULT_GATEWAY": "duckdb_gateway",
377379
},
378380
):
@@ -384,7 +386,7 @@ def test_load_py_config_env_var_gateway_override(tmp_path_factory):
384386
gateways={ # type: ignore
385387
"duckdb_gateway": GatewayConfig(
386388
connection=DuckDBConnectionConfig(),
387-
state_connection=BigQueryConnectionConfig(),
389+
state_connection=BigQueryConnectionConfig(check_import=False),
388390
),
389391
},
390392
model_defaults=ModelDefaultsConfig(dialect=""),
@@ -823,6 +825,7 @@ def test_gcp_postgres_ip_and_scopes(tmp_path):
823825
gcp_postgres:
824826
connection:
825827
type: gcp_postgres
828+
check_import: false
826829
instance_connection_string: something
827830
user: user
828831
password: password

0 commit comments

Comments
 (0)