diff --git a/CHANGELOG.md b/CHANGELOG.md index 517a60bf..4fa44f85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal` + via an `entra_id_token_factory` callback registered on the mssql-py-core + connection. The callback is invoked by mssql-tds mid-handshake (FedAuth + workflow 0x02) so the tenant id can be resolved from the server-supplied + STS URL. Requires `mssql-py-core` 0.1.5+. Partial fix for #534. ### Changed - Improved error handling in the connection module. diff --git a/mssql_python/auth.py b/mssql_python/auth.py index dd716c2c..e5bf0929 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -4,6 +4,7 @@ This module handles authentication for the mssql_python package. """ +import hashlib import platform import struct import threading @@ -13,6 +14,7 @@ from mssql_python.constants import ( AuthType, ConstantsDDBC, + _AuthInternal, _KEY_AUTHENTICATION, _KEY_UID, _KEY_PWD, @@ -34,10 +36,11 @@ # Map Authentication connection-string values to internal short names. _AUTH_TYPE_MAP: Dict[str, str] = { - AuthType.INTERACTIVE.value: "interactive", - AuthType.DEVICE_CODE.value: "devicecode", - AuthType.DEFAULT.value: "default", - AuthType.MSI.value: "msi", + AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE, + AuthType.DEVICE_CODE.value: _AuthInternal.DEVICE_CODE, + AuthType.DEFAULT.value: _AuthInternal.DEFAULT, + AuthType.MSI.value: _AuthInternal.MSI, + AuthType.SERVICE_PRINCIPAL.value: _AuthInternal.SERVICE_PRINCIPAL, } @@ -111,10 +114,10 @@ def _acquire_token( # Mapping of auth types to credential classes credential_map = { - "default": DefaultAzureCredential, - "devicecode": DeviceCodeCredential, - "interactive": InteractiveBrowserCredential, - "msi": ManagedIdentityCredential, + _AuthInternal.DEFAULT: DefaultAzureCredential, + _AuthInternal.DEVICE_CODE: DeviceCodeCredential, + _AuthInternal.INTERACTIVE: InteractiveBrowserCredential, + _AuthInternal.MSI: ManagedIdentityCredential, } credential_class = credential_map.get(auth_type) @@ -171,6 +174,181 @@ def _acquire_token( raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +_RESERVED_TENANTS = frozenset({"common", "organizations", "consumers"}) + + +def _parse_tenant_id(sts_url: str) -> Optional[str]: + """Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL. + + Expected formats: + https://login.microsoftonline.com// + https://login.microsoftonline.com//?... + https://login.microsoftonline.com/ + where is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``) + or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are + accepted by ``azure.identity.ClientSecretCredential``. + + Returns ``None`` for the multi-tenant aliases ``common`` / ``organizations`` + / ``consumers``: confidential clients (SP) cannot authenticate against + them and AAD responds with a cryptic ``AADSTS50194`` ("application is not + configured as multi-tenant"). Failing fast in the factory surfaces a + clearer error than the AAD round-trip would. + """ + # pylint: disable=import-outside-toplevel + from urllib.parse import urlparse + + try: + parsed = urlparse(sts_url) + except (ValueError, AttributeError): + return None + # Reject anything that isn't an https URL with a netloc. ``urlparse`` will + # happily put a bare string like ``"tenant-guid"`` into ``path``, which + # would then look like a valid tenant. Azure AD STS URLs are always https. + if parsed.scheme != "https" or not parsed.netloc: + return None + path = (parsed.path or "").strip("/") + if not path: + return None + first_segment = path.split("/", 1)[0] + if not first_segment: + return None + if first_segment.lower() in _RESERVED_TENANTS: + return None + return first_segment + + +class ServicePrincipalAuth: + """Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal. + + The bulkcopy path through mssql-py-core uses callback-based token + acquisition (FedAuth workflow ``0x02``) because tenant_id is only known + from the STS URL that the server returns during the TDS handshake. + """ + + @staticmethod + def make_token_factory(client_id: str, client_secret: str): + """Return a callable suitable for ``entra_id_token_factory``. + + Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``. + Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format). + + ``ClientSecretCredential`` instances are reused across calls via the + module-level ``_credential_cache``, keyed by + ``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's + in-memory token cache (which is per-credential-instance) actually + works across handshake retries, reconnects, and separate bulkcopy + invocations using the same identity. + """ + if not client_id: + raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)") + if not client_secret: + raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)") + + def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: + # pylint: disable=import-outside-toplevel,unused-argument + try: + from azure.identity import ClientSecretCredential + from azure.core.exceptions import ClientAuthenticationError + from azure.core.pipeline.transport import RequestsTransport + except ImportError as e: + raise RuntimeError( + "Azure authentication libraries are not installed. " + "Please install with: pip install azure-identity azure-core" + ) from e + + if not spn: + raise RuntimeError( + "ServicePrincipal token factory: empty SPN from server " + "(cannot construct token scope)" + ) + tenant_id = _parse_tenant_id(sts_url) + if not tenant_id: + raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}") + + logger.info( + "ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s", + tenant_id, + spn, + ) + try: + # Reuse the shared credential cache (introduced for MSI in PR #573) + # so SP credentials get the same per-instance token reuse semantics + # as the other AD methods. + # + # The cache key includes a hash of client_secret so a rotated + # secret produces a different cache entry. Without this, an + # external secret rotation would not invalidate the cached + # ClientSecretCredential: azure-identity's internal token cache + # would keep returning the previously-issued token (good for + # up to ~1 hour) until expiry, masking the rotation. Hashing + # avoids storing the raw secret in the dict key. + secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest() + cache_key = _credential_cache_key( + _AuthInternal.SERVICE_PRINCIPAL, + { + "tenant_id": tenant_id, + "client_id": client_id, + "secret_hash": secret_hash, + }, + ) + with _credential_cache_lock: + credential = _credential_cache.get(cache_key) + if credential is None: + # Bound the AAD network round-trip. Without explicit + # timeouts, azure-identity's defaults can let an + # unreachable / slow STS endpoint block the calling + # thread for tens of seconds. The factory runs on a + # mssql-py-core blocking-pool worker (tokio + # spawn_blocking), so a stuck callback ties that + # worker up for the duration. SP is non-interactive + # and token issuance is typically <1s; 10s/15s is + # generous and still bounded. + transport = RequestsTransport( + connection_timeout=10, + read_timeout=15, + ) + # KNOWN LIMITATION: ``authority=`` is not passed, + # so this defaults to the public-cloud authority + # (login.microsoftonline.com). Sovereign clouds + # (Azure US Gov, Azure China) are not supported on + # this code path today: AAD will fail with + # "tenant not found" because the tenant lives in + # the sovereign cloud's AAD, not the public one. + # Tracked as a follow-up; the fix is to derive + # ``authority`` from ``urlparse(sts_url).netloc``. + # Out of scope for the initial #534 work. + credential = ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + transport=transport, + ) + _credential_cache[cache_key] = credential + # mssql-tds passes the resource SPN; azure-identity wants a scope. + scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default" + token = credential.get_token(scope).token + logger.info( + "ServicePrincipal token factory: token acquired, length=%d chars", + len(token), + ) + return token.encode("utf-16-le") + except ClientAuthenticationError as e: + # Keep the detailed provider error in debug logs only. The + # surfaced message is intentionally generic so that any + # secret-bearing provider text never reaches the user-facing + # exception chain. + logger.error( + "ServicePrincipal authentication failed: tenant=%s, error=%s", + tenant_id, + str(e), + ) + raise RuntimeError( + "ServicePrincipal authentication failed; " "see debug logs for provider details" + ) from None + + return _factory + + def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: """ Extract authentication type from parsed connection parameters. @@ -190,10 +368,20 @@ def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: return None # On Windows, Interactive auth is handled natively by the ODBC driver. - if auth_type == "interactive" and platform.system().lower() == "windows": + if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows": logger.debug("process_auth_parameters: Windows platform - using native AADInteractive") return None + # ServicePrincipal: ODBC (msodbcsql 17.3+) handles this natively for + # regular queries, so return None to let ODBC own the query path. Bulkcopy + # still needs the auth type (Connection.__init__ falls back to + # extract_auth_type for that), and the cursor.bulkcopy() callback branch + # registers an entra_id_token_factory because tenant_id is only known + # from the STS URL the server returns during the FedAuth handshake. + if auth_type == _AuthInternal.SERVICE_PRINCIPAL: + logger.debug("process_auth_parameters: ServicePrincipal - ODBC handles natively") + return None + logger.debug("process_auth_parameters: auth_type=%s", auth_type) return auth_type @@ -213,7 +401,7 @@ def get_auth_token( return None # Handle platform-specific logic for interactive auth - if auth_type == "interactive" and platform.system().lower() == "windows": + if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows": logger.debug("get_auth_token: Windows interactive auth - delegating to native handler") return None # Let Windows handle AADInteractive natively diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692..94fb0924 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -49,7 +49,12 @@ from mssql_python.constants import ConstantsDDBC, GetInfoConstants from mssql_python.connection_string_parser import _ConnectionStringParser from mssql_python.connection_string_builder import _ConnectionStringBuilder -from mssql_python.constants import _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID +from mssql_python.constants import ( + _RESERVED_PARAMETERS, + _KEY_AUTHENTICATION, + _KEY_UID, + _AuthInternal, +) if TYPE_CHECKING: from mssql_python.row import Row @@ -344,7 +349,7 @@ def __init__( # Capture credential kwargs (e.g. user-assigned MSI client_id) # from the parsed dict *before* remove_sensitive_params strips UID. credential_kwargs: Optional[Dict[str, str]] = None - if auth_type == "msi": + if auth_type == _AuthInternal.MSI: uid = (parsed_params.get(_KEY_UID) or "").strip() if uid: credential_kwargs = {"client_id": uid} diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 3bfd3948..401e434a 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -332,12 +332,28 @@ class GetInfoConstants(Enum): class AuthType(Enum): - """Constants for authentication types""" + """Constants for authentication types (public/ODBC connection-string form).""" INTERACTIVE = "activedirectoryinteractive" DEVICE_CODE = "activedirectorydevicecode" DEFAULT = "activedirectorydefault" MSI = "activedirectorymsi" + SERVICE_PRINCIPAL = "activedirectoryserviceprincipal" + + +class _AuthInternal: + """Internal short-form auth identifiers used after normalization. + + Paired with :class:`AuthType` (the public ODBC-string form) by + ``mssql_python.auth._AUTH_TYPE_MAP``. Adding a new auth mode requires + one entry here, one in ``AuthType``, and one in ``_AUTH_TYPE_MAP``. + """ + + DEFAULT = "default" + DEVICE_CODE = "devicecode" + INTERACTIVE = "interactive" + MSI = "msi" + SERVICE_PRINCIPAL = "serviceprincipal" class SQLTypes: diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index b8ea7d5d..4beb3fa7 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -3000,31 +3000,63 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: - # Fresh token acquisition for mssql-py-core connection. credential - # kwargs (e.g. user-assigned MSI client_id) were captured by - # Connection.__init__ before remove_sensitive_params stripped UID - # from connection_str — re-parsing here would miss them. - from mssql_python.auth import AADAuth - - try: - raw_token = AADAuth.get_raw_token( + # Fresh token acquisition for mssql-py-core connection + from mssql_python.auth import AADAuth, ServicePrincipalAuth + from mssql_python.constants import _AuthInternal + + if self.connection._auth_type == _AuthInternal.SERVICE_PRINCIPAL: + # Callback-based path: tenant_id is only known from the STS URL + # the server returns mid-handshake, so we register a factory + # that py-core invokes during FedAuth (workflow 0x02). + # Cert-based ServicePrincipal is not supported on this path. + client_id = params.get("uid", "") + client_secret = params.get("pwd", "") + if not client_id or not client_secret: + raise RuntimeError( + "Bulk copy with Authentication=ActiveDirectoryServicePrincipal " + "currently supports client-secret only. " + "Provide UID (client_id) and PWD (client_secret) in the " + "connection string." + ) + try: + factory = ServicePrincipalAuth.make_token_factory(client_id, client_secret) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to build ServicePrincipal token factory: {e}" + ) from e + pycore_context["entra_id_token_factory"] = factory + # Keep authentication/user_name/password in pycore_context — + # py-core's auth validator + transformer need them to resolve + # the auth method to ActiveDirectoryServicePrincipal before + # the factory is dispatched at handshake time. + logger.debug("Bulk copy: registered ServicePrincipal token factory") + else: + # Pre-acquired token path. Used for Default, DeviceCode, + # Interactive (non-Windows), MSI (system- or user-assigned), + # and any other AD method whose tenant_id is discoverable + # client-side via Azure Identity SDK. credential kwargs + # (e.g. user-assigned MSI client_id) were captured by + # Connection.__init__ before remove_sensitive_params stripped + # UID from connection_str. re-parsing here would miss them. + try: + raw_token = AADAuth.get_raw_token( + self.connection._auth_type, + self.connection._credential_kwargs, + ) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire Azure AD token " + f"for auth_type '{self.connection._auth_type}': {e}" + ) from e + pycore_context["access_token"] = raw_token + # Token replaces credential fields — py-core's validator rejects + # access_token combined with authentication/user_name/password. + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh Azure AD token for auth_type=%s", self.connection._auth_type, - self.connection._credential_kwargs, ) - except (RuntimeError, ValueError) as e: - raise RuntimeError( - f"Bulk copy failed: unable to acquire Azure AD token " - f"for auth_type '{self.connection._auth_type}': {e}" - ) from e - pycore_context["access_token"] = raw_token - # Token replaces credential fields — py-core's validator rejects - # access_token combined with authentication/user_name/password. - for key in ("authentication", "user_name", "password"): - pycore_context.pop(key, None) - logger.debug( - "Bulk copy: acquired fresh Azure AD token for auth_type=%s", - self.connection._auth_type, - ) pycore_connection = None pycore_cursor = None @@ -3098,9 +3130,17 @@ def _prepare_row_iterator(iterable): raise type(e)(str(e)) from None finally: - # Clear sensitive data to minimize memory exposure + # Clear sensitive data to minimize memory exposure. The + # entra_id_token_factory closure captures client_secret, so drop + # our dict reference to it (Rust still holds an Arc until the + # connection is dropped, but at least we don't keep an extra ref). if pycore_context: - for key in ("password", "user_name", "access_token"): + for key in ( + "password", + "user_name", + "access_token", + "entra_id_token_factory", + ): pycore_context.pop(key, None) # Clean up bulk copy resources for resource in (pycore_cursor, pycore_connection): diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index b127133a..10428f82 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -11,6 +11,8 @@ from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, + ServicePrincipalAuth, + _parse_tenant_id, process_auth_parameters, remove_sensitive_params, get_auth_token, @@ -43,6 +45,20 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + class MockClientSecretCredential: + # Captures construction kwargs and get_token args so ServicePrincipal + # tests can assert the right tenant/client_id/secret/scope flowed + # through from the connection string + STS URL. + last_init_kwargs = None + last_scope = None + + def __init__(self, **kwargs): + MockClientSecretCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + MockClientSecretCredential.last_scope = scope + return MockToken() + class MockManagedIdentityCredential: # Captures construction kwargs so user-assigned MSI tests can assert # client_id was forwarded correctly. @@ -54,6 +70,14 @@ def __init__(self, **kwargs): def get_token(self, scope): return MockToken() + class MockRequestsTransport: + # Captures construction kwargs so the SP factory's timeout config + # can be asserted. + last_init_kwargs = None + + def __init__(self, **kwargs): + MockRequestsTransport.last_init_kwargs = kwargs + # Mock ClientAuthenticationError class MockClientAuthenticationError(Exception): pass @@ -62,12 +86,17 @@ class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + ClientSecretCredential = MockClientSecretCredential ManagedIdentityCredential = MockManagedIdentityCredential class MockCore: class exceptions: ClientAuthenticationError = MockClientAuthenticationError + class pipeline: + class transport: + RequestsTransport = MockRequestsTransport + # Create mock azure module if it doesn't exist if "azure" not in sys.modules: sys.modules["azure"] = type("MockAzure", (), {})() @@ -76,11 +105,19 @@ class exceptions: sys.modules["azure.identity"] = MockIdentity() sys.modules["azure.core"] = MockCore() sys.modules["azure.core.exceptions"] = MockCore.exceptions() + sys.modules["azure.core.pipeline"] = MockCore.pipeline() + sys.modules["azure.core.pipeline.transport"] = MockCore.pipeline.transport() yield # Cleanup - for module in ["azure.identity", "azure.core", "azure.core.exceptions"]: + for module in [ + "azure.identity", + "azure.core", + "azure.core.exceptions", + "azure.core.pipeline", + "azure.core.pipeline.transport", + ]: if module in sys.modules: del sys.modules[module] @@ -99,6 +136,7 @@ def test_auth_type_constants(self): assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" assert AuthType.DEFAULT.value == "activedirectorydefault" assert AuthType.MSI.value == "activedirectorymsi" + assert AuthType.SERVICE_PRINCIPAL.value == "activedirectoryserviceprincipal" class TestAADAuth: @@ -327,6 +365,20 @@ def test_default_auth(self): auth_type = process_auth_parameters(params) assert auth_type == "default" + def test_service_principal_auth_leaves_odbc_path_alone(self): + """ServicePrincipal is handled natively by ODBC. process_auth_parameters + must return None so the ODBC path doesn't pre-acquire a token (which + would require tenant_id we don't have client-side). Bulkcopy still + gets "serviceprincipal" from extract_auth_type.""" + params = {"Authentication": "ActiveDirectoryServicePrincipal", "Server": "test"} + auth_type = process_auth_parameters(params) + assert auth_type is None + + def test_service_principal_auth_case_insensitive(self): + params = {"Authentication": "activedirectoryserviceprincipal", "Server": "test"} + auth_type = process_auth_parameters(params) + assert auth_type is None + def test_msi_auth(self): params = {"Authentication": "ActiveDirectoryMSI", "Server": "test"} auth_type = process_auth_parameters(params) @@ -380,6 +432,14 @@ def test_devicecode(self): == "devicecode" ) + def test_serviceprincipal(self): + assert ( + extract_auth_type( + {"Server": "test", "Authentication": "ActiveDirectoryServicePrincipal"} + ) + == "serviceprincipal" + ) + def test_msi(self): assert ( extract_auth_type({"Server": "test", "Authentication": "ActiveDirectoryMSI"}) == "msi" @@ -970,3 +1030,348 @@ def test_token_output_correct_on_cache_miss_and_hit(self): # Same credential instance for both assert "default" in _credential_cache + + +class TestParseTenantId: + def test_guid_tenant(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_guid_tenant_no_trailing_slash(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_domain_tenant(self): + url = "https://login.microsoftonline.com/contoso.onmicrosoft.com/" + assert _parse_tenant_id(url) == "contoso.onmicrosoft.com" + + def test_tenant_with_query_string(self): + url = "https://login.microsoftonline.com/tenant-guid/?foo=bar" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_extra_path_segments_after_tenant(self): + url = "https://login.microsoftonline.com/tenant-guid/oauth2/authorize" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_empty_string(self): + assert _parse_tenant_id("") is None + + def test_no_path(self): + assert _parse_tenant_id("https://login.microsoftonline.com/") is None + + def test_rejects_bare_string_without_scheme(self): + # urlparse puts a bare string into path; without a scheme/netloc check + # this would be silently treated as a tenant id. + assert _parse_tenant_id("tenant-guid") is None + + def test_rejects_path_only_url(self): + assert _parse_tenant_id("/tenant-guid/oauth2") is None + + def test_rejects_http_scheme(self): + # Azure AD STS URLs are always https. Reject http to avoid trusting + # a downgraded URL. + assert _parse_tenant_id("http://login.microsoftonline.com/tenant/") is None + + def test_rejects_common_alias(self): + # Multi-tenant alias — confidential clients (SP) cannot auth against + # it. Reject up front so the error surfaced is ours, not AADSTS50194. + assert _parse_tenant_id("https://login.microsoftonline.com/common/") is None + + def test_rejects_organizations_alias(self): + assert _parse_tenant_id("https://login.microsoftonline.com/organizations/") is None + + def test_rejects_consumers_alias(self): + assert _parse_tenant_id("https://login.microsoftonline.com/consumers/") is None + + def test_rejects_reserved_alias_case_insensitive(self): + # Defensive: AAD treats these as case-insensitive; we should too. + assert _parse_tenant_id("https://login.microsoftonline.com/Common/") is None + assert _parse_tenant_id("https://login.microsoftonline.com/COMMON/") is None + + +class TestServicePrincipalAuth: + """Tests for the ActiveDirectoryServicePrincipal token factory.""" + + def test_make_token_factory_returns_callable(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + assert callable(factory) + + def test_factory_requires_client_id(self): + with pytest.raises(ValueError, match="client_id"): + ServicePrincipalAuth.make_token_factory("", "client-secret") + + def test_factory_requires_client_secret(self): + with pytest.raises(ValueError, match="client_secret"): + ServicePrincipalAuth.make_token_factory("client-id", "") + + def test_factory_returns_utf16le_bytes(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + result = factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + assert isinstance(result, bytes) + # SAMPLE_TOKEN is hex chars (ASCII). UTF-16LE encoding doubles each byte + # and inserts a 0x00 high byte after each ASCII char. + assert result == SAMPLE_TOKEN.encode("utf-16-le") + assert len(result) == len(SAMPLE_TOKEN) * 2 + + def test_factory_forwards_credentials_to_ClientSecretCredential(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_init_kwargs = None + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory( + "11111111-2222-3333-4444-555555555555", "my-secret" + ) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/", + "activedirectoryserviceprincipal", + ) + + kwargs = az.ClientSecretCredential.last_init_kwargs + # tenant/client/secret must match — transport is asserted separately. + assert kwargs["tenant_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + assert kwargs["client_id"] == "11111111-2222-3333-4444-555555555555" + assert kwargs["client_secret"] == "my-secret" + + def test_factory_passes_transport_with_explicit_timeouts(self): + # Without explicit timeouts, azure-identity defaults can block the + # mssql-py-core blocking-pool worker for tens of seconds on a slow + # AAD endpoint. The factory must pass a bounded RequestsTransport. + from azure.core.pipeline.transport import RequestsTransport + + RequestsTransport.last_init_kwargs = None + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_init_kwargs = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + + # Transport is constructed with finite connection + read timeouts. + t_kwargs = RequestsTransport.last_init_kwargs + assert t_kwargs is not None, "RequestsTransport was never constructed" + assert "connection_timeout" in t_kwargs + assert "read_timeout" in t_kwargs + assert isinstance(t_kwargs["connection_timeout"], (int, float)) + assert isinstance(t_kwargs["read_timeout"], (int, float)) + assert 0 < t_kwargs["connection_timeout"] <= 30 + assert 0 < t_kwargs["read_timeout"] <= 60 + + # Credential receives the transport. + cred_kwargs = az.ClientSecretCredential.last_init_kwargs + assert "transport" in cred_kwargs + + def test_factory_builds_scope_from_spn(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_keeps_existing_default_suffix(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/.default", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_errors_on_unparseable_sts_url(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="Could not extract tenant_id"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/", # no tenant segment + "activedirectoryserviceprincipal", + ) + + def test_factory_propagates_authentication_error(self): + from azure.core.exceptions import ClientAuthenticationError + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError("AADSTS7000215: Invalid client secret") + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="ServicePrincipal authentication failed"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_does_not_leak_provider_message_in_runtime_error(self): + """The user-facing RuntimeError must not echo the provider message + (which can carry tenant ids, claims, or other sensitive context). + Provider detail is preserved in debug logs only.""" + from azure.core.exceptions import ClientAuthenticationError + + secret_marker = "AADSTS7000215_SECRET_MARKER_in_provider_message" + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError(secret_marker) + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + try: + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + except RuntimeError as e: + full_chain = str(e) + cause = e.__cause__ + while cause is not None: + full_chain += " || " + str(cause) + cause = getattr(cause, "__cause__", None) + assert ( + secret_marker not in full_chain + ), f"Provider message leaked into surfaced exception chain: {full_chain}" + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_rejects_empty_spn(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="empty SPN"): + factory( + "", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + + def test_factory_caches_credential_per_tenant(self): + """ClientSecretCredential must be reused across calls for the same + tenant so azure-identity's per-instance token cache actually works.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + sts = "https://login.microsoftonline.com/tenant-guid/" + for _ in range(3): + factory("https://database.windows.net/", sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1, ( + f"Expected 1 ClientSecretCredential construction across 3 calls, " + f"got {construction_count['n']}" + ) + # A different tenant should produce a second instance. + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/other-tenant/", + "activedirectoryserviceprincipal", + ) + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original + + def test_factory_rotates_credential_when_secret_changes(self): + """A new client_secret for the same tenant+client_id MUST produce a new + ClientSecretCredential instance. Without this, an external secret + rotation would not invalidate the cached credential: azure-identity's + internal token cache would keep returning the previously-issued token + (good for up to ~1 hour) until expiry, masking the rotation.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + sts = "https://login.microsoftonline.com/tenant-guid/" + spn = "https://database.windows.net/" + + # Old secret, two calls -> 1 construction (cached) + factory_old = ServicePrincipalAuth.make_token_factory("cid", "old-secret") + factory_old(spn, sts, "activedirectoryserviceprincipal") + factory_old(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1 + + # Rotate the secret. Same tenant + client_id, different secret. + # MUST produce a fresh ClientSecretCredential so azure-identity + # cannot serve a stale token from its internal cache. + factory_new = ServicePrincipalAuth.make_token_factory("cid", "new-secret") + factory_new(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2, ( + f"Expected 2 ClientSecretCredential constructions after secret rotation, " + f"got {construction_count['n']}. A rotated secret was silently ignored." + ) + + # Calling the new factory again should hit cache (1 more = 2 total) + factory_new(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2 + + # Calling the OLD factory again should still hit the OLD cache entry + # (it's keyed on the hash of "old-secret"), not construct again. + factory_old(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original + + def test_factory_cache_key_does_not_contain_raw_secret(self): + """The cache key must hash the secret, never store it raw. Otherwise + the secret is visible in process memory as part of the dict key.""" + from mssql_python.auth import _credential_cache + + secret_marker = "RAW_SECRET_MARKER_must_not_appear_in_cache_key" + factory = ServicePrincipalAuth.make_token_factory("cid", secret_marker) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + for key in _credential_cache.keys(): + assert secret_marker not in repr(key), f"Raw secret leaked into cache key: {key!r}"