From 1e184bd85573e943a82fd44082f5b8a430c33bec Mon Sep 17 00:00:00 2001 From: Nikita Kovalev Date: Tue, 28 Apr 2026 18:03:42 +0200 Subject: [PATCH 1/2] Add usage param and serverless content parsing --- connect/__init__.py | 2 + connect/customer/__init__.py | 3 +- connect/customer/content_parser.py | 7 +- connect/customer/token.py | 186 +++++++++---- connect/types.py | 9 + examples/obtain_and_verify_license_token.py | 4 + tests/customer/conftest.py | 3 +- tests/customer/test_content_parser.py | 24 +- tests/customer/test_tokens.py | 279 ++++++++++++++++++++ 9 files changed, 459 insertions(+), 58 deletions(-) diff --git a/connect/__init__.py b/connect/__init__.py index 2064af6..9204341 100644 --- a/connect/__init__.py +++ b/connect/__init__.py @@ -11,6 +11,7 @@ HandlerResult, RSLVerificationResult, SupertabConnectConfig, + UsageType, ) __all__ = [ @@ -21,6 +22,7 @@ "SupertabConnect", "SupertabConnectError", "SupertabConnectConfig", + "UsageType", "default_bot_detector", "obtain_license_token", "verify_license_token", diff --git a/connect/customer/__init__.py b/connect/customer/__init__.py index c5aa640..e875cc6 100644 --- a/connect/customer/__init__.py +++ b/connect/customer/__init__.py @@ -1,5 +1,6 @@ """Customer functionality for Supertab Connect.""" from connect.customer.token import obtain_license_token +from connect.types import UsageType -__all__ = ["obtain_license_token"] +__all__ = ["UsageType", "obtain_license_token"] diff --git a/connect/customer/content_parser.py b/connect/customer/content_parser.py index bc7cf64..4bad3c4 100644 --- a/connect/customer/content_parser.py +++ b/connect/customer/content_parser.py @@ -14,7 +14,7 @@ class _ContentBlock: url_pattern: str license_xml: str - server: str + server: str | None = None def _clean_attribute(value: str | None) -> str | None: @@ -47,9 +47,9 @@ def _parse_content_elements(xml: str, debug: bool = False) -> list[_ContentBlock if license_el is None: license_el = content_el.find("license", namespaces=_NS) - license_xml = ElementTree.tostring(license_el, encoding="unicode") if license_el is not None else None + license_xml = ElementTree.tostring(license_el, encoding="unicode").strip() if license_el is not None else None - if url_pattern and server and license_xml: + if url_pattern and license_xml: content_blocks.append( _ContentBlock( url_pattern=url_pattern, @@ -63,7 +63,6 @@ def _parse_content_elements(xml: str, debug: bool = False) -> list[_ContentBlock value for value in ( None if url_pattern else "url", - None if server else "server", None if license_xml else "", ) if value is not None diff --git a/connect/customer/token.py b/connect/customer/token.py index d38e87b..01f1d36 100644 --- a/connect/customer/token.py +++ b/connect/customer/token.py @@ -6,6 +6,7 @@ import time import urllib.parse from dataclasses import dataclass +from xml.etree import ElementTree from typing import Any from weakref import WeakKeyDictionary @@ -17,14 +18,18 @@ from connect.common import debug_log, error_log from connect.exceptions import SupertabConnectError from connect.customer.content_matcher import _find_best_matching_content +from connect.customer.content_parser import _ContentBlock from connect.customer.content_parser import _parse_content_elements +from connect.types import UsageType _SUPPORTED_ALGS = ("ES256", "RS256") _DEFAULT_HTTP_TIMEOUT_SECONDS = 10.0 -_LICENSE_TOKEN_CACHE: dict[tuple[str, str], "_CachedToken"] = {} -_LICENSE_TOKEN_LOCKS: WeakKeyDictionary[asyncio.AbstractEventLoop, dict[tuple[str, str], asyncio.Lock]] = ( +_LICENSE_TOKEN_CACHE: dict[tuple[str, str, str], "_CachedToken"] = {} +_LICENSE_TOKEN_LOCKS: WeakKeyDictionary[asyncio.AbstractEventLoop, dict[tuple[str, str, str], asyncio.Lock]] = ( WeakKeyDictionary() ) +_LICENSE_XML_TTL_SECONDS = 15 * 60 +_LICENSE_XML_CACHE: dict[str, "_CachedLicenseXml"] = {} @dataclass(frozen=True) @@ -33,6 +38,12 @@ class _CachedToken: exp: int +@dataclass(frozen=True) +class _CachedLicenseXml: + xml: str + fetched_at: int + + def _build_origin(resource_url: str) -> str: parsed = urllib.parse.urlparse(resource_url) if not parsed.scheme or not parsed.netloc: @@ -40,7 +51,7 @@ def _build_origin(resource_url: str) -> str: return f"{parsed.scheme}://{parsed.netloc}" -def _get_cached_token(cache_key: tuple[str, str], debug: bool = False) -> str | None: +def _get_cached_token(cache_key: tuple[str, str, str], debug: bool = False) -> str | None: cached = _LICENSE_TOKEN_CACHE.get(cache_key) if cached is None: return None @@ -58,7 +69,7 @@ def _get_cached_token(cache_key: tuple[str, str], debug: bool = False) -> str | return None -def _get_cache_lock(cache_key: tuple[str, str]) -> asyncio.Lock: +def _get_cache_lock(cache_key: tuple[str, str, str]) -> asyncio.Lock: loop = asyncio.get_running_loop() loop_locks = _LICENSE_TOKEN_LOCKS.get(loop) if loop_locks is None: @@ -73,6 +84,15 @@ def _get_cache_lock(cache_key: tuple[str, str]) -> asyncio.Lock: return lock +def _evict_expired_license_xml() -> None: + now = int(time.time()) + expired_origins = [ + origin for origin, entry in _LICENSE_XML_CACHE.items() if now - entry.fetched_at >= _LICENSE_XML_TTL_SECONDS + ] + for origin in expired_origins: + _LICENSE_XML_CACHE.pop(origin, None) + + def _create_async_client(**kwargs: Any) -> httpx.AsyncClient: kwargs.setdefault("follow_redirects", True) kwargs.setdefault("timeout", httpx.Timeout(_DEFAULT_HTTP_TIMEOUT_SECONDS)) @@ -210,7 +230,22 @@ async def _fetch_license_xml( resource_url: str, debug: bool = False, ) -> str: - license_xml_url = f"{_build_origin(resource_url)}/license.xml" + origin = _build_origin(resource_url) + cached = _LICENSE_XML_CACHE.get(origin) + if cached is not None: + now = int(time.time()) + age = now - cached.fetched_at + if age < _LICENSE_XML_TTL_SECONDS: + debug_log( + debug, + f"Using cached license.xml for origin {origin} (expires in {_LICENSE_XML_TTL_SECONDS - age}s)", + ) + return cached.xml + + debug_log(debug, f"Cached license.xml for origin {origin} expired, re-fetching") + _LICENSE_XML_CACHE.pop(origin, None) + + license_xml_url = f"{origin}/license.xml" try: response = await client.get(license_xml_url) @@ -230,50 +265,109 @@ async def _fetch_license_xml( raise SupertabConnectError(message) from error debug_log(debug, f"Fetched license.xml from {license_xml_url}") + _evict_expired_license_xml() + _LICENSE_XML_CACHE[origin] = _CachedLicenseXml(xml=xml, fetched_at=int(time.time())) return xml +def _local_name(tag: str) -> str: + return tag.rsplit("}", maxsplit=1)[-1] + + +def _license_permits_usage(license_xml: str, usage: UsageType | str) -> bool: + try: + root = ElementTree.fromstring(license_xml) + except ElementTree.ParseError: + return False + + usage_value = str(usage) + + for element in root.iter(): + if _local_name(element.tag) == "prohibits" and element.attrib.get("type") == "usage": + prohibited_usages = " ".join(element.itertext()).split() + if UsageType.ALL in prohibited_usages or usage_value in prohibited_usages: + return False + + for element in root.iter(): + if _local_name(element.tag) == "permits" and element.attrib.get("type") == "usage": + permitted_usages = " ".join(element.itertext()).split() + if UsageType.ALL in permitted_usages or usage_value in permitted_usages: + return True + + return False + + +def _find_serverless_usage_content( + content_blocks: list[_ContentBlock], + resource_url: str, + usage: UsageType | str, + debug: bool = False, +) -> _ContentBlock | None: + matching_usage_blocks = [ + block for block in content_blocks if block.server is None and _license_permits_usage(block.license_xml, usage) + ] + + return _find_best_matching_content(matching_usage_blocks, resource_url, debug) + + async def obtain_license_token( *, client_id: str, client_secret: str, resource_url: str, + usage: UsageType | str | None = None, debug: bool = False, -) -> str: +) -> str | None: """Obtain a license token using the current client credentials flow. This is the supported customer flow. The SDK fetches ``license.xml`` for the requested resource, finds the best matching ```` block, and - exchanges the client credentials for a license token. + exchanges the client credentials for a license token. If ``usage`` is + provided and a matching serverless content block permits that usage, no + token is needed and ``None`` is returned. """ - cache_key = (client_id, resource_url) - cached = _get_cached_token(cache_key, debug) - if cached is not None: - return cached + async with _create_async_client() as client: + xml = await _fetch_license_xml(client, resource_url, debug) + debug_log(debug, f"Fetched license.xml ({len(xml)} chars)") + content_blocks = _parse_content_elements(xml, debug) + + if not content_blocks: + error_log(debug, "No valid elements with found in license.xml") + raise SupertabConnectError("No valid elements with found in license.xml") + + if usage is not None: + serverless_usage_content = _find_serverless_usage_content(content_blocks, resource_url, usage, debug) + if serverless_usage_content is not None: + debug_log( + debug, + "Matched serverless content to usage and resource URL combination, skipping license token request.", + ) + debug_log(debug, f"URL: {resource_url}, Usage: {usage}") + return None + + token_content_blocks = [block for block in content_blocks if block.server] + matched_content = _find_best_matching_content(token_content_blocks, resource_url, debug) + if matched_content is None or matched_content.server is None: + patterns = ", ".join(block.url_pattern for block in token_content_blocks) + error_log( + debug, + f"No element matches resource URL: {resource_url}. Available patterns: {patterns}", + ) + raise SupertabConnectError(f"No element in license.xml matches resource URL: {resource_url}") + + debug_log(debug, f"Matched content block for resource URL: {resource_url}") + debug_log(debug, f"Using license XML: {matched_content.license_xml}") - lock = _get_cache_lock(cache_key) - async with lock: + cache_key = (client_id, matched_content.server, matched_content.url_pattern) cached = _get_cached_token(cache_key, debug) if cached is not None: return cached - async with _create_async_client() as client: - xml = await _fetch_license_xml(client, resource_url, debug) - debug_log(debug, f"Fetched license.xml ({len(xml)} chars)") - content_blocks = _parse_content_elements(xml, debug) - - if not content_blocks: - error_log(debug, "No valid elements with found in license.xml") - raise SupertabConnectError("No valid elements with found in license.xml") - - matched_content = _find_best_matching_content(content_blocks, resource_url, debug) - if matched_content is None: - patterns = ", ".join(block.url_pattern for block in content_blocks) - error_log( - debug, - f"No element matches resource URL: {resource_url}. Available patterns: {patterns}", - ) - raise SupertabConnectError(f"No element in license.xml matches resource URL: {resource_url}") + lock = _get_cache_lock(cache_key) + async with lock: + cached = _get_cached_token(cache_key, debug) + if cached is not None: + return cached token_endpoint = matched_content.server.rstrip("/") + "/token" debug_log(debug, f"Requesting license token from {token_endpoint}") @@ -295,22 +389,22 @@ async def obtain_license_token( debug=debug, ) - try: - claims = jwt.decode( - token, - options={ - "verify_signature": False, - "verify_exp": False, - "verify_aud": False, - "verify_iss": False, - }, - algorithms=["HS256", "RS256", "ES256", "PS256"], - ) - exp = claims.get("exp") - if isinstance(exp, int): - _LICENSE_TOKEN_CACHE[cache_key] = _CachedToken(token=token, exp=exp) - except (jwt.PyJWTError, ValueError, TypeError) as error: - debug_log(debug, f"Failed to decode token for caching, skipping cache: {error}") + try: + claims = jwt.decode( + token, + options={ + "verify_signature": False, + "verify_exp": False, + "verify_aud": False, + "verify_iss": False, + }, + algorithms=["HS256", "RS256", "ES256", "PS256"], + ) + exp = claims.get("exp") + if isinstance(exp, int): + _LICENSE_TOKEN_CACHE[cache_key] = _CachedToken(token=token, exp=exp) + except (jwt.PyJWTError, ValueError, TypeError) as error: + debug_log(debug, f"Failed to decode token for caching, skipping cache: {error}") return token diff --git a/connect/types.py b/connect/types.py index 8bc7299..1f7a348 100644 --- a/connect/types.py +++ b/connect/types.py @@ -31,6 +31,15 @@ class HandlerAction(StrEnum): BLOCK = "block" +class UsageType(StrEnum): + ALL = "all" + SEARCH = "search" + AI_ALL = "ai-all" + AI_TRAIN = "ai-train" + AI_INDEX = "ai-index" + AI_INPUT = "ai-input" + + BotDetector: TypeAlias = Callable[[Request], bool] diff --git a/examples/obtain_and_verify_license_token.py b/examples/obtain_and_verify_license_token.py index 17f1124..54fb634 100644 --- a/examples/obtain_and_verify_license_token.py +++ b/examples/obtain_and_verify_license_token.py @@ -17,6 +17,10 @@ async def main() -> None: debug=True, ) + assert isinstance(token, str), ( + "Token was not generated. Check your credentials and whether you have a license to the resource URL" + ) + print(f"Generated license token: {token}") result = await verify_license_token( diff --git a/tests/customer/conftest.py b/tests/customer/conftest.py index 95be7f1..284df57 100644 --- a/tests/customer/conftest.py +++ b/tests/customer/conftest.py @@ -1,6 +1,6 @@ import pytest -from connect.customer.token import _LICENSE_TOKEN_CACHE, _LICENSE_TOKEN_LOCKS +from connect.customer.token import _LICENSE_TOKEN_CACHE, _LICENSE_TOKEN_LOCKS, _LICENSE_XML_CACHE SAMPLE_XML = """ @@ -27,3 +27,4 @@ def clear_token_cache() -> None: _LICENSE_TOKEN_CACHE.clear() _LICENSE_TOKEN_LOCKS.clear() + _LICENSE_XML_CACHE.clear() diff --git a/tests/customer/test_content_parser.py b/tests/customer/test_content_parser.py index 86d3efc..807d9b7 100644 --- a/tests/customer/test_content_parser.py +++ b/tests/customer/test_content_parser.py @@ -33,12 +33,6 @@ def test_parse_content_elements_parses_multiple_blocks() -> None: """, - # Missing required server attribute on the element. - """ - - - - """, # Reject whitespace-only attributes that are present but effectively empty. """ @@ -54,3 +48,21 @@ def test_parse_content_elements_parses_multiple_blocks() -> None: def test_parse_content_elements_skips_invalid_content(xml: str) -> None: """Invalid or incomplete content elements produce no blocks.""" assert _parse_content_elements(xml) == [] + + +def test_parse_content_elements_keeps_serverless_content() -> None: + """Serverless content is valid for usage grants.""" + xml = """ + + + + + + """ + + blocks = _parse_content_elements(xml) + + assert len(blocks) == 1 + assert blocks[0].url_pattern == "http://example.com/*" + assert blocks[0].server is None + assert blocks[0].license_xml == '' diff --git a/tests/customer/test_tokens.py b/tests/customer/test_tokens.py index d149a6b..059871e 100644 --- a/tests/customer/test_tokens.py +++ b/tests/customer/test_tokens.py @@ -12,6 +12,7 @@ from connect.customer.token import _create_async_client, _generate_license_token, obtain_license_token from connect.exceptions import SupertabConnectError +from connect.types import UsageType from tests.customer.conftest import SAMPLE_XML @@ -98,6 +99,106 @@ async def run() -> None: asyncio.run(run()) +def test_obtain_license_token_fetches_license_xml_once_per_origin( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Caches license.xml by origin while still fetching tokens per matched pattern.""" + exp = int(time.time()) + 3600 + access_token = jwt.encode({"exp": exp}, "x" * 32, algorithm="HS256") + origin = "http://cachetest.example" + token_endpoint = "http://token-server.com/token" + calls = {"license_xml": 0, "token": 0} + xml = f""" + + + + + + + + + """ + + async def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == f"{origin}/license.xml": + calls["license_xml"] += 1 + return httpx.Response(200, text=xml, request=request) + + if str(request.url) == token_endpoint: + calls["token"] += 1 + return httpx.Response(200, json={"access_token": access_token}, request=request) + + raise AssertionError(f"Unexpected URL: {request.url!s}") + + _install_mock_transport(monkeypatch, handler) + + asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/foo", + ) + ) + asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/news/bar", + ) + ) + + assert calls == {"license_xml": 1, "token": 2} + + +def test_obtain_license_token_reuses_token_for_same_matched_pattern( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Caches license tokens by client, token server, and matched URL pattern.""" + exp = int(time.time()) + 3600 + access_token = jwt.encode({"exp": exp}, "x" * 32, algorithm="HS256") + origin = "http://pattern-cache.example" + calls = {"license_xml": 0, "token": 0} + xml = f""" + + + + + + """ + + async def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == f"{origin}/license.xml": + calls["license_xml"] += 1 + return httpx.Response(200, text=xml, request=request) + + if str(request.url) == "http://token-server.com/token": + calls["token"] += 1 + return httpx.Response(200, json={"access_token": access_token}, request=request) + + raise AssertionError(f"Unexpected URL: {request.url!s}") + + _install_mock_transport(monkeypatch, handler) + + first = asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/foo", + ) + ) + second = asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/bar", + ) + ) + + assert first == access_token + assert second == access_token + assert calls == {"license_xml": 1, "token": 1} + + def test_obtain_license_token_follows_redirects( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -154,6 +255,184 @@ async def handler(request: httpx.Request) -> httpx.Response: } +def test_obtain_license_token_returns_none_for_matching_serverless_usage( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Skips token exchange when serverless content permits the requested usage.""" + origin = "http://search-serverless-match.example" + xml = """ + + + + search + + + + + + + """ + calls = {"license_xml": 0, "token": 0} + + async def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == f"{origin}/license.xml": + calls["license_xml"] += 1 + return httpx.Response(200, text=xml, request=request) + + if str(request.url) == "http://token-server.com/token": + calls["token"] += 1 + return httpx.Response(200, json={"access_token": "unused"}, request=request) + + raise AssertionError(f"Unexpected URL: {request.url!s}") + + _install_mock_transport(monkeypatch, handler) + + token = asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/foo", + usage=UsageType.SEARCH, + ) + ) + + assert token is None + assert calls == {"license_xml": 1, "token": 0} + + +def test_obtain_license_token_requests_token_when_serverless_usage_does_not_match( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Falls back to token exchange if the serverless usage grant matches a different resource.""" + exp = int(time.time()) + 3600 + access_token = jwt.encode({"exp": exp}, "x" * 32, algorithm="HS256") + origin = "http://search-serverless-miss.example" + xml = """ + + + + search + + + + + + + """ + calls = {"license_xml": 0, "token": 0} + + async def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == f"{origin}/license.xml": + calls["license_xml"] += 1 + return httpx.Response(200, text=xml, request=request) + + if str(request.url) == "http://token-server.com/token": + calls["token"] += 1 + return httpx.Response(200, json={"access_token": access_token}, request=request) + + raise AssertionError(f"Unexpected URL: {request.url!s}") + + _install_mock_transport(monkeypatch, handler) + + token = asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/foo", + usage=UsageType.SEARCH, + ) + ) + + assert token == access_token + assert calls == {"license_xml": 1, "token": 1} + + +def test_obtain_license_token_requests_token_when_usage_is_prohibited( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A matching prohibit entry wins over permits and forces token exchange.""" + exp = int(time.time()) + 3600 + access_token = jwt.encode({"exp": exp}, "x" * 32, algorithm="HS256") + origin = "http://usage-prohibited.example" + xml = """ + + + + ai-train search + ai-train + + + + + + + """ + calls = {"license_xml": 0, "token": 0} + + async def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == f"{origin}/license.xml": + calls["license_xml"] += 1 + return httpx.Response(200, text=xml, request=request) + + if str(request.url) == "http://token-server.com/token": + calls["token"] += 1 + return httpx.Response(200, json={"access_token": access_token}, request=request) + + raise AssertionError(f"Unexpected URL: {request.url!s}") + + _install_mock_transport(monkeypatch, handler) + + token = asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/foo", + usage=UsageType.AI_TRAIN, + ) + ) + + assert token == access_token + assert calls == {"license_xml": 1, "token": 1} + + +def test_obtain_license_token_accepts_all_usage_grant( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The all usage grant permits any specific usage type.""" + origin = "http://usage-all.example" + xml = """ + + + + all + + + + + + + """ + + async def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == f"{origin}/license.xml": + return httpx.Response(200, text=xml, request=request) + + raise AssertionError(f"Unexpected URL: {request.url!s}") + + _install_mock_transport(monkeypatch, handler) + + token = asyncio.run( + obtain_license_token( + client_id="client", + client_secret="secret", + resource_url=f"{origin}/articles/foo", + usage=UsageType.AI_INPUT, + ) + ) + + assert token is None + + @pytest.mark.parametrize( ("algorithm", "key_factory"), [ From 005da434133bc3a3955cc562fa23dc48e6a8ae06 Mon Sep 17 00:00:00 2001 From: Nikita Kovalev Date: Wed, 29 Apr 2026 11:22:07 +0200 Subject: [PATCH 2/2] Optimize _license_permits_usage as suggested in PR comments --- connect/customer/token.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/connect/customer/token.py b/connect/customer/token.py index 01f1d36..cef8fed 100644 --- a/connect/customer/token.py +++ b/connect/customer/token.py @@ -281,20 +281,27 @@ def _license_permits_usage(license_xml: str, usage: UsageType | str) -> bool: return False usage_value = str(usage) + is_permitted = False for element in root.iter(): - if _local_name(element.tag) == "prohibits" and element.attrib.get("type") == "usage": - prohibited_usages = " ".join(element.itertext()).split() - if UsageType.ALL in prohibited_usages or usage_value in prohibited_usages: - return False + if element.attrib.get("type") != "usage": + continue - for element in root.iter(): - if _local_name(element.tag) == "permits" and element.attrib.get("type") == "usage": - permitted_usages = " ".join(element.itertext()).split() - if UsageType.ALL in permitted_usages or usage_value in permitted_usages: - return True + tag = _local_name(element.tag) + if tag not in {"prohibits", "permits"}: + continue + + usages = " ".join(element.itertext()).split() + + if UsageType.ALL not in usages and usage_value not in usages: + continue + + if tag == "prohibits": + return False + + is_permitted = True - return False + return is_permitted def _find_serverless_usage_content(