Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 155 additions & 20 deletions msal/application.py

Large diffs are not rendered by default.

81 changes: 77 additions & 4 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from urllib.parse import urlparse # Python 3+
from collections import UserDict # Python 3+
from typing import List, Optional, Union # Needed in Python 3.7 & 3.8
from .token_cache import TokenCache
from .token_cache import TokenCache, _compute_ext_cache_key, _parse_claims_or_raise
from .individual_cache import _IndividualCache as IndividualCache
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
from .cloudshell import _is_running_in_cloud_shell
Expand All @@ -26,6 +26,14 @@ class ManagedIdentityError(ValueError):
pass


_XMS_AZ_NWPERIMID = "xms_az_nwperimid"

_CLIENT_CLAIMS_UNSUPPORTED_SOURCE = (
"client_claims is only supported for the IMDS (Azure VM) managed identity "
"source. The detected source ({source}) does not support forwarding "
"client-originated claims.")


class ManagedIdentity(UserDict):
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
to acquire token for the specified managed identity.
Expand Down Expand Up @@ -261,6 +269,7 @@ def acquire_token_for_client(
*,
resource: str, # If/when we support scope, resource will become optional
claims_challenge: Optional[str] = None,
client_claims: Optional[str] = None,
):
"""Acquire token for the managed identity.

Expand All @@ -280,6 +289,22 @@ def acquire_token_for_client(
even if the app developer did not opt in for the "CP1" client capability.
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.

:param client_claims:
Optional.
A string representation of a JSON object containing
*client-originated* claims to forward to the identity endpoint
(for example a network security perimeter ``xms_az_nwperimid`` claim).

Unlike ``claims_challenge`` (server-issued, which bypasses the cache),
tokens acquired with ``client_claims`` **are cached**, and the cache
entry is keyed on the claims value. Different ``client_claims`` values
produce separate cache entries, so use stable, non-dynamic values to
avoid unbounded cache growth.

Only the IMDS (Azure VM) managed identity source supports this
parameter; other sources raise an error. On IMDS v1, the claims JSON
may contain only the ``xms_az_nwperimid`` key.

.. note::

Known issue: When an Azure VM has only one user-assigned managed identity,
Expand All @@ -294,6 +319,17 @@ def acquire_token_for_client(
client_id_in_cache = self._managed_identity.get(
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
now = time.time()
if client_claims is not None:
if not isinstance(client_claims, str):
raise ValueError(
"client_claims must be a string, got {}".format(
type(client_claims).__name__))
_parse_claims_or_raise(client_claims) # Fail fast on malformed JSON
Comment thread
Copilot marked this conversation as resolved.
# Client-originated claims isolate the cache: a distinct claims value gets
Comment thread
Robbie-Microsoft marked this conversation as resolved.
# a distinct cache entry. (Server-issued claims_challenge, by contrast,
# bypasses the cache and is keyed normally.)
ext_cache_key = _compute_ext_cache_key(
{"client_claims": client_claims}) if client_claims else None
if True: # Attempt cache search even if receiving claims_challenge,
# because we want to locate the existing token (if any) and refresh it
matches = self._token_cache.search(
Expand All @@ -304,6 +340,7 @@ def acquire_token_for_client(
environment=self.__instance,
realm=self._tenant,
home_account_id=None,
**({"ext_cache_key": ext_cache_key} if ext_cache_key else {}),
),
)
for entry in matches:
Expand Down Expand Up @@ -334,6 +371,7 @@ def acquire_token_for_client(
access_token_to_refresh.encode("utf-8")).hexdigest()
if access_token_to_refresh else None,
client_capabilities=self._client_capabilities,
client_claims=client_claims,
)
if "access_token" in result:
expires_in = result.get("expires_in", 3600)
Expand All @@ -346,7 +384,7 @@ def acquire_token_for_client(
self.__instance, self._tenant),
response=result,
params={},
data={},
data={"client_claims": client_claims} if client_claims else {},
))
if "refresh_in" in result:
result["refresh_on"] = int(now + result["refresh_in"])
Expand Down Expand Up @@ -414,10 +452,14 @@ def _obtain_token(
*,
access_token_sha256_to_refresh: Optional[str] = None,
client_capabilities: Optional[List[str]] = None,
client_claims: Optional[str] = None,
):
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
):
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Service Fabric"))
if managed_identity:
logger.debug(
"Ignoring managed_identity parameter. "
Expand All @@ -434,6 +476,9 @@ def _obtain_token(
client_capabilities=client_capabilities,
)
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="App Service"))
return _obtain_token_on_app_service(
http_client,
os.environ["IDENTITY_ENDPOINT"],
Expand All @@ -442,6 +487,9 @@ def _obtain_token(
resource,
)
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Machine Learning"))
# Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
return _obtain_token_on_machine_learning(
http_client,
Expand All @@ -452,14 +500,18 @@ def _obtain_token(
)
arc_endpoint = _get_arc_endpoint()
if arc_endpoint:
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Azure Arc"))
if ManagedIdentity.is_user_assigned(managed_identity):
raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too
"Invalid managed_identity parameter. "
"Azure Arc supports only system-assigned managed identity, "
"See also "
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
return _obtain_token_on_arc(http_client, arc_endpoint, resource)
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)
return _obtain_token_on_azure_vm(
http_client, managed_identity, resource, client_claims=client_claims)


def _adjust_param(params, managed_identity, types_mapping=None):
Expand All @@ -469,14 +521,35 @@ def _adjust_param(params, managed_identity, types_mapping=None):
if id_name:
params[id_name] = managed_identity[ManagedIdentity.ID]

def _obtain_token_on_azure_vm(http_client, managed_identity, resource):
def _validate_msiv1_claims(client_claims):
"""MSIv1 (IMDS v1) only supports the single ``xms_az_nwperimid`` custom claim.

