|
4 | 4 | import base64 |
5 | 5 | import logging |
6 | 6 | import os |
| 7 | +import importlib |
7 | 8 | import pathlib |
8 | 9 | import re |
9 | 10 | import typing as t |
|
49 | 50 | MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)") |
50 | 51 |
|
51 | 52 |
|
| 53 | +def _get_engine_import_validator( |
| 54 | + import_name: str, engine_type: str, extra_name: t.Optional[str] = None |
| 55 | +) -> t.Callable: |
| 56 | + extra_name = extra_name or engine_type |
| 57 | + |
| 58 | + @model_validator(mode="before") |
| 59 | + def validate(cls: t.Any, data: t.Any) -> t.Any: |
| 60 | + try: |
| 61 | + importlib.import_module(import_name) |
| 62 | + except ImportError: |
| 63 | + raise ConfigError( |
| 64 | + f"Failed to import the '{engine_type}' engine library. Please run `pip install \"sqlmesh[{extra_name}]\"`." |
| 65 | + ) |
| 66 | + |
| 67 | + return data |
| 68 | + |
| 69 | + return validate |
| 70 | + |
| 71 | + |
52 | 72 | class ConnectionConfig(abc.ABC, BaseConfig): |
53 | 73 | type_: str |
54 | 74 | concurrent_tasks: int |
@@ -428,6 +448,7 @@ class SnowflakeConnectionConfig(ConnectionConfig): |
428 | 448 | type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") |
429 | 449 |
|
430 | 450 | _concurrent_tasks_validator = concurrent_tasks_validator |
| 451 | + _engine_import_validator = _get_engine_import_validator("snowflake", "snowflake") |
431 | 452 |
|
432 | 453 | @model_validator(mode="before") |
433 | 454 | def _validate_authenticator(cls, data: t.Any) -> t.Any: |
@@ -621,6 +642,7 @@ class DatabricksConnectionConfig(ConnectionConfig): |
621 | 642 |
|
622 | 643 | _concurrent_tasks_validator = concurrent_tasks_validator |
623 | 644 | _http_headers_validator = http_headers_validator |
| 645 | + _engine_import_validator = _get_engine_import_validator("databricks", "databricks") |
624 | 646 |
|
625 | 647 | @model_validator(mode="before") |
626 | 648 | def _databricks_connect_validator(cls, data: t.Any) -> t.Any: |
@@ -873,6 +895,8 @@ class BigQueryConnectionConfig(ConnectionConfig): |
873 | 895 |
|
874 | 896 | type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") |
875 | 897 |
|
| 898 | + _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") |
| 899 | + |
876 | 900 | @field_validator("execution_project") |
877 | 901 | def validate_execution_project( |
878 | 902 | cls, |
@@ -1015,6 +1039,10 @@ class GCPPostgresConnectionConfig(ConnectionConfig): |
1015 | 1039 | register_comments: bool = True |
1016 | 1040 | pre_ping: bool = True |
1017 | 1041 |
|
| 1042 | + _engine_import_validator = _get_engine_import_validator( |
| 1043 | + "google.cloud.sql", "gcp_postgres", "gcppostgres" |
| 1044 | + ) |
| 1045 | + |
1018 | 1046 | @model_validator(mode="before") |
1019 | 1047 | def _validate_auth_method(cls, data: t.Any) -> t.Any: |
1020 | 1048 | if not isinstance(data, dict): |
@@ -1142,6 +1170,8 @@ class RedshiftConnectionConfig(ConnectionConfig): |
1142 | 1170 |
|
1143 | 1171 | type_: t.Literal["redshift"] = Field(alias="type", default="redshift") |
1144 | 1172 |
|
| 1173 | + _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") |
| 1174 | + |
1145 | 1175 | @property |
1146 | 1176 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1147 | 1177 | return { |
@@ -1201,6 +1231,8 @@ class PostgresConnectionConfig(ConnectionConfig): |
1201 | 1231 |
|
1202 | 1232 | type_: t.Literal["postgres"] = Field(alias="type", default="postgres") |
1203 | 1233 |
|
| 1234 | + _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") |
| 1235 | + |
1204 | 1236 | @property |
1205 | 1237 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1206 | 1238 | return { |
@@ -1252,6 +1284,8 @@ class MySQLConnectionConfig(ConnectionConfig): |
1252 | 1284 |
|
1253 | 1285 | type_: t.Literal["mysql"] = Field(alias="type", default="mysql") |
1254 | 1286 |
|
| 1287 | + _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") |
| 1288 | + |
1255 | 1289 | @property |
1256 | 1290 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1257 | 1291 | connection_keys = { |
@@ -1302,6 +1336,8 @@ class MSSQLConnectionConfig(ConnectionConfig): |
1302 | 1336 |
|
1303 | 1337 | type_: t.Literal["mssql"] = Field(alias="type", default="mssql") |
1304 | 1338 |
|
| 1339 | + _engine_import_validator = _get_engine_import_validator("pymssql", "mssql") |
| 1340 | + |
1305 | 1341 | @property |
1306 | 1342 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1307 | 1343 | return { |
@@ -1357,6 +1393,8 @@ class SparkConnectionConfig(ConnectionConfig): |
1357 | 1393 |
|
1358 | 1394 | type_: t.Literal["spark"] = Field(alias="type", default="spark") |
1359 | 1395 |
|
| 1396 | + _engine_import_validator = _get_engine_import_validator("pyspark", "spark") |
| 1397 | + |
1360 | 1398 | @property |
1361 | 1399 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1362 | 1400 | return { |
@@ -1473,6 +1511,8 @@ class TrinoConnectionConfig(ConnectionConfig): |
1473 | 1511 |
|
1474 | 1512 | type_: t.Literal["trino"] = Field(alias="type", default="trino") |
1475 | 1513 |
|
| 1514 | + _engine_import_validator = _get_engine_import_validator("trino", "trino") |
| 1515 | + |
1476 | 1516 | @field_validator("schema_location_mapping", mode="before") |
1477 | 1517 | @classmethod |
1478 | 1518 | def _validate_regex_keys( |
@@ -1623,6 +1663,8 @@ class ClickhouseConnectionConfig(ConnectionConfig): |
1623 | 1663 |
|
1624 | 1664 | type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") |
1625 | 1665 |
|
| 1666 | + _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") |
| 1667 | + |
1626 | 1668 | @property |
1627 | 1669 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1628 | 1670 | kwargs = { |
@@ -1727,6 +1769,8 @@ class AthenaConnectionConfig(ConnectionConfig): |
1727 | 1769 |
|
1728 | 1770 | type_: t.Literal["athena"] = Field(alias="type", default="athena") |
1729 | 1771 |
|
| 1772 | + _engine_import_validator = _get_engine_import_validator("pyathena", "athena") |
| 1773 | + |
1730 | 1774 | @model_validator(mode="after") |
1731 | 1775 | def _root_validator(self) -> Self: |
1732 | 1776 | work_group = self.work_group |
@@ -1793,6 +1837,8 @@ class RisingwaveConnectionConfig(ConnectionConfig): |
1793 | 1837 |
|
1794 | 1838 | type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") |
1795 | 1839 |
|
| 1840 | + _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") |
| 1841 | + |
1796 | 1842 | @property |
1797 | 1843 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1798 | 1844 | return { |
|
0 commit comments