From 27188559d8c3ca2bbba36a0edce0aedca3b27885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Fri, 6 Jun 2025 19:20:00 +0000 Subject: [PATCH 1/7] feat: Add optional dependencies for azuresql-odbc and mssql-odbc --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b172c86375..160b1be786 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ classifiers = [ [project.optional-dependencies] athena = ["PyAthena[Pandas]"] azuresql = ["pymssql"] +azuresql-odbc = ["pyodbc"] bigquery = [ "google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage" @@ -104,6 +105,7 @@ gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub~=2.5.0"] llm = ["langchain", "openai"] mssql = ["pymssql"] +mssql-odbc = ["pyodbc"] mysql = ["pymysql"] mwaa = ["boto3"] postgres = ["psycopg2"] @@ -203,6 +205,7 @@ module = [ "databricks_cli.*", "mysql.*", "pymssql.*", + "pyodbc.*", "psycopg2.*", "langchain.*", "pytest_lazyfixture.*", From ccaff114783e9ce827dcee364fff6153c7fe433f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Fri, 6 Jun 2025 19:20:19 +0000 Subject: [PATCH 2/7] feat: Update Azure SQL and MSSQL documentation to include ODBC authentication options --- docs/integrations/engines/azuresql.md | 16 +++++++++++----- docs/integrations/engines/mssql.md | 12 ++++++++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/docs/integrations/engines/azuresql.md b/docs/integrations/engines/azuresql.md index e9b97abaa1..5b54ffa9c6 100644 --- a/docs/integrations/engines/azuresql.md +++ b/docs/integrations/engines/azuresql.md @@ -2,15 +2,18 @@ [Azure SQL](https://azure.microsoft.com/en-us/products/azure-sql) is "a family of managed, secure, and intelligent products that use the SQL Server database engine in the Azure cloud." -The Azure SQL adapter only supports authentication with a username and password. It does not support authentication with Microsoft Entra or Azure Active Directory. - ## Local/Built-in Scheduler **Engine Adapter Type**: `azuresql` ### Installation +#### User / Password Authentication: ``` pip install "sqlmesh[azuresql]" ``` +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[azuresql-odbc]" +``` ### Connection options @@ -18,8 +21,8 @@ pip install "sqlmesh[azuresql]" | ----------------- | ---------------------------------------------------------------- | :----------: | :------: | | `type` | Engine type name - must be `azuresql` | string | Y | | `host` | The hostname of the Azure SQL server | string | Y | -| `user` | The username to use for authentication with the Azure SQL server | string | N | -| `password` | The password to use for authentication with the Azure SQL server | string | N | +| `user` | The username / client ID to use for authentication with the Azure SQL server | string | N | +| `password` | The password / client secret to use for authentication with the Azure SQL server | string | N | | `port` | The port number of the Azure SQL server | int | N | | `database` | The target database | string | N | | `charset` | The character set used for the connection | string | N | @@ -27,4 +30,7 @@ pip install "sqlmesh[azuresql]" | `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | | `appname` | The application name to use for the connection | string | N | | `conn_properties` | The list of connection properties | list[string] | N | -| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | \ No newline at end of file +| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pymssql | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | \ No newline at end of file diff --git a/docs/integrations/engines/mssql.md b/docs/integrations/engines/mssql.md index 1650319d07..f06b5f1387 100644 --- a/docs/integrations/engines/mssql.md +++ b/docs/integrations/engines/mssql.md @@ -4,9 +4,14 @@ **Engine Adapter Type**: `mssql` ### Installation +#### User / Password Authentication: ``` pip install "sqlmesh[mssql]" ``` +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[mssql-odbc]" +``` ### Connection options @@ -14,8 +19,8 @@ pip install "sqlmesh[mssql]" | ----------------- | ------------------------------------------------------------ | :----------: | :------: | | `type` | Engine type name - must be `mssql` | string | Y | | `host` | The hostname of the MSSQL server | string | Y | -| `user` | The username to use for authentication with the MSSQL server | string | N | -| `password` | The password to use for authentication with the MSSQL server | string | N | +| `user` | The username / client id to use for authentication with the MSSQL server | string | N | +| `password` | The password / client secret to use for authentication with the MSSQL server | string | N | | `port` | The port number of the MSSQL server | int | N | | `database` | The target database | string | N | | `charset` | The character set used for the connection | string | N | @@ -24,3 +29,6 @@ pip install "sqlmesh[mssql]" | `appname` | The application name to use for the connection | string | N | | `conn_properties` | The list of connection properties | list[string] | N | | `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pymssql | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | \ No newline at end of file From ea6bc06c22932636be2d4fc9d3f63ee41ff372bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Fri, 6 Jun 2025 19:31:16 +0000 Subject: [PATCH 3/7] feat: Enhance MSSQL connection configuration with ODBC options and improve bulk copy handling --- sqlmesh/core/config/connection.py | 97 +++++++++++++++++++++++++++- sqlmesh/core/engine_adapter/mssql.py | 4 ++ 2 files changed, 98 insertions(+), 3 deletions(-) diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index e7e138c908..ebc491e04e 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1422,6 +1422,16 @@ class MSSQLConnectionConfig(ConnectionConfig): autocommit: t.Optional[bool] = False tds_version: t.Optional[str] = None + # Driver options + driver: t.Literal["pymssql", "pyodbc"] = "pymssql" + # PyODBC specific options + driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" + trust_server_certificate: t.Optional[bool] = None + encrypt: t.Optional[bool] = None + # Dictionary of arbitrary ODBC connection properties + # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute + odbc_properties: t.Optional[t.Dict[str, t.Any]] = None + concurrent_tasks: int = 4 register_comments: bool = True pre_ping: bool = True @@ -1432,7 +1442,7 @@ class MSSQLConnectionConfig(ConnectionConfig): @property def _connection_kwargs_keys(self) -> t.Set[str]: - return { + base_keys = { "host", "user", "password", @@ -1447,15 +1457,96 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "tds_version", } + if self.driver == "pyodbc": + base_keys.update( + { + "driver_name", + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + ) + # Remove pymssql-specific parameters + base_keys.discard("tds_version") + base_keys.discard("conn_properties") + + return base_keys + @property def _engine_adapter(self) -> t.Type[EngineAdapter]: return engine_adapter.MSSQLEngineAdapter @property def _connection_factory(self) -> t.Callable: - import pymssql + if self.driver == "pymssql": + import pymssql + + return pymssql.connect + + import pyodbc + + def connect(**kwargs: t.Any) -> t.Callable: + # Extract parameters for connection string + host = kwargs.pop("host") + port = kwargs.pop("port", 1433) + database = kwargs.pop("database", "") + user = kwargs.pop("user", None) + password = kwargs.pop("password", None) + driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") + trust_server_certificate = kwargs.pop("trust_server_certificate", False) + encrypt = kwargs.pop("encrypt", True) + login_timeout = kwargs.pop("login_timeout", 60) + + # Build connection string + conn_str_parts = [ + f"DRIVER={{{driver_name}}}", + f"SERVER={host},{port}", + ] + + if database: + conn_str_parts.append(f"DATABASE={database}") + + # Add security options + conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") + if trust_server_certificate: + conn_str_parts.append("TrustServerCertificate=YES") + + conn_str_parts.append(f"Connection Timeout={login_timeout}") + + # Standard SQL Server authentication + if user: + conn_str_parts.append(f"UID={user}") + if password: + conn_str_parts.append(f"PWD={password}") + + # Add any additional ODBC properties from the odbc_properties dictionary + if self.odbc_properties: + for key, value in self.odbc_properties.items(): + # Skip properties that we've already set above + if key.lower() in ( + "driver", + "server", + "database", + "uid", + "pwd", + "encrypt", + "trustservercertificate", + "connection timeout", + ): + continue - return pymssql.connect + # Handle boolean values properly + if isinstance(value, bool): + conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") + else: + conn_str_parts.append(f"{key}={value}") + + # Create the connection string + conn_str = ";".join(conn_str_parts) + + return pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) + + return connect @property def _extra_engine_config(self) -> t.Dict[str, t.Any]: diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 796bb87960..40649f3c2d 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -219,6 +219,10 @@ def _df_to_source_queries( assert isinstance(df, pd.DataFrame) temp_table = self._get_temp_table(target_table or "pandas") + # Return the superclass implementation if the connection pool doesn't support bulk_copy + if not hasattr(self._connection_pool.get(), "bulk_copy"): + return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + def query_factory() -> Query: # It is possible for the factory to be called multiple times and if so then the temp table will already # be created so we skip creating again. This means we are assuming the first call is the same result From 90fdaee0d3224c61965d0915af533cfbac006682 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Fri, 6 Jun 2025 22:16:41 +0000 Subject: [PATCH 4/7] feat: Implement MSSQL engine import validator for driver configuration and add corresponding tests --- sqlmesh/core/config/connection.py | 41 ++++++++++++++++++++- tests/core/test_connection_config.py | 53 ++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index ebc491e04e..e270509b40 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1438,7 +1438,46 @@ class MSSQLConnectionConfig(ConnectionConfig): type_: t.Literal["mssql"] = Field(alias="type", default="mssql") - _engine_import_validator = _get_engine_import_validator("pymssql", "mssql") + @model_validator(mode="before") + @classmethod + def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + check_import = str_to_bool(str(data.pop("check_import", True))) + if not check_import: + return data + + driver = data.get("driver", "pymssql") + + try: + if driver == "pymssql": + importlib.import_module("pymssql") + elif driver == "pyodbc": + importlib.import_module("pyodbc") + else: + raise ValueError(f"Unsupported driver: {driver}") + except ImportError: + if debug_mode_enabled(): + raise + + logger.exception("Failed to import the MSSQL engine library") + + if driver == "pymssql": + extra_name = "mssql" + elif driver == "pyodbc": + extra_name = "mssql-odbc" + else: + extra_name = "mssql" + + raise ConfigError( + f"Failed to import the '{driver}' library for MSSQL connections. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + f'running: `pip install "sqlmesh[{extra_name}]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ) + + return data @property def _connection_kwargs_keys(self) -> t.Set[str]: diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index d106559b67..996d52d3ad 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -4,6 +4,7 @@ import pytest from _pytest.fixtures import FixtureRequest +from unittest.mock import patch from sqlmesh.core.config.connection import ( BigQueryConnectionConfig, @@ -19,6 +20,7 @@ SnowflakeConnectionConfig, TrinoAuthenticationMethod, AthenaConnectionConfig, + MSSQLConnectionConfig, _connection_config_validator, _get_engine_import_validator, ) @@ -1127,3 +1129,54 @@ class TestConfigC(PydanticModel): _engine_import_validator = _get_engine_import_validator("sqlmesh", "bigquery") TestConfigC() + + +def test_mssql_engine_import_validator(): + """Test that MSSQL import validator respects driver configuration.""" + with pytest.raises( + ConfigError, + match=re.escape( + "Failed to import the 'pyodbc' library for MSSQL connections. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + 'running: `pip install "sqlmesh[mssql-odbc]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ), + ): + # Test PyODBC driver suggests mssql-odbc extra + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'pyodbc'") + MSSQLConnectionConfig(host="localhost", driver="pyodbc") + + with pytest.raises( + ConfigError, + match=re.escape( + "Failed to import the 'pymssql' library for MSSQL connections. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + 'running: `pip install "sqlmesh[mssql]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ), + ): + # Test PyMSSQL driver suggests mssql extra + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'pymssql'") + MSSQLConnectionConfig(host="localhost", driver="pymssql") + + with pytest.raises( + ConfigError, + match=re.escape( + "Failed to import the 'pymssql' library for MSSQL connections. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + 'running: `pip install "sqlmesh[mssql]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ), + ): + # Test default driver (pymssql) suggests mssql extra + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'pymssql'") + MSSQLConnectionConfig(host="localhost") # No driver specified + + # Test successful import doesn't raise exception + with patch("importlib.import_module") as mock_import: + mock_import.return_value = None + config = MSSQLConnectionConfig(host="localhost", driver="pyodbc") + assert config.driver == "pyodbc" From eb1f1e0cc031caa880624515e5370ae60ecafa34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Fri, 6 Jun 2025 22:57:17 +0000 Subject: [PATCH 5/7] feat: Add tests for MSSQL connection configuration and parameter validation --- tests/core/test_connection_config.py | 235 +++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 996d52d3ad..ba0b46c804 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1180,3 +1180,238 @@ def test_mssql_engine_import_validator(): mock_import.return_value = None config = MSSQLConnectionConfig(host="localhost", driver="pyodbc") assert config.driver == "pyodbc" + + +def test_mssql_connection_config_parameter_validation(make_config): + """Test MSSQL connection config parameter validation.""" + # Test default driver is pymssql + config = make_config(type="mssql", host="localhost", check_import=False) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "pymssql" + + # Test explicit pyodbc driver + config = make_config(type="mssql", host="localhost", driver="pyodbc", check_import=False) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "pyodbc" + + # Test explicit pymssql driver + config = make_config(type="mssql", host="localhost", driver="pymssql", check_import=False) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "pymssql" + + # Test pyodbc specific parameters + config = make_config( + type="mssql", + host="localhost", + driver="pyodbc", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=False, + odbc_properties={"Authentication": "ActiveDirectoryServicePrincipal"}, + check_import=False, + ) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver_name == "ODBC Driver 18 for SQL Server" + assert config.trust_server_certificate is True + assert config.encrypt is False + assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + + # Test pymssql specific parameters + config = make_config( + type="mssql", + host="localhost", + driver="pymssql", + tds_version="7.4", + conn_properties=["SET ANSI_NULLS ON"], + check_import=False, + ) + assert isinstance(config, MSSQLConnectionConfig) + assert config.tds_version == "7.4" + assert config.conn_properties == ["SET ANSI_NULLS ON"] + + +def test_mssql_connection_kwargs_keys(): + """Test _connection_kwargs_keys returns correct keys for each driver variant.""" + # Test pymssql driver keys + config = MSSQLConnectionConfig(host="localhost", driver="pymssql", check_import=False) + pymssql_keys = config._connection_kwargs_keys + expected_pymssql_keys = { + "password", + "user", + "database", + "host", + "timeout", + "login_timeout", + "charset", + "appname", + "port", + "tds_version", + "conn_properties", + "autocommit", + } + assert pymssql_keys == expected_pymssql_keys + + # Test pyodbc driver keys + config = MSSQLConnectionConfig(host="localhost", driver="pyodbc", check_import=False) + pyodbc_keys = config._connection_kwargs_keys + expected_pyodbc_keys = { + "password", + "user", + "database", + "host", + "timeout", + "login_timeout", + "charset", + "appname", + "port", + "autocommit", + "driver_name", + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + assert pyodbc_keys == expected_pyodbc_keys + + # Verify pyodbc keys don't include pymssql-specific parameters + assert "tds_version" not in pyodbc_keys + assert "conn_properties" not in pyodbc_keys + + +def test_mssql_pyodbc_connection_string_generation(): + """Test pyodbc.connect gets invoked with the correct ODBC connection string.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Mock the return value to have the methods we need + mock_connection = mock_pyodbc_connect.return_value + + # Create a pyodbc config + config = MSSQLConnectionConfig( + host="testserver.database.windows.net", + port=1433, + database="testdb", + user="testuser", + password="testpass", + driver="pyodbc", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=True, + login_timeout=30, + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify pyodbc.connect was called with the correct connection string + mock_pyodbc_connect.assert_called_once() + call_args = mock_pyodbc_connect.call_args + + # Check the connection string (first argument) + conn_str = call_args[0][0] + expected_parts = [ + "DRIVER={ODBC Driver 18 for SQL Server}", + "SERVER=testserver.database.windows.net,1433", + "DATABASE=testdb", + "Encrypt=YES", + "TrustServerCertificate=YES", + "Connection Timeout=30", + "UID=testuser", + "PWD=testpass", + ] + + for part in expected_parts: + assert part in conn_str + + # Check autocommit parameter + assert call_args[1]["autocommit"] is False + + +def test_mssql_pyodbc_connection_string_with_odbc_properties(): + """Test pyodbc connection string includes custom ODBC properties.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Create a pyodbc config with custom ODBC properties + config = MSSQLConnectionConfig( + host="testserver.database.windows.net", + database="testdb", + user="client-id", + password="client-secret", + driver="pyodbc", + odbc_properties={ + "Authentication": "ActiveDirectoryServicePrincipal", + "ClientCertificate": "/path/to/cert.pem", + "TrustServerCertificate": "NO", # This should be ignored since we set it explicitly + }, + trust_server_certificate=True, # This should take precedence + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify pyodbc.connect was called + mock_pyodbc_connect.assert_called_once() + conn_str = mock_pyodbc_connect.call_args[0][0] + + # Check that custom ODBC properties are included + assert "Authentication=ActiveDirectoryServicePrincipal" in conn_str + assert "ClientCertificate=/path/to/cert.pem" in conn_str + + # Verify that explicit trust_server_certificate takes precedence + assert "TrustServerCertificate=YES" in conn_str + + # Should not have the conflicting property from odbc_properties + assert conn_str.count("TrustServerCertificate") == 1 + + +def test_mssql_pyodbc_connection_string_minimal(): + """Test pyodbc connection string with minimal configuration.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + config = MSSQLConnectionConfig( + host="localhost", + driver="pyodbc", + autocommit=True, + check_import=False, + ) + + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + mock_pyodbc_connect.assert_called_once() + conn_str = mock_pyodbc_connect.call_args[0][0] + + # Check basic required parts + assert "DRIVER={ODBC Driver 18 for SQL Server}" in conn_str + assert "SERVER=localhost,1433" in conn_str + assert "Encrypt=YES" in conn_str # Default encrypt=True + assert "Connection Timeout=60" in conn_str # Default timeout + + # Check autocommit parameter + assert mock_pyodbc_connect.call_args[1]["autocommit"] is True + + +def test_mssql_pymssql_connection_factory(): + """Test pymssql connection factory returns correct function.""" + # Mock the import of pymssql at the module level + import sys + from unittest.mock import MagicMock + + # Create a mock pymssql module + mock_pymssql = MagicMock() + sys.modules["pymssql"] = mock_pymssql + + try: + config = MSSQLConnectionConfig( + host="localhost", + driver="pymssql", + check_import=False, + ) + + factory = config._connection_factory + + # Verify the factory returns pymssql.connect + assert factory is mock_pymssql.connect + finally: + # Clean up the mock module + if "pymssql" in sys.modules: + del sys.modules["pymssql"] From 34d1e78c4ce89495979b6dbf65f6d8f2ba66d8b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Mon, 9 Jun 2025 13:13:06 +0000 Subject: [PATCH 6/7] feat: Refactor MSSQL connection driver import handling --- sqlmesh/core/config/connection.py | 42 ++++++++++------------------ tests/core/test_connection_config.py | 39 ++++++-------------------- 2 files changed, 22 insertions(+), 59 deletions(-) diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index e270509b40..37c000df22 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1444,40 +1444,26 @@ def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: if not isinstance(data, dict): return data - check_import = str_to_bool(str(data.pop("check_import", True))) - if not check_import: - return data - driver = data.get("driver", "pymssql") - try: - if driver == "pymssql": - importlib.import_module("pymssql") - elif driver == "pyodbc": - importlib.import_module("pyodbc") - else: - raise ValueError(f"Unsupported driver: {driver}") - except ImportError: - if debug_mode_enabled(): - raise + # Define the mapping of driver to import module and extra name + driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} - logger.exception("Failed to import the MSSQL engine library") + if driver not in driver_configs: + raise ValueError(f"Unsupported driver: {driver}") - if driver == "pymssql": - extra_name = "mssql" - elif driver == "pyodbc": - extra_name = "mssql-odbc" - else: - extra_name = "mssql" + import_module, extra_name = driver_configs[driver] - raise ConfigError( - f"Failed to import the '{driver}' library for MSSQL connections. This may be due to a missing " - "or incompatible installation. Please ensure the required dependency is installed by " - f'running: `pip install "sqlmesh[{extra_name}]"`. For more details, check the logs ' - "in the 'logs/' folder, or rerun the command with the '--debug' flag." - ) + # Conditionally delegate to the existing _get_engine_import_validator + # Create a validator for the specific driver and call its inner function + validator_func = _get_engine_import_validator(import_module, driver, extra_name) - return data + # Extract the inner validate function from the decorated validator + # The validator_func has a __wrapped__ attribute that contains the original function + inner_validate = getattr(validator_func, "__wrapped__", validator_func) + + # Call the inner validation function directly + return inner_validate(cls, data) @property def _connection_kwargs_keys(self) -> t.Set[str]: diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index ba0b46c804..ba33cb010b 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1133,49 +1133,26 @@ class TestConfigC(PydanticModel): def test_mssql_engine_import_validator(): """Test that MSSQL import validator respects driver configuration.""" - with pytest.raises( - ConfigError, - match=re.escape( - "Failed to import the 'pyodbc' library for MSSQL connections. This may be due to a missing " - "or incompatible installation. Please ensure the required dependency is installed by " - 'running: `pip install "sqlmesh[mssql-odbc]"`. For more details, check the logs ' - "in the 'logs/' folder, or rerun the command with the '--debug' flag." - ), - ): - # Test PyODBC driver suggests mssql-odbc extra + + # Test PyODBC driver suggests mssql-odbc extra when import fails + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql-odbc\]\""): with patch("importlib.import_module") as mock_import: mock_import.side_effect = ImportError("No module named 'pyodbc'") MSSQLConnectionConfig(host="localhost", driver="pyodbc") - with pytest.raises( - ConfigError, - match=re.escape( - "Failed to import the 'pymssql' library for MSSQL connections. This may be due to a missing " - "or incompatible installation. Please ensure the required dependency is installed by " - 'running: `pip install "sqlmesh[mssql]"`. For more details, check the logs ' - "in the 'logs/' folder, or rerun the command with the '--debug' flag." - ), - ): - # Test PyMSSQL driver suggests mssql extra + # Test PyMSSQL driver suggests mssql extra when import fails + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql\]\""): with patch("importlib.import_module") as mock_import: mock_import.side_effect = ImportError("No module named 'pymssql'") MSSQLConnectionConfig(host="localhost", driver="pymssql") - with pytest.raises( - ConfigError, - match=re.escape( - "Failed to import the 'pymssql' library for MSSQL connections. This may be due to a missing " - "or incompatible installation. Please ensure the required dependency is installed by " - 'running: `pip install "sqlmesh[mssql]"`. For more details, check the logs ' - "in the 'logs/' folder, or rerun the command with the '--debug' flag." - ), - ): - # Test default driver (pymssql) suggests mssql extra + # Test default driver (pymssql) suggests mssql extra when import fails + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql\]\""): with patch("importlib.import_module") as mock_import: mock_import.side_effect = ImportError("No module named 'pymssql'") MSSQLConnectionConfig(host="localhost") # No driver specified - # Test successful import doesn't raise exception + # Test successful import works without error with patch("importlib.import_module") as mock_import: mock_import.return_value = None config = MSSQLConnectionConfig(host="localhost", driver="pyodbc") From 9c5889bb0f7bd2d75f4ee570299e6dc45c86b575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Thal=C3=A9n?= Date: Wed, 11 Jun 2025 21:39:39 +0000 Subject: [PATCH 7/7] feat: Enhance engine import validator to support optional decoration --- sqlmesh/core/config/connection.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 37c000df22..691b3a7731 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -55,11 +55,10 @@ def _get_engine_import_validator( - import_name: str, engine_type: str, extra_name: t.Optional[str] = None + import_name: str, engine_type: str, extra_name: t.Optional[str] = None, decorate: bool = True ) -> t.Callable: extra_name = extra_name or engine_type - @model_validator(mode="before") def validate(cls: t.Any, data: t.Any) -> t.Any: check_import = ( str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True @@ -83,7 +82,7 @@ def validate(cls: t.Any, data: t.Any) -> t.Any: return data - return validate + return model_validator(mode="before")(validate) if decorate else validate class ConnectionConfig(abc.ABC, BaseConfig): @@ -1454,16 +1453,14 @@ def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: import_module, extra_name = driver_configs[driver] - # Conditionally delegate to the existing _get_engine_import_validator - # Create a validator for the specific driver and call its inner function - validator_func = _get_engine_import_validator(import_module, driver, extra_name) - - # Extract the inner validate function from the decorated validator - # The validator_func has a __wrapped__ attribute that contains the original function - inner_validate = getattr(validator_func, "__wrapped__", validator_func) + # Use _get_engine_import_validator with decorate=False to get the raw validation function + # This avoids the __wrapped__ issue in Python 3.9 + validator_func = _get_engine_import_validator( + import_module, driver, extra_name, decorate=False + ) - # Call the inner validation function directly - return inner_validate(cls, data) + # Call the raw validation function directly + return validator_func(cls, data) @property def _connection_kwargs_keys(self) -> t.Set[str]: