|
4 | 4 | import base64 |
5 | 5 | import logging |
6 | 6 | import os |
| 7 | +import importlib |
7 | 8 | import pathlib |
8 | 9 | import re |
9 | 10 | import typing as t |
|
24 | 25 | ) |
25 | 26 | from sqlmesh.core.engine_adapter.shared import CatalogSupport |
26 | 27 | from sqlmesh.core.engine_adapter import EngineAdapter |
| 28 | +from sqlmesh.utils import str_to_bool |
27 | 29 | from sqlmesh.utils.errors import ConfigError |
28 | 30 | from sqlmesh.utils.pydantic import ( |
29 | 31 | ValidationInfo, |
|
49 | 51 | MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)") |
50 | 52 |
|
51 | 53 |
|
| 54 | +def _get_engine_import_validator( |
| 55 | + import_name: str, engine_type: str, extra_name: t.Optional[str] = None |
| 56 | +) -> t.Callable: |
| 57 | + extra_name = extra_name or engine_type |
| 58 | + |
| 59 | + @model_validator(mode="before") |
| 60 | + def validate(cls: t.Any, data: t.Any) -> t.Any: |
| 61 | + check_import = ( |
| 62 | + str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True |
| 63 | + ) |
| 64 | + if not check_import: |
| 65 | + return data |
| 66 | + try: |
| 67 | + importlib.import_module(import_name) |
| 68 | + except ImportError: |
| 69 | + raise ConfigError( |
| 70 | + f"Failed to import the '{engine_type}' engine library. Please run `pip install \"sqlmesh[{extra_name}]\"`." |
| 71 | + ) |
| 72 | + |
| 73 | + return data |
| 74 | + |
| 75 | + return validate |
| 76 | + |
| 77 | + |
52 | 78 | class ConnectionConfig(abc.ABC, BaseConfig): |
53 | 79 | type_: str |
54 | 80 | concurrent_tasks: int |
@@ -428,6 +454,7 @@ class SnowflakeConnectionConfig(ConnectionConfig): |
428 | 454 | type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") |
429 | 455 |
|
430 | 456 | _concurrent_tasks_validator = concurrent_tasks_validator |
| 457 | + _engine_import_validator = _get_engine_import_validator("snowflake", "snowflake") |
431 | 458 |
|
432 | 459 | @model_validator(mode="before") |
433 | 460 | def _validate_authenticator(cls, data: t.Any) -> t.Any: |
@@ -621,6 +648,7 @@ class DatabricksConnectionConfig(ConnectionConfig): |
621 | 648 |
|
622 | 649 | _concurrent_tasks_validator = concurrent_tasks_validator |
623 | 650 | _http_headers_validator = http_headers_validator |
| 651 | + _engine_import_validator = _get_engine_import_validator("databricks", "databricks") |
624 | 652 |
|
625 | 653 | @model_validator(mode="before") |
626 | 654 | def _databricks_connect_validator(cls, data: t.Any) -> t.Any: |
@@ -873,6 +901,8 @@ class BigQueryConnectionConfig(ConnectionConfig): |
873 | 901 |
|
874 | 902 | type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") |
875 | 903 |
|
| 904 | + _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") |
| 905 | + |
876 | 906 | @field_validator("execution_project") |
877 | 907 | def validate_execution_project( |
878 | 908 | cls, |
@@ -1015,6 +1045,10 @@ class GCPPostgresConnectionConfig(ConnectionConfig): |
1015 | 1045 | register_comments: bool = True |
1016 | 1046 | pre_ping: bool = True |
1017 | 1047 |
|
| 1048 | + _engine_import_validator = _get_engine_import_validator( |
| 1049 | + "google.cloud.sql", "gcp_postgres", "gcppostgres" |
| 1050 | + ) |
| 1051 | + |
1018 | 1052 | @model_validator(mode="before") |
1019 | 1053 | def _validate_auth_method(cls, data: t.Any) -> t.Any: |
1020 | 1054 | if not isinstance(data, dict): |
@@ -1142,6 +1176,8 @@ class RedshiftConnectionConfig(ConnectionConfig): |
1142 | 1176 |
|
1143 | 1177 | type_: t.Literal["redshift"] = Field(alias="type", default="redshift") |
1144 | 1178 |
|
| 1179 | + _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") |
| 1180 | + |
1145 | 1181 | @property |
1146 | 1182 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1147 | 1183 | return { |
@@ -1201,6 +1237,8 @@ class PostgresConnectionConfig(ConnectionConfig): |
1201 | 1237 |
|
1202 | 1238 | type_: t.Literal["postgres"] = Field(alias="type", default="postgres") |
1203 | 1239 |
|
| 1240 | + _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") |
| 1241 | + |
1204 | 1242 | @property |
1205 | 1243 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1206 | 1244 | return { |
@@ -1252,6 +1290,8 @@ class MySQLConnectionConfig(ConnectionConfig): |
1252 | 1290 |
|
1253 | 1291 | type_: t.Literal["mysql"] = Field(alias="type", default="mysql") |
1254 | 1292 |
|
| 1293 | + _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") |
| 1294 | + |
1255 | 1295 | @property |
1256 | 1296 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1257 | 1297 | connection_keys = { |
@@ -1302,6 +1342,8 @@ class MSSQLConnectionConfig(ConnectionConfig): |
1302 | 1342 |
|
1303 | 1343 | type_: t.Literal["mssql"] = Field(alias="type", default="mssql") |
1304 | 1344 |
|
| 1345 | + _engine_import_validator = _get_engine_import_validator("pymssql", "mssql") |
| 1346 | + |
1305 | 1347 | @property |
1306 | 1348 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1307 | 1349 | return { |
@@ -1357,6 +1399,8 @@ class SparkConnectionConfig(ConnectionConfig): |
1357 | 1399 |
|
1358 | 1400 | type_: t.Literal["spark"] = Field(alias="type", default="spark") |
1359 | 1401 |
|
| 1402 | + _engine_import_validator = _get_engine_import_validator("pyspark", "spark") |
| 1403 | + |
1360 | 1404 | @property |
1361 | 1405 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1362 | 1406 | return { |
@@ -1473,6 +1517,8 @@ class TrinoConnectionConfig(ConnectionConfig): |
1473 | 1517 |
|
1474 | 1518 | type_: t.Literal["trino"] = Field(alias="type", default="trino") |
1475 | 1519 |
|
| 1520 | + _engine_import_validator = _get_engine_import_validator("trino", "trino") |
| 1521 | + |
1476 | 1522 | @field_validator("schema_location_mapping", mode="before") |
1477 | 1523 | @classmethod |
1478 | 1524 | def _validate_regex_keys( |
@@ -1623,6 +1669,8 @@ class ClickhouseConnectionConfig(ConnectionConfig): |
1623 | 1669 |
|
1624 | 1670 | type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") |
1625 | 1671 |
|
| 1672 | + _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") |
| 1673 | + |
1626 | 1674 | @property |
1627 | 1675 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1628 | 1676 | kwargs = { |
@@ -1727,6 +1775,8 @@ class AthenaConnectionConfig(ConnectionConfig): |
1727 | 1775 |
|
1728 | 1776 | type_: t.Literal["athena"] = Field(alias="type", default="athena") |
1729 | 1777 |
|
| 1778 | + _engine_import_validator = _get_engine_import_validator("pyathena", "athena") |
| 1779 | + |
1730 | 1780 | @model_validator(mode="after") |
1731 | 1781 | def _root_validator(self) -> Self: |
1732 | 1782 | work_group = self.work_group |
@@ -1793,6 +1843,8 @@ class RisingwaveConnectionConfig(ConnectionConfig): |
1793 | 1843 |
|
1794 | 1844 | type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") |
1795 | 1845 |
|
| 1846 | + _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") |
| 1847 | + |
1796 | 1848 | @property |
1797 | 1849 | def _connection_kwargs_keys(self) -> t.Set[str]: |
1798 | 1850 | return { |
|
0 commit comments