Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions connect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HandlerResult,
RSLVerificationResult,
SupertabConnectConfig,
UsageType,
)

__all__ = [
Expand All @@ -21,6 +22,7 @@
"SupertabConnect",
"SupertabConnectError",
"SupertabConnectConfig",
"UsageType",
"default_bot_detector",
"obtain_license_token",
"verify_license_token",
Expand Down
3 changes: 2 additions & 1 deletion connect/customer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Customer functionality for Supertab Connect."""

from connect.customer.token import obtain_license_token
from connect.types import UsageType

__all__ = ["obtain_license_token"]
__all__ = ["UsageType", "obtain_license_token"]
7 changes: 3 additions & 4 deletions connect/customer/content_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class _ContentBlock:
url_pattern: str
license_xml: str
server: str
server: str | None = None


def _clean_attribute(value: str | None) -> str | None:
Expand Down Expand Up @@ -47,9 +47,9 @@ def _parse_content_elements(xml: str, debug: bool = False) -> list[_ContentBlock
if license_el is None:
license_el = content_el.find("license", namespaces=_NS)

license_xml = ElementTree.tostring(license_el, encoding="unicode") if license_el is not None else None
license_xml = ElementTree.tostring(license_el, encoding="unicode").strip() if license_el is not None else None

if url_pattern and server and license_xml:
if url_pattern and license_xml:
content_blocks.append(
_ContentBlock(
url_pattern=url_pattern,
Expand All @@ -63,7 +63,6 @@ def _parse_content_elements(xml: str, debug: bool = False) -> list[_ContentBlock
value
for value in (
None if url_pattern else "url",
None if server else "server",
None if license_xml else "<license>",
)
if value is not None
Expand Down
193 changes: 147 additions & 46 deletions connect/customer/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import urllib.parse
from dataclasses import dataclass
from xml.etree import ElementTree
from typing import Any
from weakref import WeakKeyDictionary

Expand All @@ -17,14 +18,18 @@
from connect.common import debug_log, error_log
from connect.exceptions import SupertabConnectError
from connect.customer.content_matcher import _find_best_matching_content
from connect.customer.content_parser import _ContentBlock
from connect.customer.content_parser import _parse_content_elements
from connect.types import UsageType

_SUPPORTED_ALGS = ("ES256", "RS256")
_DEFAULT_HTTP_TIMEOUT_SECONDS = 10.0
_LICENSE_TOKEN_CACHE: dict[tuple[str, str], "_CachedToken"] = {}
_LICENSE_TOKEN_LOCKS: WeakKeyDictionary[asyncio.AbstractEventLoop, dict[tuple[str, str], asyncio.Lock]] = (
_LICENSE_TOKEN_CACHE: dict[tuple[str, str, str], "_CachedToken"] = {}
_LICENSE_TOKEN_LOCKS: WeakKeyDictionary[asyncio.AbstractEventLoop, dict[tuple[str, str, str], asyncio.Lock]] = (
WeakKeyDictionary()
)
_LICENSE_XML_TTL_SECONDS = 15 * 60
_LICENSE_XML_CACHE: dict[str, "_CachedLicenseXml"] = {}


@dataclass(frozen=True)
Expand All @@ -33,14 +38,20 @@ class _CachedToken:
exp: int


@dataclass(frozen=True)
class _CachedLicenseXml:
xml: str
fetched_at: int


def _build_origin(resource_url: str) -> str:
parsed = urllib.parse.urlparse(resource_url)
if not parsed.scheme or not parsed.netloc:
raise SupertabConnectError(f"Invalid resource URL: {resource_url}")
return f"{parsed.scheme}://{parsed.netloc}"


def _get_cached_token(cache_key: tuple[str, str], debug: bool = False) -> str | None:
def _get_cached_token(cache_key: tuple[str, str, str], debug: bool = False) -> str | None:
cached = _LICENSE_TOKEN_CACHE.get(cache_key)
if cached is None:
return None
Expand All @@ -58,7 +69,7 @@ def _get_cached_token(cache_key: tuple[str, str], debug: bool = False) -> str |
return None


def _get_cache_lock(cache_key: tuple[str, str]) -> asyncio.Lock:
def _get_cache_lock(cache_key: tuple[str, str, str]) -> asyncio.Lock:
loop = asyncio.get_running_loop()
loop_locks = _LICENSE_TOKEN_LOCKS.get(loop)
if loop_locks is None:
Expand All @@ -73,6 +84,15 @@ def _get_cache_lock(cache_key: tuple[str, str]) -> asyncio.Lock:
return lock


def _evict_expired_license_xml() -> None:
now = int(time.time())
expired_origins = [
origin for origin, entry in _LICENSE_XML_CACHE.items() if now - entry.fetched_at >= _LICENSE_XML_TTL_SECONDS
]
for origin in expired_origins:
_LICENSE_XML_CACHE.pop(origin, None)


def _create_async_client(**kwargs: Any) -> httpx.AsyncClient:
kwargs.setdefault("follow_redirects", True)
kwargs.setdefault("timeout", httpx.Timeout(_DEFAULT_HTTP_TIMEOUT_SECONDS))
Expand Down Expand Up @@ -210,7 +230,22 @@ async def _fetch_license_xml(
resource_url: str,
debug: bool = False,
) -> str:
license_xml_url = f"{_build_origin(resource_url)}/license.xml"
origin = _build_origin(resource_url)
cached = _LICENSE_XML_CACHE.get(origin)
if cached is not None:
now = int(time.time())
age = now - cached.fetched_at
if age < _LICENSE_XML_TTL_SECONDS:
debug_log(
debug,
f"Using cached license.xml for origin {origin} (expires in {_LICENSE_XML_TTL_SECONDS - age}s)",
)
return cached.xml

debug_log(debug, f"Cached license.xml for origin {origin} expired, re-fetching")
_LICENSE_XML_CACHE.pop(origin, None)

license_xml_url = f"{origin}/license.xml"

try:
response = await client.get(license_xml_url)
Expand All @@ -230,50 +265,116 @@ async def _fetch_license_xml(
raise SupertabConnectError(message) from error

debug_log(debug, f"Fetched license.xml from {license_xml_url}")
_evict_expired_license_xml()
_LICENSE_XML_CACHE[origin] = _CachedLicenseXml(xml=xml, fetched_at=int(time.time()))
return xml


def _local_name(tag: str) -> str:
return tag.rsplit("}", maxsplit=1)[-1]


def _license_permits_usage(license_xml: str, usage: UsageType | str) -> bool:
try:
root = ElementTree.fromstring(license_xml)
except ElementTree.ParseError:
return False

usage_value = str(usage)
is_permitted = False

for element in root.iter():
if element.attrib.get("type") != "usage":
continue

tag = _local_name(element.tag)
if tag not in {"prohibits", "permits"}:
continue

usages = " ".join(element.itertext()).split()

if UsageType.ALL not in usages and usage_value not in usages:
continue

if tag == "prohibits":
return False

is_permitted = True

return is_permitted


def _find_serverless_usage_content(
content_blocks: list[_ContentBlock],
resource_url: str,
usage: UsageType | str,
debug: bool = False,
) -> _ContentBlock | None:
matching_usage_blocks = [
block for block in content_blocks if block.server is None and _license_permits_usage(block.license_xml, usage)
]

return _find_best_matching_content(matching_usage_blocks, resource_url, debug)


async def obtain_license_token(
*,
client_id: str,
client_secret: str,
resource_url: str,
usage: UsageType | str | None = None,
debug: bool = False,
) -> str:
) -> str | None:
"""Obtain a license token using the current client credentials flow.

This is the supported customer flow. The SDK fetches ``license.xml`` for
the requested resource, finds the best matching ``<content>`` block, and
exchanges the client credentials for a license token.
exchanges the client credentials for a license token. If ``usage`` is
provided and a matching serverless content block permits that usage, no
token is needed and ``None`` is returned.
"""
cache_key = (client_id, resource_url)
cached = _get_cached_token(cache_key, debug)
if cached is not None:
return cached
async with _create_async_client() as client:
xml = await _fetch_license_xml(client, resource_url, debug)
debug_log(debug, f"Fetched license.xml ({len(xml)} chars)")
content_blocks = _parse_content_elements(xml, debug)

if not content_blocks:
error_log(debug, "No valid <content> elements with <license> found in license.xml")
raise SupertabConnectError("No valid <content> elements with <license> found in license.xml")

if usage is not None:
serverless_usage_content = _find_serverless_usage_content(content_blocks, resource_url, usage, debug)
if serverless_usage_content is not None:
debug_log(
debug,
"Matched serverless content to usage and resource URL combination, skipping license token request.",
)
debug_log(debug, f"URL: {resource_url}, Usage: {usage}")
return None

token_content_blocks = [block for block in content_blocks if block.server]
matched_content = _find_best_matching_content(token_content_blocks, resource_url, debug)
if matched_content is None or matched_content.server is None:
patterns = ", ".join(block.url_pattern for block in token_content_blocks)
error_log(
debug,
f"No <content> element matches resource URL: {resource_url}. Available patterns: {patterns}",
)
raise SupertabConnectError(f"No <content> element in license.xml matches resource URL: {resource_url}")

debug_log(debug, f"Matched content block for resource URL: {resource_url}")
debug_log(debug, f"Using license XML: {matched_content.license_xml}")

lock = _get_cache_lock(cache_key)
async with lock:
cache_key = (client_id, matched_content.server, matched_content.url_pattern)
cached = _get_cached_token(cache_key, debug)
if cached is not None:
Comment on lines +368 to 370

Copilot AI Apr 28, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token cache key no longer includes the resource origin (it’s now keyed by client_id + token server + matched url_pattern). For path-only patterns like "/" or "/articles/", this can collide across different origins and return a token minted for a different license.xml/resource set. Consider including the resource origin (or a stable identifier like the license.xml URL / a hash of the matched license XML) in the cache key to prevent cross-origin token reuse.

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something for us to take into account when we think of generalizing to not only Supertab Connect. Our server URL is already unique per website's base URL

return cached

async with _create_async_client() as client:
xml = await _fetch_license_xml(client, resource_url, debug)
debug_log(debug, f"Fetched license.xml ({len(xml)} chars)")
content_blocks = _parse_content_elements(xml, debug)

if not content_blocks:
error_log(debug, "No valid <content> elements with <license> found in license.xml")
raise SupertabConnectError("No valid <content> elements with <license> found in license.xml")

matched_content = _find_best_matching_content(content_blocks, resource_url, debug)
if matched_content is None:
patterns = ", ".join(block.url_pattern for block in content_blocks)
error_log(
debug,
f"No <content> element matches resource URL: {resource_url}. Available patterns: {patterns}",
)
raise SupertabConnectError(f"No <content> element in license.xml matches resource URL: {resource_url}")
lock = _get_cache_lock(cache_key)
async with lock:
cached = _get_cached_token(cache_key, debug)
if cached is not None:
return cached

token_endpoint = matched_content.server.rstrip("/") + "/token"
debug_log(debug, f"Requesting license token from {token_endpoint}")
Expand All @@ -295,22 +396,22 @@ async def obtain_license_token(
debug=debug,
)

try:
claims = jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
"verify_aud": False,
"verify_iss": False,
},
algorithms=["HS256", "RS256", "ES256", "PS256"],
)
exp = claims.get("exp")
if isinstance(exp, int):
_LICENSE_TOKEN_CACHE[cache_key] = _CachedToken(token=token, exp=exp)
except (jwt.PyJWTError, ValueError, TypeError) as error:
debug_log(debug, f"Failed to decode token for caching, skipping cache: {error}")
try:
claims = jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
"verify_aud": False,
"verify_iss": False,
},
algorithms=["HS256", "RS256", "ES256", "PS256"],
)
exp = claims.get("exp")
if isinstance(exp, int):
_LICENSE_TOKEN_CACHE[cache_key] = _CachedToken(token=token, exp=exp)
except (jwt.PyJWTError, ValueError, TypeError) as error:
debug_log(debug, f"Failed to decode token for caching, skipping cache: {error}")

return token

Expand Down
9 changes: 9 additions & 0 deletions connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ class HandlerAction(StrEnum):
BLOCK = "block"


class UsageType(StrEnum):
ALL = "all"
SEARCH = "search"
AI_ALL = "ai-all"
AI_TRAIN = "ai-train"
AI_INDEX = "ai-index"
AI_INPUT = "ai-input"


BotDetector: TypeAlias = Callable[[Request], bool]


Expand Down
4 changes: 4 additions & 0 deletions examples/obtain_and_verify_license_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ async def main() -> None:
debug=True,
)

assert isinstance(token, str), (
"Token was not generated. Check your credentials and whether you have a license to the resource URL"
)
Comment thread
nick434434 marked this conversation as resolved.

print(f"Generated license token: {token}")

result = await verify_license_token(
Expand Down
3 changes: 2 additions & 1 deletion tests/customer/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from connect.customer.token import _LICENSE_TOKEN_CACHE, _LICENSE_TOKEN_LOCKS
from connect.customer.token import _LICENSE_TOKEN_CACHE, _LICENSE_TOKEN_LOCKS, _LICENSE_XML_CACHE

SAMPLE_XML = """
<rsl xmlns="https://rslstandard.org/rsl">
Expand All @@ -27,3 +27,4 @@
def clear_token_cache() -> None:
_LICENSE_TOKEN_CACHE.clear()
_LICENSE_TOKEN_LOCKS.clear()
_LICENSE_XML_CACHE.clear()
Loading
Loading