Skip to content

Commit 7076dfa

Browse files
Restoring changes done to script client
1 parent 9867549 commit 7076dfa

10 files changed

Lines changed: 186 additions & 33 deletions

File tree

src/datacustomcode/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717
from datacustomcode.credentials import AuthType, Credentials
1818
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
1919
from datacustomcode.io.writer.print import PrintDataCloudWriter
20-
# TODO: Restore proxy/LLM gateway integration
21-
# from datacustomcode.proxy.client.LocalProxyClientProvider import (
22-
# LocalProxyClientProvider,
23-
# )
20+
from datacustomcode.proxy.client.LocalProxyClientProvider import (
21+
LocalProxyClientProvider,
22+
)
2423

2524
__all__ = [
2625
"AuthType",
2726
"Client",
2827
"Credentials",
29-
# "LocalProxyClientProvider", # TODO: Restore
28+
"LocalProxyClientProvider",
3029
"PrintDataCloudWriter",
3130
"QueryAPIDataCloudReader",
3231
]

src/datacustomcode/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from datacustomcode.io.reader.base import BaseDataCloudReader
3535
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
36+
from datacustomcode.proxy.client.base import BaseProxyClient
3637
from datacustomcode.spark.base import BaseSparkSessionProvider
3738

3839

@@ -106,13 +107,15 @@ class Client:
106107
_reader: BaseDataCloudReader
107108
_writer: BaseDataCloudWriter
108109
_file: DefaultFindFilePath
110+
_proxy: Optional[BaseProxyClient]
109111
_data_layer_history: dict[DataCloudObjectType, set[str]]
110112
_code_type: str
111113

112114
def __new__(
113115
cls,
114116
reader: Optional[BaseDataCloudReader] = None,
115117
writer: Optional["BaseDataCloudWriter"] = None,
118+
proxy: Optional[BaseProxyClient] = None,
116119
spark_provider: Optional["BaseSparkSessionProvider"] = None,
117120
code_type: str = "script",
118121
) -> Client:
@@ -175,6 +178,11 @@ def __new__(
175178
@classmethod
176179
def _new_function_client(cls) -> Client:
177180
cls._instance = super().__new__(cls)
181+
cls._instance._proxy = (
182+
config.proxy_config.to_object() # type: ignore
183+
if config.proxy_config is not None
184+
else None
185+
)
178186
return cls._instance
179187

180188
def read_dlo(self, name: str, row_limit: int = 1000) -> PySparkDataFrame:
@@ -229,6 +237,11 @@ def write_to_dmo(
229237
self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO)
230238
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)
231239

240+
def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
241+
if self._proxy is None:
242+
raise ValueError("No proxy configured; set proxy or proxy_config")
243+
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens)
244+
232245
def find_file_path(self, file_name: str) -> Path:
233246
"""Return a file path"""
234247

src/datacustomcode/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from datacustomcode.io.base import BaseDataAccessLayer
3939
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
4040
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
41+
from datacustomcode.proxy.base import BaseProxyAccessLayer
42+
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001
4143
from datacustomcode.spark.base import BaseSparkSessionProvider
4244

4345
DEFAULT_CONFIG_NAME = "config.yaml"
@@ -92,6 +94,23 @@ class SparkConfig(ForceableConfig):
9294

9395
_P = TypeVar("_P", bound=BaseSparkSessionProvider)
9496

97+
_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)
98+
99+
100+
class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
101+
"""Config for proxy clients that take no constructor args (e.g. no spark)."""
102+
103+
model_config = ConfigDict(validate_default=True, extra="forbid")
104+
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
105+
type_config_name: str = Field(
106+
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
107+
)
108+
options: dict[str, Any] = Field(default_factory=dict)
109+
110+
def to_object(self) -> _PX:
111+
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
112+
return cast(_PX, type_(**self.options))
113+
95114

96115
class SparkProviderConfig(ForceableConfig, Generic[_P]):
97116
model_config = ConfigDict(validate_default=True, extra="forbid")
@@ -109,6 +128,7 @@ def to_object(self) -> _P:
109128
class ClientConfig(BaseModel):
110129
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
111130
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
131+
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
112132
spark_config: Union[SparkConfig, None] = None
113133
spark_provider_config: Union[
114134
SparkProviderConfig[BaseSparkSessionProvider], None
@@ -136,6 +156,7 @@ def merge(
136156

137157
self.reader_config = merge(self.reader_config, other.reader_config)
138158
self.writer_config = merge(self.writer_config, other.writer_config)
159+
self.proxy_config = merge(self.proxy_config, other.proxy_config)
139160
self.spark_config = merge(self.spark_config, other.spark_config)
140161
self.spark_provider_config = merge(
141162
self.spark_provider_config, other.spark_provider_config

src/datacustomcode/file/base.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16-
from abc import abstractmethod
17-
from pathlib import Path
1816

1917

2018
class BaseDataAccessLayer:
21-
def __init__(self):
22-
pass
23-
24-
@abstractmethod
25-
def find_file_path(self, file_name: str) -> Path: ...
19+
"""Base class for data access layer implementations."""
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

src/datacustomcode/proxy/base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from abc import ABC
18+
19+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
20+
21+
22+
class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin):
23+
def __init__(self):
24+
pass
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from datacustomcode.proxy.client.base import BaseProxyClient
18+
19+
20+
class LocalProxyClientProvider(BaseProxyClient):
21+
"""Default proxy client provider."""
22+
23+
CONFIG_NAME = "LocalProxyClientProvider"
24+
25+
def __init__(self, **kwargs: object) -> None:
26+
pass
27+
28+
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str:
29+
return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}"
30+
31+
def llm_gateway_generate_text(
32+
self, template, values, llmModelId: str, maxTokens: int
33+
):
34+
return f"Using Generate Text with {llmModelId} and maxTokens: {maxTokens}"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from abc import abstractmethod
18+
19+
from datacustomcode.proxy.base import BaseProxyAccessLayer
20+
21+
22+
class BaseProxyClient(BaseProxyAccessLayer):
23+
def __init__(self):
24+
pass
25+
26+
@abstractmethod
27+
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...
28+
29+
@abstractmethod
30+
def llm_gateway_generate_text(
31+
self, template, values, llmModelId: str, maxTokens: int
32+
): ...

0 commit comments

Comments
 (0)