-
Notifications
You must be signed in to change notification settings - Fork 50
FEAT: Add ActiveDirectoryServicePrincipal support for bulk copy #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
28f9569
75ba643
7e146b5
01a2ab3
8e5c664
fbd333f
763b565
2139355
3d3912b
3358bf2
06ade95
9f688a1
5bed26d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| This module handles authentication for the mssql_python package. | ||
| """ | ||
|
|
||
| import hashlib | ||
| import platform | ||
| import struct | ||
| import threading | ||
|
|
@@ -13,6 +14,7 @@ | |
| from mssql_python.constants import ( | ||
| AuthType, | ||
| ConstantsDDBC, | ||
| _AuthInternal, | ||
| _KEY_AUTHENTICATION, | ||
| _KEY_UID, | ||
| _KEY_PWD, | ||
|
|
@@ -34,10 +36,11 @@ | |
|
|
||
| # Map Authentication connection-string values to internal short names. | ||
| _AUTH_TYPE_MAP: Dict[str, str] = { | ||
| AuthType.INTERACTIVE.value: "interactive", | ||
| AuthType.DEVICE_CODE.value: "devicecode", | ||
| AuthType.DEFAULT.value: "default", | ||
| AuthType.MSI.value: "msi", | ||
| AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE, | ||
| AuthType.DEVICE_CODE.value: _AuthInternal.DEVICE_CODE, | ||
| AuthType.DEFAULT.value: _AuthInternal.DEFAULT, | ||
| AuthType.MSI.value: _AuthInternal.MSI, | ||
| AuthType.SERVICE_PRINCIPAL.value: _AuthInternal.SERVICE_PRINCIPAL, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -111,10 +114,10 @@ def _acquire_token( | |
|
|
||
| # Mapping of auth types to credential classes | ||
| credential_map = { | ||
| "default": DefaultAzureCredential, | ||
| "devicecode": DeviceCodeCredential, | ||
| "interactive": InteractiveBrowserCredential, | ||
| "msi": ManagedIdentityCredential, | ||
| _AuthInternal.DEFAULT: DefaultAzureCredential, | ||
| _AuthInternal.DEVICE_CODE: DeviceCodeCredential, | ||
| _AuthInternal.INTERACTIVE: InteractiveBrowserCredential, | ||
| _AuthInternal.MSI: ManagedIdentityCredential, | ||
| } | ||
|
|
||
| credential_class = credential_map.get(auth_type) | ||
|
|
@@ -171,6 +174,181 @@ def _acquire_token( | |
| raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e | ||
|
|
||
|
|
||
| _RESERVED_TENANTS = frozenset({"common", "organizations", "consumers"}) | ||
|
|
||
|
|
||
| def _parse_tenant_id(sts_url: str) -> Optional[str]: | ||
| """Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL. | ||
|
|
||
| Expected formats: | ||
| https://login.microsoftonline.com/<tenant>/ | ||
| https://login.microsoftonline.com/<tenant>/?... | ||
| https://login.microsoftonline.com/<tenant> | ||
| where <tenant> is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``) | ||
| or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are | ||
| accepted by ``azure.identity.ClientSecretCredential``. | ||
|
|
||
| Returns ``None`` for the multi-tenant aliases ``common`` / ``organizations`` | ||
| / ``consumers``: confidential clients (SP) cannot authenticate against | ||
| them and AAD responds with a cryptic ``AADSTS50194`` ("application is not | ||
| configured as multi-tenant"). Failing fast in the factory surfaces a | ||
| clearer error than the AAD round-trip would. | ||
| """ | ||
| # pylint: disable=import-outside-toplevel | ||
| from urllib.parse import urlparse | ||
|
|
||
| try: | ||
| parsed = urlparse(sts_url) | ||
| except (ValueError, AttributeError): | ||
| return None | ||
| # Reject anything that isn't an https URL with a netloc. ``urlparse`` will | ||
| # happily put a bare string like ``"tenant-guid"`` into ``path``, which | ||
| # would then look like a valid tenant. Azure AD STS URLs are always https. | ||
| if parsed.scheme != "https" or not parsed.netloc: | ||
| return None | ||
| path = (parsed.path or "").strip("/") | ||
| if not path: | ||
| return None | ||
| first_segment = path.split("/", 1)[0] | ||
| if not first_segment: | ||
| return None | ||
| if first_segment.lower() in _RESERVED_TENANTS: | ||
| return None | ||
| return first_segment | ||
|
|
||
|
|
||
| class ServicePrincipalAuth: | ||
| """Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal. | ||
|
|
||
| The bulkcopy path through mssql-py-core uses callback-based token | ||
| acquisition (FedAuth workflow ``0x02``) because tenant_id is only known | ||
| from the STS URL that the server returns during the TDS handshake. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def make_token_factory(client_id: str, client_secret: str): | ||
| """Return a callable suitable for ``entra_id_token_factory``. | ||
|
|
||
| Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``. | ||
| Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format). | ||
|
|
||
| ``ClientSecretCredential`` instances are reused across calls via the | ||
| module-level ``_credential_cache``, keyed by | ||
| ``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's | ||
| in-memory token cache (which is per-credential-instance) actually | ||
| works across handshake retries, reconnects, and separate bulkcopy | ||
| invocations using the same identity. | ||
| """ | ||
| if not client_id: | ||
| raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)") | ||
| if not client_secret: | ||
| raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)") | ||
|
|
||
| def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: | ||
| # pylint: disable=import-outside-toplevel,unused-argument | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| try: | ||
| from azure.identity import ClientSecretCredential | ||
| from azure.core.exceptions import ClientAuthenticationError | ||
| from azure.core.pipeline.transport import RequestsTransport | ||
| except ImportError as e: | ||
| raise RuntimeError( | ||
| "Azure authentication libraries are not installed. " | ||
| "Please install with: pip install azure-identity azure-core" | ||
| ) from e | ||
|
|
||
| if not spn: | ||
| raise RuntimeError( | ||
| "ServicePrincipal token factory: empty SPN from server " | ||
| "(cannot construct token scope)" | ||
| ) | ||
| tenant_id = _parse_tenant_id(sts_url) | ||
| if not tenant_id: | ||
| raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}") | ||
|
|
||
| logger.info( | ||
| "ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s", | ||
| tenant_id, | ||
| spn, | ||
| ) | ||
| try: | ||
| # Reuse the shared credential cache (introduced for MSI in PR #573) | ||
| # so SP credentials get the same per-instance token reuse semantics | ||
| # as the other AD methods. | ||
| # | ||
| # The cache key includes a hash of client_secret so a rotated | ||
| # secret produces a different cache entry. Without this, an | ||
| # external secret rotation would not invalidate the cached | ||
| # ClientSecretCredential: azure-identity's internal token cache | ||
| # would keep returning the previously-issued token (good for | ||
| # up to ~1 hour) until expiry, masking the rotation. Hashing | ||
| # avoids storing the raw secret in the dict key. | ||
| secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest()
def _factory(spn, sts_url, auth_method):
...
cache_key = _credential_cache_key(_AuthInternal.SERVICE_PRINCIPAL,
{"tenant_id": tenant_id, "client_id": client_id, "secret_hash": secret_hash})Save the SHA-256 per call (tens of μs each) and remove a redundant import of |
||
| cache_key = _credential_cache_key( | ||
| _AuthInternal.SERVICE_PRINCIPAL, | ||
| { | ||
| "tenant_id": tenant_id, | ||
| "client_id": client_id, | ||
| "secret_hash": secret_hash, | ||
| }, | ||
| ) | ||
| with _credential_cache_lock: | ||
| credential = _credential_cache.get(cache_key) | ||
| if credential is None: | ||
| # Bound the AAD network round-trip. Without explicit | ||
| # timeouts, azure-identity's defaults can let an | ||
| # unreachable / slow STS endpoint block the calling | ||
| # thread for tens of seconds. The factory runs on a | ||
| # mssql-py-core blocking-pool worker (tokio | ||
| # spawn_blocking), so a stuck callback ties that | ||
| # worker up for the duration. SP is non-interactive | ||
| # and token issuance is typically <1s; 10s/15s is | ||
| # generous and still bounded. | ||
| transport = RequestsTransport( | ||
| connection_timeout=10, | ||
| read_timeout=15, | ||
| ) | ||
| # KNOWN LIMITATION: ``authority=`` is not passed, | ||
| # so this defaults to the public-cloud authority | ||
| # (login.microsoftonline.com). Sovereign clouds | ||
| # (Azure US Gov, Azure China) are not supported on | ||
| # this code path today: AAD will fail with | ||
| # "tenant not found" because the tenant lives in | ||
| # the sovereign cloud's AAD, not the public one. | ||
| # Tracked as a follow-up; the fix is to derive | ||
| # ``authority`` from ``urlparse(sts_url).netloc``. | ||
| # Out of scope for the initial #534 work. | ||
| credential = ClientSecretCredential( | ||
| tenant_id=tenant_id, | ||
| client_id=client_id, | ||
| client_secret=client_secret, | ||
| transport=transport, | ||
| ) | ||
| _credential_cache[cache_key] = credential | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider evicting any prior entries matching |
||
| # mssql-tds passes the resource SPN; azure-identity wants a scope. | ||
| scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default" | ||
| token = credential.get_token(scope).token | ||
| logger.info( | ||
| "ServicePrincipal token factory: token acquired, length=%d chars", | ||
| len(token), | ||
| ) | ||
| return token.encode("utf-16-le") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
one line in the docstring like "py-core handles the FedAuth length-prefix wrapping; do not pre-wrap" so a future reader doesn't try to "fix" this to match |
||
| except ClientAuthenticationError as e: | ||
| # Keep the detailed provider error in debug logs only. The | ||
| # surfaced message is intentionally generic so that any | ||
| # secret-bearing provider text never reaches the user-facing | ||
| # exception chain. | ||
| logger.error( | ||
| "ServicePrincipal authentication failed: tenant=%s, error=%s", | ||
| tenant_id, | ||
| str(e), | ||
| ) | ||
| raise RuntimeError( | ||
| "ServicePrincipal authentication failed; " "see debug logs for provider details" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. message says "see debug logs for provider details" but the provider error is logged at |
||
| ) from None | ||
|
|
||
| return _factory | ||
|
|
||
|
|
||
| def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: | ||
| """ | ||
| Extract authentication type from parsed connection parameters. | ||
|
|
@@ -190,10 +368,20 @@ def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: | |
| return None | ||
|
|
||
| # On Windows, Interactive auth is handled natively by the ODBC driver. | ||
| if auth_type == "interactive" and platform.system().lower() == "windows": | ||
| if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows": | ||
| logger.debug("process_auth_parameters: Windows platform - using native AADInteractive") | ||
| return None | ||
|
|
||
| # ServicePrincipal: ODBC (msodbcsql 17.3+) handles this natively for | ||
| # regular queries, so return None to let ODBC own the query path. Bulkcopy | ||
| # still needs the auth type (Connection.__init__ falls back to | ||
| # extract_auth_type for that), and the cursor.bulkcopy() callback branch | ||
| # registers an entra_id_token_factory because tenant_id is only known | ||
| # from the STS URL the server returns during the FedAuth handshake. | ||
| if auth_type == _AuthInternal.SERVICE_PRINCIPAL: | ||
| logger.debug("process_auth_parameters: ServicePrincipal - ODBC handles natively") | ||
| return None | ||
|
|
||
| logger.debug("process_auth_parameters: auth_type=%s", auth_type) | ||
| return auth_type | ||
|
|
||
|
|
@@ -213,7 +401,7 @@ def get_auth_token( | |
| return None | ||
|
|
||
| # Handle platform-specific logic for interactive auth | ||
| if auth_type == "interactive" and platform.system().lower() == "windows": | ||
| if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows": | ||
| logger.debug("get_auth_token: Windows interactive auth - delegating to native handler") | ||
| return None # Let Windows handle AADInteractive natively | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.