Skip to content

Commit e3bf211

Browse files
Enhances MSSQL connection with pyodbc support v2 (#4686)
1 parent 11ce8c8 commit e3bf211

6 files changed

Lines changed: 413 additions & 14 deletions

File tree

docs/integrations/engines/azuresql.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,35 @@
22

33
[Azure SQL](https://azure.microsoft.com/en-us/products/azure-sql) is "a family of managed, secure, and intelligent products that use the SQL Server database engine in the Azure cloud."
44

5-
The Azure SQL adapter only supports authentication with a username and password. It does not support authentication with Microsoft Entra or Azure Active Directory.
6-
75
## Local/Built-in Scheduler
86
**Engine Adapter Type**: `azuresql`
97

108
### Installation
9+
#### User / Password Authentication:
1110
```
1211
pip install "sqlmesh[azuresql]"
1312
```
13+
#### Microsoft Entra ID / Azure Active Directory Authentication:
14+
```
15+
pip install "sqlmesh[azuresql-odbc]"
16+
```
1417

1518
### Connection options
1619

1720
| Option | Description | Type | Required |
1821
| ----------------- | ---------------------------------------------------------------- | :----------: | :------: |
1922
| `type` | Engine type name - must be `azuresql` | string | Y |
2023
| `host` | The hostname of the Azure SQL server | string | Y |
21-
| `user` | The username to use for authentication with the Azure SQL server | string | N |
22-
| `password` | The password to use for authentication with the Azure SQL server | string | N |
24+
| `user` | The username / client ID to use for authentication with the Azure SQL server | string | N |
25+
| `password` | The password / client secret to use for authentication with the Azure SQL server | string | N |
2326
| `port` | The port number of the Azure SQL server | int | N |
2427
| `database` | The target database | string | N |
2528
| `charset` | The character set used for the connection | string | N |
2629
| `timeout` | The query timeout in seconds. Default: no timeout | int | N |
2730
| `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N |
2831
| `appname` | The application name to use for the connection | string | N |
2932
| `conn_properties` | The list of connection properties | list[string] | N |
30-
| `autocommit` | Is autocommit mode enabled. Default: false | bool | N |
33+
| `autocommit` | Is autocommit mode enabled. Default: false | bool | N |
34+
| `driver` | The driver to use for the connection. Default: pymssql | string | N |
35+
| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N |
36+
| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N |

docs/integrations/engines/mssql.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,23 @@
44
**Engine Adapter Type**: `mssql`
55

66
### Installation
7+
#### User / Password Authentication:
78
```
89
pip install "sqlmesh[mssql]"
910
```
11+
#### Microsoft Entra ID / Azure Active Directory Authentication:
12+
```
13+
pip install "sqlmesh[mssql-odbc]"
14+
```
1015

1116
### Connection options
1217

1318
| Option | Description | Type | Required |
1419
| ----------------- | ------------------------------------------------------------ | :----------: | :------: |
1520
| `type` | Engine type name - must be `mssql` | string | Y |
1621
| `host` | The hostname of the MSSQL server | string | Y |
17-
| `user` | The username to use for authentication with the MSSQL server | string | N |
18-
| `password` | The password to use for authentication with the MSSQL server | string | N |
22+
| `user` | The username / client id to use for authentication with the MSSQL server | string | N |
23+
| `password` | The password / client secret to use for authentication with the MSSQL server | string | N |
1924
| `port` | The port number of the MSSQL server | int | N |
2025
| `database` | The target database | string | N |
2126
| `charset` | The character set used for the connection | string | N |
@@ -24,3 +29,6 @@ pip install "sqlmesh[mssql]"
2429
| `appname` | The application name to use for the connection | string | N |
2530
| `conn_properties` | The list of connection properties | list[string] | N |
2631
| `autocommit` | Is autocommit mode enabled. Default: false | bool | N |
32+
| `driver` | The driver to use for the connection. Default: pymssql | string | N |
33+
| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N |
34+
| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N |

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ classifiers = [
3939
[project.optional-dependencies]
4040
athena = ["PyAthena[Pandas]"]
4141
azuresql = ["pymssql"]
42+
azuresql-odbc = ["pyodbc"]
4243
bigquery = [
4344
"google-cloud-bigquery[pandas]",
4445
"google-cloud-bigquery-storage"
@@ -104,6 +105,7 @@ gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"]
104105
github = ["PyGithub~=2.5.0"]
105106
llm = ["langchain", "openai"]
106107
mssql = ["pymssql"]
108+
mssql-odbc = ["pyodbc"]
107109
mysql = ["pymysql"]
108110
mwaa = ["boto3"]
109111
postgres = ["psycopg2"]
@@ -203,6 +205,7 @@ module = [
203205
"databricks_cli.*",
204206
"mysql.*",
205207
"pymssql.*",
208+
"pyodbc.*",
206209
"psycopg2.*",
207210
"langchain.*",
208211
"pytest_lazyfixture.*",

sqlmesh/core/config/connection.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,10 @@
5555

5656

5757
def _get_engine_import_validator(
58-
import_name: str, engine_type: str, extra_name: t.Optional[str] = None
58+
import_name: str, engine_type: str, extra_name: t.Optional[str] = None, decorate: bool = True
5959
) -> t.Callable:
6060
extra_name = extra_name or engine_type
6161

62-
@model_validator(mode="before")
6362
def validate(cls: t.Any, data: t.Any) -> t.Any:
6463
check_import = (
6564
str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True
@@ -83,7 +82,7 @@ def validate(cls: t.Any, data: t.Any) -> t.Any:
8382

8483
return data
8584

86-
return validate
85+
return model_validator(mode="before")(validate) if decorate else validate
8786

8887

8988
class ConnectionConfig(abc.ABC, BaseConfig):
@@ -1422,17 +1421,50 @@ class MSSQLConnectionConfig(ConnectionConfig):
14221421
autocommit: t.Optional[bool] = False
14231422
tds_version: t.Optional[str] = None
14241423

1424+
# Driver options
1425+
driver: t.Literal["pymssql", "pyodbc"] = "pymssql"
1426+
# PyODBC specific options
1427+
driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server"
1428+
trust_server_certificate: t.Optional[bool] = None
1429+
encrypt: t.Optional[bool] = None
1430+
# Dictionary of arbitrary ODBC connection properties
1431+
# See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute
1432+
odbc_properties: t.Optional[t.Dict[str, t.Any]] = None
1433+
14251434
concurrent_tasks: int = 4
14261435
register_comments: bool = True
14271436
pre_ping: bool = True
14281437

14291438
type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
14301439

1431-
_engine_import_validator = _get_engine_import_validator("pymssql", "mssql")
1440+
@model_validator(mode="before")
1441+
@classmethod
1442+
def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any:
1443+
if not isinstance(data, dict):
1444+
return data
1445+
1446+
driver = data.get("driver", "pymssql")
1447+
1448+
# Define the mapping of driver to import module and extra name
1449+
driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")}
1450+
1451+
if driver not in driver_configs:
1452+
raise ValueError(f"Unsupported driver: {driver}")
1453+
1454+
import_module, extra_name = driver_configs[driver]
1455+
1456+
# Use _get_engine_import_validator with decorate=False to get the raw validation function
1457+
# This avoids the __wrapped__ issue in Python 3.9
1458+
validator_func = _get_engine_import_validator(
1459+
import_module, driver, extra_name, decorate=False
1460+
)
1461+
1462+
# Call the raw validation function directly
1463+
return validator_func(cls, data)
14321464

14331465
@property
14341466
def _connection_kwargs_keys(self) -> t.Set[str]:
1435-
return {
1467+
base_keys = {
14361468
"host",
14371469
"user",
14381470
"password",
@@ -1447,15 +1479,96 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
14471479
"tds_version",
14481480
}
14491481

1482+
if self.driver == "pyodbc":
1483+
base_keys.update(
1484+
{
1485+
"driver_name",
1486+
"trust_server_certificate",
1487+
"encrypt",
1488+
"odbc_properties",
1489+
}
1490+
)
1491+
# Remove pymssql-specific parameters
1492+
base_keys.discard("tds_version")
1493+
base_keys.discard("conn_properties")
1494+
1495+
return base_keys
1496+
14501497
@property
14511498
def _engine_adapter(self) -> t.Type[EngineAdapter]:
14521499
return engine_adapter.MSSQLEngineAdapter
14531500

14541501
@property
14551502
def _connection_factory(self) -> t.Callable:
1456-
import pymssql
1503+
if self.driver == "pymssql":
1504+
import pymssql
1505+
1506+
return pymssql.connect
1507+
1508+
import pyodbc
1509+
1510+
def connect(**kwargs: t.Any) -> t.Callable:
1511+
# Extract parameters for connection string
1512+
host = kwargs.pop("host")
1513+
port = kwargs.pop("port", 1433)
1514+
database = kwargs.pop("database", "")
1515+
user = kwargs.pop("user", None)
1516+
password = kwargs.pop("password", None)
1517+
driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server")
1518+
trust_server_certificate = kwargs.pop("trust_server_certificate", False)
1519+
encrypt = kwargs.pop("encrypt", True)
1520+
login_timeout = kwargs.pop("login_timeout", 60)
1521+
1522+
# Build connection string
1523+
conn_str_parts = [
1524+
f"DRIVER={{{driver_name}}}",
1525+
f"SERVER={host},{port}",
1526+
]
1527+
1528+
if database:
1529+
conn_str_parts.append(f"DATABASE={database}")
1530+
1531+
# Add security options
1532+
conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}")
1533+
if trust_server_certificate:
1534+
conn_str_parts.append("TrustServerCertificate=YES")
1535+
1536+
conn_str_parts.append(f"Connection Timeout={login_timeout}")
1537+
1538+
# Standard SQL Server authentication
1539+
if user:
1540+
conn_str_parts.append(f"UID={user}")
1541+
if password:
1542+
conn_str_parts.append(f"PWD={password}")
1543+
1544+
# Add any additional ODBC properties from the odbc_properties dictionary
1545+
if self.odbc_properties:
1546+
for key, value in self.odbc_properties.items():
1547+
# Skip properties that we've already set above
1548+
if key.lower() in (
1549+
"driver",
1550+
"server",
1551+
"database",
1552+
"uid",
1553+
"pwd",
1554+
"encrypt",
1555+
"trustservercertificate",
1556+
"connection timeout",
1557+
):
1558+
continue
14571559

1458-
return pymssql.connect
1560+
# Handle boolean values properly
1561+
if isinstance(value, bool):
1562+
conn_str_parts.append(f"{key}={'YES' if value else 'NO'}")
1563+
else:
1564+
conn_str_parts.append(f"{key}={value}")
1565+
1566+
# Create the connection string
1567+
conn_str = ";".join(conn_str_parts)
1568+
1569+
return pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False))
1570+
1571+
return connect
14591572

14601573
@property
14611574
def _extra_engine_config(self) -> t.Dict[str, t.Any]:

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def _df_to_source_queries(
219219
assert isinstance(df, pd.DataFrame)
220220
temp_table = self._get_temp_table(target_table or "pandas")
221221

222+
# Return the superclass implementation if the connection pool doesn't support bulk_copy
223+
if not hasattr(self._connection_pool.get(), "bulk_copy"):
224+
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)
225+
222226
def query_factory() -> Query:
223227
# It is possible for the factory to be called multiple times and if so then the temp table will already
224228
# be created so we skip creating again. This means we are assuming the first call is the same result

0 commit comments

Comments
 (0)