Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ workflow_payload_offloading = [
workflow_payload_encryption = [
"cryptography>=41.0.0,<47.0.0",
]
workflow_payload_compression = [
"zstandard>=0.25.0,<0.26",
]


[project.urls]
Expand All @@ -69,6 +72,7 @@ dev = [
"griffe>=1.7.3,<2",
"authlib>=1.5.2,<2",
"websockets >=13.0",
"zstandard>=0.25.0,<0.26",
]
lint = [
"ruff>=0.11.10,<0.12",
Expand Down
4 changes: 4 additions & 0 deletions src/mistralai/extra/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class WorkflowPayloadEncryptionException(MistralClientException):
"""Workflow payload encryption exception"""


class WorkflowPayloadCompressionException(MistralClientException):
"""Workflow payload compression exception"""


class RunException(MistralClientException):
"""Conversation run errors."""

Expand Down
45 changes: 43 additions & 2 deletions src/mistralai/extra/tests/test_workflow_encoding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for workflow encoding configuration lifecycle."""

import asyncio
import gc

import pytest
Expand Down Expand Up @@ -29,7 +30,9 @@ def encryption_config() -> WorkflowEncodingConfig:
)


def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodingConfig):
def test_payload_encoder_cleanup_on_client_gc(
encryption_config: WorkflowEncodingConfig,
):
"""Test that PayloadEncoder is cleaned up when client is garbage collected."""
initial_config_count = len(_workflow_configs)

Expand All @@ -56,7 +59,9 @@ def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodin
assert len(_workflow_configs) == initial_config_count


def test_multiple_clients_independent_configs(encryption_config: WorkflowEncodingConfig):
def test_multiple_clients_independent_configs(
encryption_config: WorkflowEncodingConfig,
):
"""Test that multiple clients have independent configs."""
initial_config_count = len(_workflow_configs)

Expand Down Expand Up @@ -132,3 +137,39 @@ def test_reconfigure_same_client(encryption_config: WorkflowEncodingConfig):
del client
gc.collect()
assert config_id not in _workflow_configs


def test_payload_encoder_compresses_network_inputs():
from mistralai.extra.workflows import (
PayloadCompressionConfig,
ZstdCompressionConfig,
)
from mistralai.extra.workflows.encoding.models import (
EncodedPayloadOptions,
WorkflowContext,
)
from mistralai.extra.workflows.encoding.payload_encoder import PayloadEncoder

config = WorkflowEncodingConfig(
payload_compression=PayloadCompressionConfig(
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
)
)
encoder = PayloadEncoder(encoding_config=config)
payload = {"data": "x" * 20_000}

encoded = asyncio.run(
encoder.encode_network_input(
payload, WorkflowContext(namespace="test", execution_id="exec")
)
)

assert EncodedPayloadOptions.COMPRESSED in encoded.encoding_options
assert encoded.encoding_metadata == {
"compression_config": '{"algorithm":"zstd","level":3}'
}

decoded = asyncio.run(
encoder.decode_network_result(encoded.model_dump(mode="json"))
)
assert decoded == payload
14 changes: 10 additions & 4 deletions src/mistralai/extra/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
WorkflowExtensions,
)
from .encoding import (
WorkflowEncodingConfig,
PayloadOffloadingConfig,
AlgorithmConfig,
BlobStorageConfig,
EncryptedStrField,
PayloadCompressionConfig,
PayloadEncryptionConfig,
PayloadEncryptionMode,
BlobStorageConfig,
PayloadOffloadingConfig,
StorageProvider,
EncryptedStrField,
WorkflowEncodingConfig,
ZstdCompressionConfig,
configure_workflow_encoding,
generate_two_part_id,
)
Expand All @@ -27,10 +30,13 @@
"ConnectorSlot",
"WorkflowExtensions",
"execute_with_connector_auth_async",
"AlgorithmConfig",
"WorkflowEncodingConfig",
"PayloadOffloadingConfig",
"PayloadEncryptionConfig",
"PayloadEncryptionMode",
"PayloadCompressionConfig",
"ZstdCompressionConfig",
"BlobStorageConfig",
"StorageProvider",
"EncryptedStrField",
Expand Down
12 changes: 9 additions & 3 deletions src/mistralai/extra/workflows/encoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from .config import (
WorkflowEncodingConfig,
PayloadOffloadingConfig,
AlgorithmConfig,
BlobStorageConfig,
PayloadCompressionConfig,
PayloadEncryptionConfig,
PayloadEncryptionMode,
BlobStorageConfig,
PayloadOffloadingConfig,
StorageProvider,
WorkflowEncodingConfig,
ZstdCompressionConfig,
)
from .models import EncryptedStrField
from .payload_encoder import PayloadEncoder
from .helpers import configure_workflow_encoding, generate_two_part_id

__all__ = [
"AlgorithmConfig",
"WorkflowEncodingConfig",
"PayloadOffloadingConfig",
"PayloadEncryptionConfig",
"PayloadEncryptionMode",
"PayloadCompressionConfig",
"ZstdCompressionConfig",
"BlobStorageConfig",
"StorageProvider",
"EncryptedStrField",
Expand Down
22 changes: 20 additions & 2 deletions src/mistralai/extra/workflows/encoding/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from pydantic import SecretStr, BaseModel
from typing import Optional
from typing import Annotated, Literal, Optional, Union

from pydantic import BaseModel, Field, SecretStr


class StorageProvider(str, Enum):
Expand Down Expand Up @@ -47,6 +48,23 @@ class PayloadEncryptionConfig(BaseModel):
secondary_key: Optional[SecretStr] = None


class ZstdCompressionConfig(BaseModel):
algorithm: Literal["zstd"] = "zstd"
level: int = Field(default=3, ge=1, le=22)


AlgorithmConfig = Annotated[
Union[ZstdCompressionConfig], Field(discriminator="algorithm")
]


class PayloadCompressionConfig(BaseModel):
enabled: bool = True
min_size_bytes: int = 1024 * 1024 # 1MB
algorithm_config: AlgorithmConfig = Field(default_factory=ZstdCompressionConfig)


class WorkflowEncodingConfig(BaseModel):
payload_offloading: PayloadOffloadingConfig | None = None
payload_encryption: PayloadEncryptionConfig | None = None
payload_compression: PayloadCompressionConfig | None = None
19 changes: 17 additions & 2 deletions src/mistralai/extra/workflows/encoding/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class EncodedPayloadOptions(str, Enum):
OFFLOADED = "offloaded"
ENCRYPTED = "encrypted"
PARTIALLY_ENCRYPTED = "encrypted-partial"
COMPRESSED = "compressed"


class EncryptableFieldTypes(str, Enum):
Expand All @@ -35,6 +36,10 @@ class EncodedPayload(BaseModel):
encoding_options: list[EncodedPayloadOptions] = Field(
description="The encoding of the payload", default=[]
)
encoding_metadata: dict[str, str] = Field(
description="Additional metadata required to decode the payload",
default_factory=dict,
)
payload: bytes = Field(description="The encoded payload")


Expand All @@ -43,6 +48,10 @@ class NetworkEncodedBase(BaseModel):
encoding_options: list[EncodedPayloadOptions] = Field(
description="The encoding of the payload", default=[]
)
encoding_metadata: dict[str, str] = Field(
description="Additional metadata required to decode the payload",
default_factory=dict,
)

def get_payload(self) -> bytes:
return base64.b64decode(self.b64payload)
Expand All @@ -52,11 +61,12 @@ class NetworkEncodedInput(NetworkEncodedBase):
empty: bool = Field(description="Whether the payload is empty", default=False)

def to_encoded_payload(
self, namespace: str, execution_id: str, execution_token: str | None = None
self, namespace: str, execution_id: str, execution_token: Optional[str]
) -> EncodedPayload:
return EncodedPayload(
payload=base64.b64decode(self.b64payload),
encoding_options=self.encoding_options,
encoding_metadata=self.encoding_metadata,
context=WorkflowContext(
namespace=namespace,
execution_id=execution_id,
Expand All @@ -69,15 +79,19 @@ def from_encoded_payload(encoded_payload: EncodedPayload) -> "NetworkEncodedInpu
return NetworkEncodedInput(
b64payload=base64.b64encode(encoded_payload.payload).decode("utf-8"),
encoding_options=encoded_payload.encoding_options,
encoding_metadata=encoded_payload.encoding_metadata,
)

@staticmethod
def from_data(
data: bytes, encoding_options: list[EncodedPayloadOptions]
data: bytes,
encoding_options: list[EncodedPayloadOptions],
encoding_metadata: dict[str, str] | None = None,
) -> "NetworkEncodedInput":
return NetworkEncodedInput(
b64payload=base64.b64encode(data).decode("utf-8"),
encoding_options=encoding_options,
encoding_metadata=encoding_metadata or {},
)


Expand All @@ -87,4 +101,5 @@ def from_encoded_payload(encoded_payload: EncodedPayload) -> "NetworkEncodedResu
return NetworkEncodedResult(
b64payload=base64.b64encode(encoded_payload.payload).decode("utf-8"),
encoding_options=encoded_payload.encoding_options,
encoding_metadata=encoded_payload.encoding_metadata,
)
100 changes: 100 additions & 0 deletions src/mistralai/extra/workflows/encoding/payload_compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import lru_cache
from importlib import import_module
from types import ModuleType

from pydantic import TypeAdapter, ValidationError

from mistralai.extra.exceptions import WorkflowPayloadCompressionException
from mistralai.extra.workflows.encoding.config import (
AlgorithmConfig,
PayloadCompressionConfig,
ZstdCompressionConfig,
)

_ALGORITHM_CONFIG_ADAPTER: TypeAdapter[AlgorithmConfig] = TypeAdapter(AlgorithmConfig)
COMPRESSION_CONFIG_METADATA_KEY = "compression_config"


class Compressor(ABC):
@property
@abstractmethod
def tag(self) -> str:
"""Wire tag stored in encoding options to identify the algorithm."""

@abstractmethod
def serialize_config(self) -> str:
"""Pydantic-serialised algorithm config, stored with the payload for config-independent decoding."""

@abstractmethod
def compress(self, data: bytes) -> bytes: ...

@abstractmethod
def decompress(self, data: bytes) -> bytes: ...


def _require_zstandard() -> ModuleType:
try:
return import_module("zstandard")
except ImportError:
raise WorkflowPayloadCompressionException(
"Payload compression requires installing mistralai[workflow-payload-compression]"
) from None


class ZstdCompressor(Compressor):
tag = "zstd"

def __init__(self, cfg: ZstdCompressionConfig) -> None:
zstd = _require_zstandard()
self._config = cfg
self._compressor = zstd.ZstdCompressor(level=cfg.level)
self._decompressor = zstd.ZstdDecompressor()

def serialize_config(self) -> str:
return self._config.model_dump_json()

def compress(self, data: bytes) -> bytes:
result: bytes = self._compressor.compress(data)
return result

def decompress(self, data: bytes) -> bytes:
result: bytes = self._decompressor.decompress(data)
return result


@lru_cache(maxsize=8)
def _build_compressor_for_config(config_json: str) -> Compressor:
try:
algo_config = _ALGORITHM_CONFIG_ADAPTER.validate_json(config_json)
except ValidationError as exc:
raise WorkflowPayloadCompressionException(
f"Invalid compression config in payload metadata: {exc}"
) from exc

if isinstance(algo_config, ZstdCompressionConfig):
return ZstdCompressor(algo_config)
raise WorkflowPayloadCompressionException(
f"Unsupported compression algorithm: {algo_config.algorithm!r}"
)


def decompress_by_metadata(data: bytes, encoding_metadata: dict[str, str]) -> bytes:
config_json = encoding_metadata.get(COMPRESSION_CONFIG_METADATA_KEY)
if config_json is None:
raise WorkflowPayloadCompressionException(
"Payload is marked as compressed but has no compression config in metadata"
)
return _build_compressor_for_config(config_json).decompress(data)


def build_compressor(
compression_config: PayloadCompressionConfig | None,
) -> Compressor | None:
if compression_config is None or not compression_config.enabled:
return None
return _build_compressor_for_config(
compression_config.algorithm_config.model_dump_json()
)
Loading