Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- New feature: Support for macOS and Linux.
- Documentation: Added API documentation in the Wiki.
- Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal`
via an `entra_id_token_factory` callback registered on the mssql-py-core
connection. The callback is invoked by mssql-tds mid-handshake (FedAuth
workflow 0x02) so the tenant id can be resolved from the server-supplied
STS URL. Requires `mssql-py-core` 0.1.5+. Partial fix for #534.

### Changed
- Improved error handling in the connection module.
Expand Down
208 changes: 198 additions & 10 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This module handles authentication for the mssql_python package.
"""

import hashlib
import platform
import struct
import threading
Expand All @@ -13,6 +14,7 @@
from mssql_python.constants import (
AuthType,
ConstantsDDBC,
_AuthInternal,
_KEY_AUTHENTICATION,
_KEY_UID,
_KEY_PWD,
Expand All @@ -34,10 +36,11 @@

# Map Authentication connection-string values to internal short names.
_AUTH_TYPE_MAP: Dict[str, str] = {
AuthType.INTERACTIVE.value: "interactive",
AuthType.DEVICE_CODE.value: "devicecode",
AuthType.DEFAULT.value: "default",
AuthType.MSI.value: "msi",
AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE,
AuthType.DEVICE_CODE.value: _AuthInternal.DEVICE_CODE,
AuthType.DEFAULT.value: _AuthInternal.DEFAULT,
AuthType.MSI.value: _AuthInternal.MSI,
AuthType.SERVICE_PRINCIPAL.value: _AuthInternal.SERVICE_PRINCIPAL,
}


Expand Down Expand Up @@ -111,10 +114,10 @@ def _acquire_token(

# Mapping of auth types to credential classes
credential_map = {
"default": DefaultAzureCredential,
"devicecode": DeviceCodeCredential,
"interactive": InteractiveBrowserCredential,
"msi": ManagedIdentityCredential,
_AuthInternal.DEFAULT: DefaultAzureCredential,
_AuthInternal.DEVICE_CODE: DeviceCodeCredential,
_AuthInternal.INTERACTIVE: InteractiveBrowserCredential,
_AuthInternal.MSI: ManagedIdentityCredential,
}

credential_class = credential_map.get(auth_type)
Expand Down Expand Up @@ -171,6 +174,181 @@ def _acquire_token(
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e


_RESERVED_TENANTS = frozenset({"common", "organizations", "consumers"})


def _parse_tenant_id(sts_url: str) -> Optional[str]:
"""Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL.

Expected formats:
https://login.microsoftonline.com/<tenant>/
https://login.microsoftonline.com/<tenant>/?...
https://login.microsoftonline.com/<tenant>
where <tenant> is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``)
or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are
accepted by ``azure.identity.ClientSecretCredential``.

Returns ``None`` for the multi-tenant aliases ``common`` / ``organizations``
/ ``consumers``: confidential clients (SP) cannot authenticate against
them and AAD responds with a cryptic ``AADSTS50194`` ("application is not
configured as multi-tenant"). Failing fast in the factory surfaces a
clearer error than the AAD round-trip would.
"""
# pylint: disable=import-outside-toplevel
from urllib.parse import urlparse

try:
parsed = urlparse(sts_url)
except (ValueError, AttributeError):
return None
# Reject anything that isn't an https URL with a netloc. ``urlparse`` will
# happily put a bare string like ``"tenant-guid"`` into ``path``, which
# would then look like a valid tenant. Azure AD STS URLs are always https.
if parsed.scheme != "https" or not parsed.netloc:
return None
path = (parsed.path or "").strip("/")
if not path:
return None
first_segment = path.split("/", 1)[0]
if not first_segment:
return None
if first_segment.lower() in _RESERVED_TENANTS:
return None
return first_segment


class ServicePrincipalAuth:
"""Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal.

The bulkcopy path through mssql-py-core uses callback-based token
acquisition (FedAuth workflow ``0x02``) because tenant_id is only known
from the STS URL that the server returns during the TDS handshake.
"""

@staticmethod
def make_token_factory(client_id: str, client_secret: str):
"""Return a callable suitable for ``entra_id_token_factory``.

Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``.
Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format).

``ClientSecretCredential`` instances are reused across calls via the
Comment thread
jahnvi480 marked this conversation as resolved.
module-level ``_credential_cache``, keyed by
``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's
in-memory token cache (which is per-credential-instance) actually
works across handshake retries, reconnects, and separate bulkcopy
invocations using the same identity.
"""
if not client_id:
raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)")
if not client_secret:
raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)")

def _factory(spn: str, sts_url: str, auth_method: str) -> bytes:
# pylint: disable=import-outside-toplevel,unused-argument

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auth_method is silenced via unused-argument. Are we expecting it to be used in furture?

try:
from azure.identity import ClientSecretCredential
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.transport import RequestsTransport
except ImportError as e:
raise RuntimeError(
"Azure authentication libraries are not installed. "
"Please install with: pip install azure-identity azure-core"
) from e

if not spn:
raise RuntimeError(
"ServicePrincipal token factory: empty SPN from server "
"(cannot construct token scope)"
)
tenant_id = _parse_tenant_id(sts_url)
if not tenant_id:
raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}")

logger.info(
"ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s",
tenant_id,
spn,
)
try:
# Reuse the shared credential cache (introduced for MSI in PR #573)
# so SP credentials get the same per-instance token reuse semantics
# as the other AD methods.
#
# The cache key includes a hash of client_secret so a rotated
# secret produces a different cache entry. Without this, an
# external secret rotation would not invalidate the cached
# ClientSecretCredential: azure-identity's internal token cache
# would keep returning the previously-issued token (good for
# up to ~1 hour) until expiry, masking the rotation. Hashing
# avoids storing the raw secret in the dict key.
secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

secret_hash is recomputed on every factory invocation, but client_secret is fixed for the lifetime of the closure returned by make_token_factory. Compute the hash once in the outer scope and capture it:

secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest()
def _factory(spn, sts_url, auth_method):
    ...
    cache_key = _credential_cache_key(_AuthInternal.SERVICE_PRINCIPAL,
        {"tenant_id": tenant_id, "client_id": client_id, "secret_hash": secret_hash})

Save the SHA-256 per call (tens of μs each) and remove a redundant import of hashlib from the hot path.

cache_key = _credential_cache_key(
_AuthInternal.SERVICE_PRINCIPAL,
{
"tenant_id": tenant_id,
"client_id": client_id,
"secret_hash": secret_hash,
},
)
with _credential_cache_lock:
credential = _credential_cache.get(cache_key)
if credential is None:
# Bound the AAD network round-trip. Without explicit
# timeouts, azure-identity's defaults can let an
# unreachable / slow STS endpoint block the calling
# thread for tens of seconds. The factory runs on a
# mssql-py-core blocking-pool worker (tokio
# spawn_blocking), so a stuck callback ties that
# worker up for the duration. SP is non-interactive
# and token issuance is typically <1s; 10s/15s is
# generous and still bounded.
transport = RequestsTransport(
connection_timeout=10,
read_timeout=15,
)
# KNOWN LIMITATION: ``authority=`` is not passed,
# so this defaults to the public-cloud authority
# (login.microsoftonline.com). Sovereign clouds
# (Azure US Gov, Azure China) are not supported on
# this code path today: AAD will fail with
# "tenant not found" because the tenant lives in
# the sovereign cloud's AAD, not the public one.
# Tracked as a follow-up; the fix is to derive
# ``authority`` from ``urlparse(sts_url).netloc``.
# Out of scope for the initial #534 work.
credential = ClientSecretCredential(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
transport=transport,
)
_credential_cache[cache_key] = credential

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_credential_cache is never evicted. After a client_secret rotation, a fresh entry is added under the new secret_hash, but the old ClientSecretCredential stays in the dict for the life of the process - holding the old secret in memory plus its internal token cache. For long-lived processes that rotate secrets periodically this is a slow leak.

Consider evicting any prior entries matching (serviceprincipal, tenant_id, client_id, *) when inserting a new secret_hash, or documenting the long-running-process caveat. Not blocking.

# mssql-tds passes the resource SPN; azure-identity wants a scope.
scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default"
token = credential.get_token(scope).token
logger.info(
"ServicePrincipal token factory: token acquired, length=%d chars",
len(token),
)
return token.encode("utf-16-le")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AADAuth.get_token_struct (used for the access_token path) returns a <I length-prefixed UTF-16LE struct. This factory returns bare token.encode("utf-16-le") with no length prefix.

one line in the docstring like "py-core handles the FedAuth length-prefix wrapping; do not pre-wrap" so a future reader doesn't try to "fix" this to match get_token_struct.

except ClientAuthenticationError as e:
# Keep the detailed provider error in debug logs only. The
# surfaced message is intentionally generic so that any
# secret-bearing provider text never reaches the user-facing
# exception chain.
logger.error(
"ServicePrincipal authentication failed: tenant=%s, error=%s",
tenant_id,
str(e),
)
raise RuntimeError(
"ServicePrincipal authentication failed; " "see debug logs for provider details"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message says "see debug logs for provider details" but the provider error is logged at logger.error(...) a few lines up, not debug. Either drop the log to debug or update the message to "see error logs" so the user knows where to actually look.

) from None

return _factory


def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]:
"""
Extract authentication type from parsed connection parameters.
Expand All @@ -190,10 +368,20 @@ def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]:
return None

# On Windows, Interactive auth is handled natively by the ODBC driver.
if auth_type == "interactive" and platform.system().lower() == "windows":
if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows":
logger.debug("process_auth_parameters: Windows platform - using native AADInteractive")
return None

# ServicePrincipal: ODBC (msodbcsql 17.3+) handles this natively for
# regular queries, so return None to let ODBC own the query path. Bulkcopy
# still needs the auth type (Connection.__init__ falls back to
# extract_auth_type for that), and the cursor.bulkcopy() callback branch
# registers an entra_id_token_factory because tenant_id is only known
# from the STS URL the server returns during the FedAuth handshake.
if auth_type == _AuthInternal.SERVICE_PRINCIPAL:
logger.debug("process_auth_parameters: ServicePrincipal - ODBC handles natively")
return None

logger.debug("process_auth_parameters: auth_type=%s", auth_type)
return auth_type

Expand All @@ -213,7 +401,7 @@ def get_auth_token(
return None

# Handle platform-specific logic for interactive auth
if auth_type == "interactive" and platform.system().lower() == "windows":
if auth_type == _AuthInternal.INTERACTIVE and platform.system().lower() == "windows":
logger.debug("get_auth_token: Windows interactive auth - delegating to native handler")
return None # Let Windows handle AADInteractive natively

Expand Down
9 changes: 7 additions & 2 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@
from mssql_python.constants import ConstantsDDBC, GetInfoConstants
from mssql_python.connection_string_parser import _ConnectionStringParser
from mssql_python.connection_string_builder import _ConnectionStringBuilder
from mssql_python.constants import _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID
from mssql_python.constants import (
_RESERVED_PARAMETERS,
_KEY_AUTHENTICATION,
_KEY_UID,
_AuthInternal,
)

if TYPE_CHECKING:
from mssql_python.row import Row
Expand Down Expand Up @@ -344,7 +349,7 @@ def __init__(
# Capture credential kwargs (e.g. user-assigned MSI client_id)
# from the parsed dict *before* remove_sensitive_params strips UID.
credential_kwargs: Optional[Dict[str, str]] = None
if auth_type == "msi":
if auth_type == _AuthInternal.MSI:
uid = (parsed_params.get(_KEY_UID) or "").strip()
if uid:
credential_kwargs = {"client_id": uid}
Expand Down
18 changes: 17 additions & 1 deletion mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,28 @@ class GetInfoConstants(Enum):


class AuthType(Enum):
"""Constants for authentication types"""
"""Constants for authentication types (public/ODBC connection-string form)."""

INTERACTIVE = "activedirectoryinteractive"
DEVICE_CODE = "activedirectorydevicecode"
DEFAULT = "activedirectorydefault"
MSI = "activedirectorymsi"
SERVICE_PRINCIPAL = "activedirectoryserviceprincipal"


class _AuthInternal:
"""Internal short-form auth identifiers used after normalization.

Paired with :class:`AuthType` (the public ODBC-string form) by
``mssql_python.auth._AUTH_TYPE_MAP``. Adding a new auth mode requires
one entry here, one in ``AuthType``, and one in ``_AUTH_TYPE_MAP``.
"""

DEFAULT = "default"
DEVICE_CODE = "devicecode"
INTERACTIVE = "interactive"
MSI = "msi"
SERVICE_PRINCIPAL = "serviceprincipal"


class SQLTypes:
Expand Down
Loading
Loading