Any other top-level key makes IMDS return HTTP 400 with no useful diagnostic,
so validate early and raise a clear error. Mirrors MSAL .NET's
``AbstractManagedIdentity.ValidateMsiv1Claims``.
"""
parsed = _parse_claims_or_raise(client_claims)
for key in parsed:
if key != _XMS_AZ_NWPERIMID:
raise ManagedIdentityError(
"MSIv1 (IMDS v1) only supports the `{expected}` custom claim. "
"The claims JSON contained the unsupported key `{actual}`. "
"Remove all keys other than `{expected}` when using client_claims "
"with MSIv1.".format(expected=_XMS_AZ_NWPERIMID, actual=key))


def _obtain_token_on_azure_vm(http_client, managed_identity, resource, client_claims=None):
# Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
logger.debug("Obtaining token via managed identity on Azure VM")
params = {
"api-version": "2018-02-01",
"resource": resource,
}
_adjust_param(params, managed_identity)
if client_claims:
# IMDS v1 (MSIv1) only supports the single xms_az_nwperimid claim.
_validate_msiv1_claims(client_claims)
params["claims"] = client_claims # http_client.get url-encodes query params
resp = http_client.get(
os.getenv(
"AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254"
Expand Down
8 changes: 8 additions & 0 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
_data.update(data or {}) # So the content in data param prevails
_data = {k: v for k, v in _data.items() if v} # Clean up None values

# "client_claims" is a cache-key-only pseudo-parameter: callers merge its
# value into the standard "claims" body parameter upstream, and it is kept
# in the request data solely so it contributes to the extended cache key.
# It must not be sent on the wire. Popping it here (from this method's own
# local copy) keeps the wire body clean while the caller's data dict — used
# for the cache-add event — still carries it.
_data.pop("client_claims", None)

if _data.get('scope'):
_data['scope'] = self._stringify(_data['scope'])

Expand Down
55 changes: 55 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,61 @@ def _compute_ext_cache_key(data):
return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()


def _parse_claims_or_raise(claims):
"""Parse a claims JSON string into a dict, or raise a friendly ``ValueError``.

The raw claims value is never included in the error message because it may
contain sensitive data. Mirrors MSAL .NET's ``ClaimsHelper.ParseClaimsOrThrow``.
"""
try:
parsed = json.loads(claims)
except (ValueError, TypeError) as ex:
# json.JSONDecodeError (malformed JSON) is a subclass of ValueError;
# TypeError is raised when *claims* is not a str/bytes/bytearray. Both
# are surfaced as the same friendly ValueError so every caller behaves
# consistently regardless of the bad input's type.
raise ValueError(
"The claims value is not valid JSON. "
"See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter."
) from ex
if not isinstance(parsed, dict):
# A valid JSON array, scalar, or the literal "null" is not a claims object.
raise ValueError(
"The claims value is not a valid JSON object. "
"See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter.")
return parsed


def _deep_merge_dict(base, overlay):
"""Recursively merge ``overlay`` into ``base``, returning a new dict.

Nested dicts are merged; for any other value type, ``overlay`` wins.
"""
result = dict(base)
for key, value in overlay.items():
if (key in result
and isinstance(result[key], dict) and isinstance(value, dict)):
result[key] = _deep_merge_dict(result[key], value)
else:
result[key] = value
return result


def _merge_claims(claims_a, claims_b):
"""Merge two claims JSON strings into a single JSON string.

If either side is empty/None, the other is returned as-is. Mirrors MSAL
.NET's ``ClaimsHelper.MergeClaimsObjects``.
"""
if not claims_a:
return claims_b
if not claims_b:
return claims_a
merged = _deep_merge_dict(
_parse_claims_or_raise(claims_a), _parse_claims_or_raise(claims_b))
return json.dumps(merged)


def is_subdict_of(small, big):
return dict(big, **small) == big

Expand Down
Loading
Loading