Skip to content

Commit dcaa1f9

Browse files
make llm gateway implementation configurable
1 parent cd119d6 commit dcaa1f9

5 files changed

Lines changed: 227 additions & 3 deletions

File tree

src/datacustomcode/llm_gateway/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from abc import abstractmethod
17+
from abc import ABC, abstractmethod
1818
from typing import TYPE_CHECKING
1919

20+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
21+
2022
if TYPE_CHECKING:
2123
from datacustomcode.llm_gateway.types.generate_text_request import (
2224
GenerateTextRequest,
@@ -26,8 +28,10 @@
2628
)
2729

2830

29-
class LLMGateway:
30-
def __init__(self) -> None:
31+
class LLMGateway(ABC, UserExtendableNamedConfigMixin):
32+
CONFIG_NAME: str
33+
34+
def __init__(self, **kwargs):
3135
pass
3236

3337
@abstractmethod

src/datacustomcode/llm_gateway/default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323

2424
class DefaultLLMGateway(LLMGateway):
25+
CONFIG_NAME = "DefaultLLMGateway"
26+
2527
def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
2628

2729
response_data = {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import (
17+
ClassVar,
18+
Generic,
19+
Type,
20+
TypeVar,
21+
Union,
22+
cast,
23+
)
24+
25+
from datacustomcode.common_config import (
26+
BaseConfig,
27+
BaseObjectConfig,
28+
default_config_file,
29+
)
30+
from datacustomcode.llm_gateway.base import LLMGateway
31+
32+
_E = TypeVar("_E", bound=LLMGateway)
33+
34+
35+
class LLMGatewayObjectConfig(BaseObjectConfig, Generic[_E]):
36+
type_base: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract]
37+
38+
def to_object(self) -> _E:
39+
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
40+
return cast(_E, type_(**self.options))
41+
42+
43+
class LLMGatewayConfig(BaseConfig):
44+
llm_gateway_config: Union[
45+
LLMGatewayObjectConfig[LLMGateway], None
46+
] = None
47+
48+
def update(self, other: "LLMGatewayConfig") -> "LLMGatewayConfig":
49+
def merge(
50+
config_a: Union[LLMGatewayObjectConfig, None],
51+
config_b: Union[LLMGatewayObjectConfig, None],
52+
) -> Union[LLMGatewayObjectConfig, None]:
53+
if config_a is not None and config_a.force:
54+
return config_a
55+
if config_b:
56+
return config_b
57+
return config_a
58+
59+
self.llm_gateway_config = merge(
60+
self.llm_gateway_config, other.llm_gateway_config
61+
)
62+
return self
63+
64+
65+
# Global LLM Gateway config instance
66+
llm_gateway_config = LLMGatewayConfig()
67+
llm_gateway_config.load(default_config_file())
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import tempfile
18+
19+
import yaml
20+
21+
from datacustomcode.llm_gateway.default import DefaultLLMGateway
22+
from datacustomcode.llm_gateway_config import (
23+
LLMGatewayConfig,
24+
LLMGatewayObjectConfig,
25+
)
26+
27+
28+
class TestLLMGatewayConfigUpdate:
29+
def test_update_replaces_config_without_force(self):
30+
config1 = LLMGatewayConfig(
31+
llm_gateway_config=LLMGatewayObjectConfig(
32+
type_config_name="OldImplementation", options={"old": True}
33+
)
34+
)
35+
36+
config2 = LLMGatewayConfig(
37+
llm_gateway_config=LLMGatewayObjectConfig(
38+
type_config_name="NewImplementation", options={"new": True}
39+
)
40+
)
41+
42+
config1.update(config2)
43+
44+
assert config1.llm_gateway_config.type_config_name == "NewImplementation"
45+
assert config1.llm_gateway_config.options == {"new": True}
46+
47+
def test_update_respects_force_flag(self):
48+
config1 = LLMGatewayConfig(
49+
llm_gateway_config=LLMGatewayObjectConfig(
50+
type_config_name="ForcedImplementation",
51+
options={"forced": True},
52+
force=True,
53+
)
54+
)
55+
56+
config2 = LLMGatewayConfig(
57+
llm_gateway_config=LLMGatewayObjectConfig(
58+
type_config_name="NewImplementation", options={"new": True}
59+
)
60+
)
61+
62+
config1.update(config2)
63+
64+
assert (
65+
config1.llm_gateway_config.type_config_name == "ForcedImplementation"
66+
)
67+
assert config1.llm_gateway_config.options == {"forced": True}
68+
assert config1.llm_gateway_config.force is True
69+
70+
71+
class TestLLMGatewayConfigLoad:
72+
def test_load_from_yaml_file(self):
73+
config_data = {
74+
"llm_gateway_config": {"type_config_name": "DefaultLLMGateway"}
75+
}
76+
77+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
78+
yaml.dump(config_data, f)
79+
temp_file = f.name
80+
81+
try:
82+
config = LLMGatewayConfig()
83+
config.load(temp_file)
84+
85+
assert config.llm_gateway_config is not None
86+
assert (
87+
config.llm_gateway_config.type_config_name == "DefaultLLMGateway"
88+
)
89+
llm_gateway = config.llm_gateway_config.to_object()
90+
assert llm_gateway is not None
91+
assert isinstance(llm_gateway, DefaultLLMGateway)
92+
finally:
93+
os.unlink(temp_file)

tests/test_runtime_llm_gateway.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from datacustomcode.llm_gateway.base import LLMGateway
17+
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
18+
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
19+
from datacustomcode.llm_gateway_config import LLMGatewayObjectConfig
20+
21+
22+
class TestCustomLLMGatewayImplementation:
23+
def test_custom_implementation_is_discoverable(self):
24+
class CustomLLMGateway(LLMGateway):
25+
CONFIG_NAME = "CustomLLMGateway"
26+
27+
def __init__(self, custom_param: str = "default", **kwargs):
28+
super().__init__(**kwargs)
29+
self.custom_param = custom_param
30+
31+
def generate_text(
32+
self, request: GenerateTextRequest
33+
) -> GenerateTextResponse:
34+
return GenerateTextResponse(
35+
version="v1",
36+
status_code=200,
37+
data={"generation": {"generatedText": "Custom response"}},
38+
)
39+
40+
available_names = LLMGateway.available_config_names()
41+
assert "CustomLLMGateway" in available_names
42+
43+
cls = LLMGateway.subclass_from_config_name("CustomLLMGateway")
44+
assert cls == CustomLLMGateway
45+
46+
# Verify we can create via config
47+
llm_config = LLMGatewayObjectConfig(
48+
type_config_name="CustomLLMGateway",
49+
options={"custom_param": "my_value"},
50+
)
51+
instance = llm_config.to_object()
52+
assert isinstance(instance, CustomLLMGateway)
53+
assert instance.custom_param == "my_value"
54+
55+
request = GenerateTextRequest(model_name="test-model", prompt="Hello")
56+
response = instance.generate_text(request)
57+
assert response.is_success is True
58+
assert response.data["generation"]["generatedText"] == "Custom response"

0 commit comments

Comments
 (0)