From 28f9569f3acf827d7b82e87985a01af3b4ee9366 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma <8655500+bewithgaurav@users.noreply.github.com> Date: Mon, 25 May 2026 10:05:29 +0530 Subject: [PATCH 1/6] FEAT: Add ActiveDirectoryServicePrincipal support for bulk copy Wires a Python token-factory callback into the mssql-py-core connection context so bulk copy can authenticate with `Authentication=ActiveDirectoryServicePrincipal`. The callback is invoked by mssql-tds mid-handshake (FedAuth workflow 0x02), receives the STS URL from the server, parses the tenant_id from it, and uses `azure.identity.ClientSecretCredential` to acquire a JWT. Necessary because tenant_id is not known client-side until the server returns it during the handshake, so the pre-acquired-token model (Model A) used by Default / DeviceCode / Interactive / MSI cannot be used here. Builds on the shared module-level `_credential_cache` introduced for MSI in #573, keyed by ("serviceprincipal", tenant_id, client_id), so SP gets the same per-instance token reuse semantics as the other AD methods. `client_secret` is intentionally not in the cache key; credentials are looked up by identity, not secret. Other behaviors: - `_parse_tenant_id` rejects non-https / non-URL inputs so a malformed STS URL cannot silently become a tenant id. - Empty SPN from the server is rejected early with a clear message rather than producing scope="/.default". - The surfaced `RuntimeError` is intentionally generic; provider message stays in `logger.error` only so any sensitive provider text does not reach the user-facing exception chain. - Bulk copy `finally` cleanup also pops `entra_id_token_factory` from pycore_context. Tests (in tests/test_008_auth.py): - 7 `TestParseTenantId` cases: GUID/domain tenants, query string, extra path segments, empty, no path, bare-string rejection, path-only URL rejection, http scheme rejection. - 11 `TestServicePrincipalAuth` cases: factory shape, missing client_id/secret, UTF-16LE return, credential kwargs forwarded to `ClientSecretCredential`, scope construction with/without `/.default` suffix, unparseable STS URL, authentication error propagation, no provider-message leak into chained exception, empty-SPN rejection, per-tenant credential caching (asserts 1 construction across 3 calls for one tenant, 2 across two tenants). - `TestProcessAuthParameters` cases: SP leaves ODBC path alone, case-insensitive recognition. - `TestExtractAuthType.test_serviceprincipal`. Requires mssql-py-core 0.1.5+ for the `entra_id_token_factory` dict key wiring (ADO mssql-rs PR 7542). py-core pin bump will land in the release bundle. Partial fix for #534. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 5 + mssql_python/auth.py | 148 ++++++++++++++++++++- mssql_python/constants.py | 1 + mssql_python/cursor.py | 87 +++++++++---- tests/test_008_auth.py | 268 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 477 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 517a60bfc..4fa44f85b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal` + via an `entra_id_token_factory` callback registered on the mssql-py-core + connection. The callback is invoked by mssql-tds mid-handshake (FedAuth + workflow 0x02) so the tenant id can be resolved from the server-supplied + STS URL. Requires `mssql-py-core` 0.1.5+. Partial fix for #534. ### Changed - Improved error handling in the connection module. diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 9b488c6d4..ceded2dff 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -154,6 +154,135 @@ def _acquire_token( raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +def _parse_tenant_id(sts_url: str) -> Optional[str]: + """Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL. + + Expected formats: + https://login.microsoftonline.com// + https://login.microsoftonline.com//?... + https://login.microsoftonline.com/ + where is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``) + or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are + accepted by ``azure.identity.ClientSecretCredential``. + """ + # pylint: disable=import-outside-toplevel + from urllib.parse import urlparse + + try: + parsed = urlparse(sts_url) + except (ValueError, AttributeError): + return None + # Reject anything that isn't an https URL with a netloc. ``urlparse`` will + # happily put a bare string like ``"tenant-guid"`` into ``path``, which + # would then look like a valid tenant. Azure AD STS URLs are always https. + if parsed.scheme != "https" or not parsed.netloc: + return None + path = (parsed.path or "").strip("/") + if not path: + return None + first_segment = path.split("/", 1)[0] + return first_segment or None + + +class ServicePrincipalAuth: + """Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal. + + The bulkcopy path through mssql-py-core uses callback-based token + acquisition (FedAuth workflow ``0x02``) because tenant_id is only known + from the STS URL that the server returns during the TDS handshake. + """ + + @staticmethod + def make_token_factory(client_id: str, client_secret: str): + """Return a callable suitable for ``entra_id_token_factory``. + + Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``. + Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format). + + ``ClientSecretCredential`` instances are reused across calls via the + module-level ``_credential_cache``, keyed by + ``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's + in-memory token cache (which is per-credential-instance) actually + works across handshake retries, reconnects, and separate bulkcopy + invocations using the same identity. + """ + if not client_id: + raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)") + if not client_secret: + raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)") + + def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: + # pylint: disable=import-outside-toplevel,unused-argument + try: + from azure.identity import ClientSecretCredential + from azure.core.exceptions import ClientAuthenticationError + except ImportError as e: + raise RuntimeError( + "Azure authentication libraries are not installed. " + "Please install with: pip install azure-identity azure-core" + ) from e + + if not spn: + raise RuntimeError( + "ServicePrincipal token factory: empty SPN from server " + "(cannot construct token scope)" + ) + tenant_id = _parse_tenant_id(sts_url) + if not tenant_id: + raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}") + + logger.info( + "ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s", + tenant_id, + spn, + ) + try: + # Reuse the shared credential cache (introduced for MSI in PR #573) + # so SP credentials get the same per-instance token reuse semantics + # as the other AD methods. Key includes tenant_id so a server that + # somehow returns different tenants on different handshakes still + # gets distinct credentials. client_secret is intentionally NOT in + # the key — credentials are looked up by identity, not by secret; + # if the secret rotates, the closure will still hold the old one + # and AAD will reject the token, surfacing as ClientAuthenticationError. + cache_key = _credential_cache_key( + "serviceprincipal", + {"tenant_id": tenant_id, "client_id": client_id}, + ) + with _credential_cache_lock: + credential = _credential_cache.get(cache_key) + if credential is None: + credential = ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + _credential_cache[cache_key] = credential + # mssql-tds passes the resource SPN; azure-identity wants a scope. + scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default" + token = credential.get_token(scope).token + logger.info( + "ServicePrincipal token factory: token acquired, length=%d chars", + len(token), + ) + return token.encode("utf-16-le") + except ClientAuthenticationError as e: + # Keep the detailed provider error in debug logs only. The + # surfaced message is intentionally generic so that any + # secret-bearing provider text never reaches the user-facing + # exception chain. + logger.error( + "ServicePrincipal authentication failed: tenant=%s, error=%s", + tenant_id, + str(e), + ) + raise RuntimeError( + "ServicePrincipal authentication failed; " "see debug logs for provider details" + ) from None + + return _factory + + def _extract_msi_client_id(connection_string: str) -> Optional[str]: """Pull UID out of a connection string for user-assigned MSI. @@ -230,6 +359,17 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ # Managed identity authentication (system- or user-assigned) logger.debug("process_auth_parameters: Managed identity authentication detected") auth_type = "msi" + elif value_lower == AuthType.SERVICE_PRINCIPAL.value: + # ServicePrincipal authentication. ODBC (msodbcsql 17.3+) + # handles this natively for regular queries, so leave + # auth_type=None to let ODBC own the query path. + # Bulkcopy still needs the auth type — extract_auth_type() + # propagates it as "serviceprincipal" so the bulkcopy path + # can register an entra_id_token_factory callback (Model B, + # required because tenant_id is only known from the STS URL + # that the server returns during the FedAuth handshake). + logger.debug("process_auth_parameters: Service principal authentication detected") + auth_type = None modified_parameters.append(param) logger.debug( @@ -299,6 +439,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: AuthType.DEVICE_CODE.value: "devicecode", AuthType.DEFAULT.value: "default", AuthType.MSI.value: "msi", + AuthType.SERVICE_PRINCIPAL.value: "serviceprincipal", } for part in connection_string.split(";"): key, _, value = part.strip().partition("=") @@ -313,13 +454,6 @@ def process_connection_string( """ Process connection string and handle authentication. - NOTE: Returns a 4-tuple. Callers must unpack all four elements. - Destructuring with three names raises ``ValueError: too many values - to unpack``. The fourth element (``credential_kwargs``) is needed by - Connection.__init__ to persist credential constructor args (e.g. the - user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, - since UID is stripped from the sanitized connection string. - Args: connection_string: The connection string to process diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 5de02eceb..f9f9331db 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -338,6 +338,7 @@ class AuthType(Enum): DEVICE_CODE = "activedirectorydevicecode" DEFAULT = "activedirectorydefault" MSI = "activedirectorymsi" + SERVICE_PRINCIPAL = "activedirectoryserviceprincipal" class SQLTypes: diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index ece27c61e..9915eea24 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2933,31 +2933,60 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: - # Fresh token acquisition for mssql-py-core connection. credential - # kwargs (e.g. user-assigned MSI client_id) were captured by - # Connection.__init__ before remove_sensitive_params stripped UID - # from connection_str — re-parsing here would miss them. - from mssql_python.auth import AADAuth - - try: - raw_token = AADAuth.get_raw_token( + # Fresh token acquisition for mssql-py-core connection + from mssql_python.auth import AADAuth, ServicePrincipalAuth + + if self.connection._auth_type == "serviceprincipal": + # Model B: callback-based. tenant_id is only known from the + # STS URL the server returns mid-handshake, so we register a + # factory that py-core invokes during FedAuth (workflow 0x02). + client_id = params.get("uid", "") + client_secret = params.get("pwd", "") + if not client_id or not client_secret: + raise RuntimeError( + "Bulk copy with Authentication=ActiveDirectoryServicePrincipal " + "requires UID (client_id) and PWD (client_secret) in the " + "connection string." + ) + try: + factory = ServicePrincipalAuth.make_token_factory(client_id, client_secret) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to build ServicePrincipal token factory: {e}" + ) from e + pycore_context["entra_id_token_factory"] = factory + # Keep authentication/user_name/password in pycore_context — + # py-core's auth validator + transformer need them to resolve + # the auth method to ActiveDirectoryServicePrincipal before + # the factory is dispatched at handshake time. + logger.debug("Bulk copy: registered ServicePrincipal token factory") + else: + # Model A: pre-acquired token. Used for Default, DeviceCode, + # Interactive (non-Windows), MSI (system- or user-assigned), + # and any other AD method whose tenant_id is discoverable + # client-side via Azure Identity SDK. credential kwargs + # (e.g. user-assigned MSI client_id) were captured by + # Connection.__init__ before remove_sensitive_params stripped + # UID from connection_str — re-parsing here would miss them. + try: + raw_token = AADAuth.get_raw_token( + self.connection._auth_type, + self.connection._credential_kwargs, + ) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire Azure AD token " + f"for auth_type '{self.connection._auth_type}': {e}" + ) from e + pycore_context["access_token"] = raw_token + # Token replaces credential fields — py-core's validator rejects + # access_token combined with authentication/user_name/password. + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh Azure AD token for auth_type=%s", self.connection._auth_type, - self.connection._credential_kwargs, ) - except (RuntimeError, ValueError) as e: - raise RuntimeError( - f"Bulk copy failed: unable to acquire Azure AD token " - f"for auth_type '{self.connection._auth_type}': {e}" - ) from e - pycore_context["access_token"] = raw_token - # Token replaces credential fields — py-core's validator rejects - # access_token combined with authentication/user_name/password. - for key in ("authentication", "user_name", "password"): - pycore_context.pop(key, None) - logger.debug( - "Bulk copy: acquired fresh Azure AD token for auth_type=%s", - self.connection._auth_type, - ) pycore_connection = None pycore_cursor = None @@ -3007,9 +3036,17 @@ def bulkcopy( raise type(e)(str(e)) from None finally: - # Clear sensitive data to minimize memory exposure + # Clear sensitive data to minimize memory exposure. The + # entra_id_token_factory closure captures client_secret, so drop + # our dict reference to it (Rust still holds an Arc until the + # connection is dropped, but at least we don't keep an extra ref). if pycore_context: - for key in ("password", "user_name", "access_token"): + for key in ( + "password", + "user_name", + "access_token", + "entra_id_token_factory", + ): pycore_context.pop(key, None) # Clean up bulk copy resources for resource in (pycore_cursor, pycore_connection): diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f8df6f6f5..54f6236d7 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -11,6 +11,8 @@ from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, + ServicePrincipalAuth, + _parse_tenant_id, process_auth_parameters, remove_sensitive_params, get_auth_token, @@ -44,6 +46,20 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + class MockClientSecretCredential: + # Captures construction kwargs and get_token args so ServicePrincipal + # tests can assert the right tenant/client_id/secret/scope flowed + # through from the connection string + STS URL. + last_init_kwargs = None + last_scope = None + + def __init__(self, **kwargs): + MockClientSecretCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + MockClientSecretCredential.last_scope = scope + return MockToken() + class MockManagedIdentityCredential: # Captures construction kwargs so user-assigned MSI tests can assert # client_id was forwarded correctly. @@ -63,6 +79,7 @@ class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + ClientSecretCredential = MockClientSecretCredential ManagedIdentityCredential = MockManagedIdentityCredential class MockCore: @@ -100,6 +117,7 @@ def test_auth_type_constants(self): assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" assert AuthType.DEFAULT.value == "activedirectorydefault" assert AuthType.MSI.value == "activedirectorymsi" + assert AuthType.SERVICE_PRINCIPAL.value == "activedirectoryserviceprincipal" class TestAADAuth: @@ -330,6 +348,20 @@ def test_default_auth(self): _, auth_type = process_auth_parameters(params) assert auth_type == "default" + def test_service_principal_auth_leaves_odbc_path_alone(self): + """ServicePrincipal is handled natively by ODBC. process_auth_parameters + must return auth_type=None so the ODBC path doesn't pre-acquire a token + (which would require tenant_id we don't have client-side).""" + params = ["Authentication=ActiveDirectoryServicePrincipal", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryServicePrincipal" in modified_params + assert auth_type is None + + def test_service_principal_auth_case_insensitive(self): + params = ["authentication=activedirectoryserviceprincipal", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type is None + def test_msi_auth(self): params = ["Authentication=ActiveDirectoryMSI", "Server=test"] _, auth_type = process_auth_parameters(params) @@ -433,6 +465,12 @@ def test_devicecode(self): == "devicecode" ) + def test_serviceprincipal(self): + assert ( + extract_auth_type("Server=test;Authentication=ActiveDirectoryServicePrincipal;") + == "serviceprincipal" + ) + def test_msi(self): assert extract_auth_type("Server=test;Authentication=ActiveDirectoryMSI;") == "msi" @@ -1012,3 +1050,233 @@ def __init__(self): assert credential_kwargs is None finally: azure_identity.DefaultAzureCredential = original + + +class TestParseTenantId: + def test_guid_tenant(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_guid_tenant_no_trailing_slash(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_domain_tenant(self): + url = "https://login.microsoftonline.com/contoso.onmicrosoft.com/" + assert _parse_tenant_id(url) == "contoso.onmicrosoft.com" + + def test_tenant_with_query_string(self): + url = "https://login.microsoftonline.com/tenant-guid/?foo=bar" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_extra_path_segments_after_tenant(self): + url = "https://login.microsoftonline.com/tenant-guid/oauth2/authorize" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_empty_string(self): + assert _parse_tenant_id("") is None + + def test_no_path(self): + assert _parse_tenant_id("https://login.microsoftonline.com/") is None + + def test_rejects_bare_string_without_scheme(self): + # urlparse puts a bare string into path; without a scheme/netloc check + # this would be silently treated as a tenant id. + assert _parse_tenant_id("tenant-guid") is None + + def test_rejects_path_only_url(self): + assert _parse_tenant_id("/tenant-guid/oauth2") is None + + def test_rejects_http_scheme(self): + # Azure AD STS URLs are always https. Reject http to avoid trusting + # a downgraded URL. + assert _parse_tenant_id("http://login.microsoftonline.com/tenant/") is None + + +class TestServicePrincipalAuth: + """Tests for the ActiveDirectoryServicePrincipal token factory.""" + + def test_make_token_factory_returns_callable(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + assert callable(factory) + + def test_factory_requires_client_id(self): + with pytest.raises(ValueError, match="client_id"): + ServicePrincipalAuth.make_token_factory("", "client-secret") + + def test_factory_requires_client_secret(self): + with pytest.raises(ValueError, match="client_secret"): + ServicePrincipalAuth.make_token_factory("client-id", "") + + def test_factory_returns_utf16le_bytes(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + result = factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + assert isinstance(result, bytes) + # SAMPLE_TOKEN is hex chars (ASCII). UTF-16LE encoding doubles each byte + # and inserts a 0x00 high byte after each ASCII char. + assert result == SAMPLE_TOKEN.encode("utf-16-le") + assert len(result) == len(SAMPLE_TOKEN) * 2 + + def test_factory_forwards_credentials_to_ClientSecretCredential(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_init_kwargs = None + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory( + "11111111-2222-3333-4444-555555555555", "my-secret" + ) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/", + "activedirectoryserviceprincipal", + ) + + assert az.ClientSecretCredential.last_init_kwargs == { + "tenant_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "client_id": "11111111-2222-3333-4444-555555555555", + "client_secret": "my-secret", + } + + def test_factory_builds_scope_from_spn(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_keeps_existing_default_suffix(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/.default", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_errors_on_unparseable_sts_url(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="Could not extract tenant_id"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/", # no tenant segment + "activedirectoryserviceprincipal", + ) + + def test_factory_propagates_authentication_error(self): + from azure.core.exceptions import ClientAuthenticationError + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError("AADSTS7000215: Invalid client secret") + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="ServicePrincipal authentication failed"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_does_not_leak_provider_message_in_runtime_error(self): + """The user-facing RuntimeError must not echo the provider message + (which can carry tenant ids, claims, or other sensitive context). + Provider detail is preserved in debug logs only.""" + from azure.core.exceptions import ClientAuthenticationError + + secret_marker = "AADSTS7000215_SECRET_MARKER_in_provider_message" + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError(secret_marker) + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + try: + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + except RuntimeError as e: + full_chain = str(e) + cause = e.__cause__ + while cause is not None: + full_chain += " || " + str(cause) + cause = getattr(cause, "__cause__", None) + assert ( + secret_marker not in full_chain + ), f"Provider message leaked into surfaced exception chain: {full_chain}" + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_rejects_empty_spn(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="empty SPN"): + factory( + "", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + + def test_factory_caches_credential_per_tenant(self): + """ClientSecretCredential must be reused across calls for the same + tenant so azure-identity's per-instance token cache actually works.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + sts = "https://login.microsoftonline.com/tenant-guid/" + for _ in range(3): + factory("https://database.windows.net/", sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1, ( + f"Expected 1 ClientSecretCredential construction across 3 calls, " + f"got {construction_count['n']}" + ) + # A different tenant should produce a second instance. + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/other-tenant/", + "activedirectoryserviceprincipal", + ) + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original From 75ba64326e735f12d258108f7aeab6b8d1e1c957 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma <8655500+bewithgaurav@users.noreply.github.com> Date: Mon, 25 May 2026 12:21:26 +0530 Subject: [PATCH 2/6] FIX: include client_secret hash in SP credential cache key Without this, a rotated client_secret was silently masked for up to ~1 hour: the (tenant_id, client_id) cache key matched the previously cached ClientSecretCredential, and azure-identity's internal token cache kept returning the token issued under the OLD secret until expiry. A user updating PWD in the connection string after an external secret rotation would observe bulkcopy continuing to work with the stale credential. Fix: hash client_secret (SHA-256) into the cache key so rotation produces a fresh ClientSecretCredential instance, defeating azure-identity's per-instance token cache. Hashing avoids storing the raw secret in the dict key. Tests: - test_factory_rotates_credential_when_secret_changes: two factories for same tenant+client_id with different secrets produce 2 distinct ClientSecretCredential instances. Re-calling either factory uses its own cache entry (does not construct again). - test_factory_cache_key_does_not_contain_raw_secret: asserts the raw secret is not present in any cache key (only its hash). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 23 +++++++++----- tests/test_008_auth.py | 68 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index ceded2dff..f564bfb9a 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -4,6 +4,7 @@ This module handles authentication for the mssql_python package. """ +import hashlib import platform import struct import threading @@ -239,15 +240,23 @@ def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: try: # Reuse the shared credential cache (introduced for MSI in PR #573) # so SP credentials get the same per-instance token reuse semantics - # as the other AD methods. Key includes tenant_id so a server that - # somehow returns different tenants on different handshakes still - # gets distinct credentials. client_secret is intentionally NOT in - # the key — credentials are looked up by identity, not by secret; - # if the secret rotates, the closure will still hold the old one - # and AAD will reject the token, surfacing as ClientAuthenticationError. + # as the other AD methods. + # + # The cache key includes a hash of client_secret so a rotated + # secret produces a different cache entry. Without this, an + # external secret rotation would not invalidate the cached + # ClientSecretCredential: azure-identity's internal token cache + # would keep returning the previously-issued token (good for + # up to ~1 hour) until expiry, masking the rotation. Hashing + # avoids storing the raw secret in the dict key. + secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest() cache_key = _credential_cache_key( "serviceprincipal", - {"tenant_id": tenant_id, "client_id": client_id}, + { + "tenant_id": tenant_id, + "client_id": client_id, + "secret_hash": secret_hash, + }, ) with _credential_cache_lock: credential = _credential_cache.get(cache_key) diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 54f6236d7..7d013abb6 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -1280,3 +1280,71 @@ def get_token(self, scope): assert construction_count["n"] == 2 finally: az.ClientSecretCredential = original + + def test_factory_rotates_credential_when_secret_changes(self): + """A new client_secret for the same tenant+client_id MUST produce a new + ClientSecretCredential instance. Without this, an external secret + rotation would not invalidate the cached credential: azure-identity's + internal token cache would keep returning the previously-issued token + (good for up to ~1 hour) until expiry, masking the rotation.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + sts = "https://login.microsoftonline.com/tenant-guid/" + spn = "https://database.windows.net/" + + # Old secret, two calls -> 1 construction (cached) + factory_old = ServicePrincipalAuth.make_token_factory("cid", "old-secret") + factory_old(spn, sts, "activedirectoryserviceprincipal") + factory_old(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1 + + # Rotate the secret. Same tenant + client_id, different secret. + # MUST produce a fresh ClientSecretCredential so azure-identity + # cannot serve a stale token from its internal cache. + factory_new = ServicePrincipalAuth.make_token_factory("cid", "new-secret") + factory_new(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2, ( + f"Expected 2 ClientSecretCredential constructions after secret rotation, " + f"got {construction_count['n']}. A rotated secret was silently ignored." + ) + + # Calling the new factory again should hit cache (1 more = 2 total) + factory_new(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2 + + # Calling the OLD factory again should still hit the OLD cache entry + # (it's keyed on the hash of "old-secret"), not construct again. + factory_old(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original + + def test_factory_cache_key_does_not_contain_raw_secret(self): + """The cache key must hash the secret, never store it raw. Otherwise + the secret is visible in process memory as part of the dict key.""" + from mssql_python.auth import _credential_cache + + secret_marker = "RAW_SECRET_MARKER_must_not_appear_in_cache_key" + factory = ServicePrincipalAuth.make_token_factory("cid", secret_marker) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + for key in _credential_cache.keys(): + assert secret_marker not in repr(key), f"Raw secret leaked into cache key: {key!r}" From 763b565f6fc3a2c2634a558ac3b0375bf427d1b4 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma <8655500+bewithgaurav@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:52:20 +0530 Subject: [PATCH 3/6] REFACTOR: introduce _AuthInternal constants for normalized auth identifiers Replace free-floating short-string literals ("serviceprincipal", "interactive", "msi", etc.) with a new _AuthInternal class in constants.py. The strings now live in exactly one place; _AUTH_TYPE_MAP references them, and all five comparison sites in auth.py, connection.py, and cursor.py use the named constants. Prevents drift between the map and the comparisons (a typo would only surface at runtime on the affected auth path), and removes the per-PR temptation to add another bare string literal when wiring a new auth mode. No behavior change. 85 auth tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 27 ++++++++++++++------------- mssql_python/connection.py | 4 ++-- mssql_python/constants.py | 17 ++++++++++++++++- mssql_python/cursor.py | 3 ++- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 67d772d0e..497fb9c73 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -14,6 +14,7 @@ from mssql_python.constants import ( AuthType, ConstantsDDBC, + _AuthInternal, _KEY_AUTHENTICATION, _KEY_UID, _KEY_PWD, @@ -35,11 +36,11 @@ # Map Authentication connection-string values to internal short names. _AUTH_TYPE_MAP: Dict[str, str] = { - AuthType.INTERACTIVE.value: "interactive", - AuthType.DEVICE_CODE.value: "devicecode", - AuthType.DEFAULT.value: "default", - AuthType.MSI.value: "msi", - AuthType.SERVICE_PRINCIPAL.value: "serviceprincipal", + AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE, + AuthType.DEVICE_CODE.value: _AuthInternal.DEVICE_CODE, + AuthType.DEFAULT.value: _AuthInternal.DEFAULT, + AuthType.MSI.value: _AuthInternal.MSI, + AuthType.SERVICE_PRINCIPAL.value: _AuthInternal.SERVICE_PRINCIPAL, } @@ -113,10 +114,10 @@ def _acquire_token( # Mapping of auth types to credential classes credential_map = { - "default": DefaultAzureCredential, - "devicecode": DeviceCodeCredential, - "interactive": InteractiveBrowserCredential, - "msi": ManagedIdentityCredential, + _AuthInternal.DEFAULT: DefaultAzureCredential, + _AuthInternal.DEVICE_CODE: DeviceCodeCredential, + _AuthInternal.INTERACTIVE: InteractiveBrowserCredential, + _AuthInternal.MSI: ManagedIdentityCredential, } credential_class = credential_map.get(auth_type) @@ -269,7 +270,7 @@ def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: # avoids storing the raw secret in the dict key. secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest() cache_key = _credential_cache_key( - "serviceprincipal", + _AuthInternal.SERVICE_PRINCIPAL, { "tenant_id": tenant_id, "client_id": client_id, @@ -329,7 +330,7 @@ def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: return None # On Windows, Interactive auth is handled natively by the ODBC driver. - if auth_type == "interactive" and platform.system().lower() == "windows": + if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows": logger.debug("process_auth_parameters: Windows platform - using native AADInteractive") return None @@ -339,7 +340,7 @@ def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: # extract_auth_type for that), and the cursor.bulkcopy() Model B branch # registers an entra_id_token_factory callback because tenant_id is only # known from the STS URL the server returns during the FedAuth handshake. - if auth_type == "serviceprincipal": + if auth_type == _AuthInternal.SERVICE_PRINCIPAL: logger.debug("process_auth_parameters: ServicePrincipal - ODBC handles natively") return None @@ -362,7 +363,7 @@ def get_auth_token( return None # Handle platform-specific logic for interactive auth - if auth_type == "interactive" and platform.system().lower() == "windows": + if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows": logger.debug("get_auth_token: Windows interactive auth - delegating to native handler") return None # Let Windows handle AADInteractive natively diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692e..e411e616e 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -49,7 +49,7 @@ from mssql_python.constants import ConstantsDDBC, GetInfoConstants from mssql_python.connection_string_parser import _ConnectionStringParser from mssql_python.connection_string_builder import _ConnectionStringBuilder -from mssql_python.constants import _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID +from mssql_python.constants import _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID, _AuthInternal if TYPE_CHECKING: from mssql_python.row import Row @@ -344,7 +344,7 @@ def __init__( # Capture credential kwargs (e.g. user-assigned MSI client_id) # from the parsed dict *before* remove_sensitive_params strips UID. credential_kwargs: Optional[Dict[str, str]] = None - if auth_type == "msi": + if auth_type == _AuthInternal.MSI: uid = (parsed_params.get(_KEY_UID) or "").strip() if uid: credential_kwargs = {"client_id": uid} diff --git a/mssql_python/constants.py b/mssql_python/constants.py index ab944b215..401e434aa 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -332,7 +332,7 @@ class GetInfoConstants(Enum): class AuthType(Enum): - """Constants for authentication types""" + """Constants for authentication types (public/ODBC connection-string form).""" INTERACTIVE = "activedirectoryinteractive" DEVICE_CODE = "activedirectorydevicecode" @@ -341,6 +341,21 @@ class AuthType(Enum): SERVICE_PRINCIPAL = "activedirectoryserviceprincipal" +class _AuthInternal: + """Internal short-form auth identifiers used after normalization. + + Paired with :class:`AuthType` (the public ODBC-string form) by + ``mssql_python.auth._AUTH_TYPE_MAP``. Adding a new auth mode requires + one entry here, one in ``AuthType``, and one in ``_AUTH_TYPE_MAP``. + """ + + DEFAULT = "default" + DEVICE_CODE = "devicecode" + INTERACTIVE = "interactive" + MSI = "msi" + SERVICE_PRINCIPAL = "serviceprincipal" + + class SQLTypes: """Constants for valid SQL data types to use with setinputsizes""" diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index c1efe46e3..d15769224 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2978,8 +2978,9 @@ def bulkcopy( if self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection from mssql_python.auth import AADAuth, ServicePrincipalAuth + from mssql_python.constants import _AuthInternal - if self.connection._auth_type == "serviceprincipal": + if self.connection._auth_type == _AuthInternal.SERVICE_PRINCIPAL: # Model B: callback-based. tenant_id is only known from the # STS URL the server returns mid-handshake, so we register a # factory that py-core invokes during FedAuth (workflow 0x02). From 213935586afcf497fb627d2fbc952f7f30a86351 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma <8655500+bewithgaurav@users.noreply.github.com> Date: Wed, 10 Jun 2026 15:01:26 +0530 Subject: [PATCH 4/6] FIX: bound AAD round-trip + reject multi-tenant aliases for SP factory Two defensive hardening changes on the ServicePrincipal token factory introduced in #534, plus a TODO-style comment documenting one known limitation that is out of scope for this PR. 1. Bound the AAD network round-trip. The factory runs on a mssql-py-core blocking-pool worker (tokio spawn_blocking). Without explicit timeouts, azure-identity's defaults can let a slow / unreachable STS endpoint block that worker for tens of seconds. Pass a RequestsTransport with 10s connection / 15s read timeouts. SP is non-interactive so token issuance is typically <1s; these limits are generous and still bounded. 2. Reject multi-tenant aliases (common / organizations / consumers) in _parse_tenant_id. Confidential clients (SP) cannot authenticate against them; AAD returns AADSTS50194 which is cryptic. Failing fast in the factory surfaces a clearer error than the AAD round-trip. 3. Document the sovereign-cloud limitation in-line: ClientSecretCredential defaults to login.microsoftonline.com because authority= is not passed. Sovereign clouds (Azure US Gov, Azure China) will fail with 'tenant not found'. Tracked as a follow-up; the fix is to derive authority from urlparse(sts_url).netloc. Tests (+5, total 90): - 4 new TestParseTenantId cases for reserved aliases (common, organizations, consumers, case-insensitive) - 1 new TestServicePrincipalAuth case asserting RequestsTransport is constructed with finite connection_timeout + read_timeout and is passed through to ClientSecretCredential - Existing test_factory_forwards_credentials loosened from exact-dict equality to per-field assertions so future kwargs additions don't break it Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 40 ++++++++++++++++++++- tests/test_008_auth.py | 79 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 112 insertions(+), 7 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 497fb9c73..31f3ef630 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -174,6 +174,9 @@ def _acquire_token( raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +_RESERVED_TENANTS = frozenset({"common", "organizations", "consumers"}) + + def _parse_tenant_id(sts_url: str) -> Optional[str]: """Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL. @@ -184,6 +187,12 @@ def _parse_tenant_id(sts_url: str) -> Optional[str]: where is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``) or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are accepted by ``azure.identity.ClientSecretCredential``. + + Returns ``None`` for the multi-tenant aliases ``common`` / ``organizations`` + / ``consumers``: confidential clients (SP) cannot authenticate against + them and AAD responds with a cryptic ``AADSTS50194`` ("application is not + configured as multi-tenant"). Failing fast in the factory surfaces a + clearer error than the AAD round-trip would. """ # pylint: disable=import-outside-toplevel from urllib.parse import urlparse @@ -201,7 +210,11 @@ def _parse_tenant_id(sts_url: str) -> Optional[str]: if not path: return None first_segment = path.split("/", 1)[0] - return first_segment or None + if not first_segment: + return None + if first_segment.lower() in _RESERVED_TENANTS: + return None + return first_segment class ServicePrincipalAuth: @@ -236,6 +249,7 @@ def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: try: from azure.identity import ClientSecretCredential from azure.core.exceptions import ClientAuthenticationError + from azure.core.pipeline.transport import RequestsTransport except ImportError as e: raise RuntimeError( "Azure authentication libraries are not installed. " @@ -280,10 +294,34 @@ def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: with _credential_cache_lock: credential = _credential_cache.get(cache_key) if credential is None: + # Bound the AAD network round-trip. Without explicit + # timeouts, azure-identity's defaults can let an + # unreachable / slow STS endpoint block the calling + # thread for tens of seconds. The factory runs on a + # mssql-py-core blocking-pool worker (tokio + # spawn_blocking), so a stuck callback ties that + # worker up for the duration. SP is non-interactive + # and token issuance is typically <1s; 10s/15s is + # generous and still bounded. + transport = RequestsTransport( + connection_timeout=10, + read_timeout=15, + ) + # KNOWN LIMITATION: ``authority=`` is not passed, + # so this defaults to the public-cloud authority + # (login.microsoftonline.com). Sovereign clouds + # (Azure US Gov, Azure China) are not supported on + # this code path today: AAD will fail with + # "tenant not found" because the tenant lives in + # the sovereign cloud's AAD, not the public one. + # Tracked as a follow-up; the fix is to derive + # ``authority`` from ``urlparse(sts_url).netloc``. + # Out of scope for the initial #534 work. credential = ClientSecretCredential( tenant_id=tenant_id, client_id=client_id, client_secret=client_secret, + transport=transport, ) _credential_cache[cache_key] = credential # mssql-tds passes the resource SPN; azure-identity wants a scope. diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 04af13c6b..10428f827 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -70,6 +70,14 @@ def __init__(self, **kwargs): def get_token(self, scope): return MockToken() + class MockRequestsTransport: + # Captures construction kwargs so the SP factory's timeout config + # can be asserted. + last_init_kwargs = None + + def __init__(self, **kwargs): + MockRequestsTransport.last_init_kwargs = kwargs + # Mock ClientAuthenticationError class MockClientAuthenticationError(Exception): pass @@ -85,6 +93,10 @@ class MockCore: class exceptions: ClientAuthenticationError = MockClientAuthenticationError + class pipeline: + class transport: + RequestsTransport = MockRequestsTransport + # Create mock azure module if it doesn't exist if "azure" not in sys.modules: sys.modules["azure"] = type("MockAzure", (), {})() @@ -93,11 +105,19 @@ class exceptions: sys.modules["azure.identity"] = MockIdentity() sys.modules["azure.core"] = MockCore() sys.modules["azure.core.exceptions"] = MockCore.exceptions() + sys.modules["azure.core.pipeline"] = MockCore.pipeline() + sys.modules["azure.core.pipeline.transport"] = MockCore.pipeline.transport() yield # Cleanup - for module in ["azure.identity", "azure.core", "azure.core.exceptions"]: + for module in [ + "azure.identity", + "azure.core", + "azure.core.exceptions", + "azure.core.pipeline", + "azure.core.pipeline.transport", + ]: if module in sys.modules: del sys.modules[module] @@ -1052,6 +1072,22 @@ def test_rejects_http_scheme(self): # a downgraded URL. assert _parse_tenant_id("http://login.microsoftonline.com/tenant/") is None + def test_rejects_common_alias(self): + # Multi-tenant alias — confidential clients (SP) cannot auth against + # it. Reject up front so the error surfaced is ours, not AADSTS50194. + assert _parse_tenant_id("https://login.microsoftonline.com/common/") is None + + def test_rejects_organizations_alias(self): + assert _parse_tenant_id("https://login.microsoftonline.com/organizations/") is None + + def test_rejects_consumers_alias(self): + assert _parse_tenant_id("https://login.microsoftonline.com/consumers/") is None + + def test_rejects_reserved_alias_case_insensitive(self): + # Defensive: AAD treats these as case-insensitive; we should too. + assert _parse_tenant_id("https://login.microsoftonline.com/Common/") is None + assert _parse_tenant_id("https://login.microsoftonline.com/COMMON/") is None + class TestServicePrincipalAuth: """Tests for the ActiveDirectoryServicePrincipal token factory.""" @@ -1095,11 +1131,42 @@ def test_factory_forwards_credentials_to_ClientSecretCredential(self): "activedirectoryserviceprincipal", ) - assert az.ClientSecretCredential.last_init_kwargs == { - "tenant_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", - "client_id": "11111111-2222-3333-4444-555555555555", - "client_secret": "my-secret", - } + kwargs = az.ClientSecretCredential.last_init_kwargs + # tenant/client/secret must match — transport is asserted separately. + assert kwargs["tenant_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + assert kwargs["client_id"] == "11111111-2222-3333-4444-555555555555" + assert kwargs["client_secret"] == "my-secret" + + def test_factory_passes_transport_with_explicit_timeouts(self): + # Without explicit timeouts, azure-identity defaults can block the + # mssql-py-core blocking-pool worker for tens of seconds on a slow + # AAD endpoint. The factory must pass a bounded RequestsTransport. + from azure.core.pipeline.transport import RequestsTransport + + RequestsTransport.last_init_kwargs = None + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_init_kwargs = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + + # Transport is constructed with finite connection + read timeouts. + t_kwargs = RequestsTransport.last_init_kwargs + assert t_kwargs is not None, "RequestsTransport was never constructed" + assert "connection_timeout" in t_kwargs + assert "read_timeout" in t_kwargs + assert isinstance(t_kwargs["connection_timeout"], (int, float)) + assert isinstance(t_kwargs["read_timeout"], (int, float)) + assert 0 < t_kwargs["connection_timeout"] <= 30 + assert 0 < t_kwargs["read_timeout"] <= 60 + + # Credential receives the transport. + cred_kwargs = az.ClientSecretCredential.last_init_kwargs + assert "transport" in cred_kwargs def test_factory_builds_scope_from_spn(self): az = sys.modules["azure.identity"] From 3d3912bfc7ff9c797e04aabb6b2f86ba13d4203f Mon Sep 17 00:00:00 2001 From: Gaurav Sharma <8655500+bewithgaurav@users.noreply.github.com> Date: Wed, 10 Jun 2026 16:09:01 +0530 Subject: [PATCH 5/6] DOC: clarify cert-based SP limitation + drop Model A/B jargon - Note in cursor.py that cert-based ServicePrincipal is not supported on the callback branch (msodbcsql also can't do it without msodbcsqlmsqa.dll, which we don't ship; see #480). - Update the missing-UID/PWD error to say 'currently supports client-secret only' so cert users get a clearer signal. - Replace 'Model A' / 'Model B' shorthand with descriptive labels ('pre-acquired token' / 'callback-based') in cursor.py and auth.py. The A/B naming was internal-vocabulary only. No behavior change. 90 auth tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 6 +++--- mssql_python/cursor.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 31f3ef630..e5bf09293 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -375,9 +375,9 @@ def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: # ServicePrincipal: ODBC (msodbcsql 17.3+) handles this natively for # regular queries, so return None to let ODBC own the query path. Bulkcopy # still needs the auth type (Connection.__init__ falls back to - # extract_auth_type for that), and the cursor.bulkcopy() Model B branch - # registers an entra_id_token_factory callback because tenant_id is only - # known from the STS URL the server returns during the FedAuth handshake. + # extract_auth_type for that), and the cursor.bulkcopy() callback branch + # registers an entra_id_token_factory because tenant_id is only known + # from the STS URL the server returns during the FedAuth handshake. if auth_type == _AuthInternal.SERVICE_PRINCIPAL: logger.debug("process_auth_parameters: ServicePrincipal - ODBC handles natively") return None diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index d15769224..24d2d98cc 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2981,15 +2981,17 @@ def bulkcopy( from mssql_python.constants import _AuthInternal if self.connection._auth_type == _AuthInternal.SERVICE_PRINCIPAL: - # Model B: callback-based. tenant_id is only known from the - # STS URL the server returns mid-handshake, so we register a - # factory that py-core invokes during FedAuth (workflow 0x02). + # Callback-based path: tenant_id is only known from the STS URL + # the server returns mid-handshake, so we register a factory + # that py-core invokes during FedAuth (workflow 0x02). + # Cert-based ServicePrincipal is not supported on this path. client_id = params.get("uid", "") client_secret = params.get("pwd", "") if not client_id or not client_secret: raise RuntimeError( "Bulk copy with Authentication=ActiveDirectoryServicePrincipal " - "requires UID (client_id) and PWD (client_secret) in the " + "currently supports client-secret only. " + "Provide UID (client_id) and PWD (client_secret) in the " "connection string." ) try: @@ -3005,13 +3007,13 @@ def bulkcopy( # the factory is dispatched at handshake time. logger.debug("Bulk copy: registered ServicePrincipal token factory") else: - # Model A: pre-acquired token. Used for Default, DeviceCode, + # Pre-acquired token path. Used for Default, DeviceCode, # Interactive (non-Windows), MSI (system- or user-assigned), # and any other AD method whose tenant_id is discoverable # client-side via Azure Identity SDK. credential kwargs # (e.g. user-assigned MSI client_id) were captured by # Connection.__init__ before remove_sensitive_params stripped - # UID from connection_str — re-parsing here would miss them. + # UID from connection_str. re-parsing here would miss them. try: raw_token = AADAuth.get_raw_token( self.connection._auth_type, From 06ade95915f8d77eda6969288524531fec7436a6 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma <8655500+bewithgaurav@users.noreply.github.com> Date: Wed, 10 Jun 2026 19:28:40 +0530 Subject: [PATCH 6/6] STYLE: apply black formatting to connection.py import The _AuthInternal import addition pushed the existing import line over the 100-char limit. Black splits it into a parenthesized multi-line import. No behavior change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/connection.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index e411e616e..94fb0924c 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -49,7 +49,12 @@ from mssql_python.constants import ConstantsDDBC, GetInfoConstants from mssql_python.connection_string_parser import _ConnectionStringParser from mssql_python.connection_string_builder import _ConnectionStringBuilder -from mssql_python.constants import _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID, _AuthInternal +from mssql_python.constants import ( + _RESERVED_PARAMETERS, + _KEY_AUTHENTICATION, + _KEY_UID, + _AuthInternal, +) if TYPE_CHECKING: from mssql_python.row import Row