Skip to content

Commit 7117670

Browse files
authored
Merge pull request #68 from forcedotcom/jo_run_function
mock run function
2 parents 087eaeb + b45e8c8 commit 7117670

11 files changed

Lines changed: 208 additions & 34 deletions

File tree

src/datacustomcode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from datacustomcode.credentials import AuthType, Credentials
1818
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
1919
from datacustomcode.io.writer.print import PrintDataCloudWriter
20+
from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider
2021

2122
__all__ = [
2223
"AuthType",
2324
"Client",
2425
"Credentials",
26+
"LocalProxyClientProvider",
2527
"PrintDataCloudWriter",
2628
"QueryAPIDataCloudReader",
2729
]

src/datacustomcode/client.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from datacustomcode.io.reader.base import BaseDataCloudReader
3535
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
36+
from datacustomcode.proxy.client.base import BaseProxyClient
3637
from datacustomcode.spark.base import BaseSparkSessionProvider
3738

3839

@@ -106,17 +107,20 @@ class Client:
106107
_reader: BaseDataCloudReader
107108
_writer: BaseDataCloudWriter
108109
_file: DefaultFindFilePath
110+
_proxy: BaseProxyClient
109111
_data_layer_history: dict[DataCloudObjectType, set[str]]
110112

111113
def __new__(
112114
cls,
113115
reader: Optional[BaseDataCloudReader] = None,
114116
writer: Optional["BaseDataCloudWriter"] = None,
117+
proxy: Optional[BaseProxyClient] = None,
115118
spark_provider: Optional["BaseSparkSessionProvider"] = None,
116119
) -> Client:
117120
if cls._instance is None:
118121
cls._instance = super().__new__(cls)
119122

123+
spark = None
120124
# Initialize Readers and Writers from config
121125
# and/or provided reader and writer
122126
if reader is None or writer is None:
@@ -135,6 +139,22 @@ def __new__(
135139
provider = DefaultSparkSessionProvider()
136140

137141
spark = provider.get_session(config.spark_config)
142+
elif (
143+
proxy is None
144+
and config.proxy_config is not None
145+
and config.spark_config is not None
146+
):
147+
# Both reader and writer provided; we still need spark for proxy init
148+
provider = (
149+
spark_provider
150+
if spark_provider is not None
151+
else (
152+
config.spark_provider_config.to_object()
153+
if config.spark_provider_config is not None
154+
else DefaultSparkSessionProvider()
155+
)
156+
)
157+
spark = provider.get_session(config.spark_config)
138158

139159
if config.reader_config is None and reader is None:
140160
raise ValueError(
@@ -143,22 +163,44 @@ def __new__(
143163
elif reader is None or (
144164
config.reader_config is not None and config.reader_config.force
145165
):
146-
reader_init = config.reader_config.to_object(spark) # type: ignore
166+
if config.proxy_config is None:
167+
raise ValueError(
168+
"Proxy config is required when reader is built from config"
169+
)
170+
assert (
171+
spark is not None
172+
) # set in "reader is None or writer is None" branch
173+
assert config.reader_config is not None # ensured by branch condition
174+
proxy_init = config.proxy_config.to_object(spark)
175+
176+
reader_init = config.reader_config.to_object(spark)
147177
else:
148178
reader_init = reader
179+
if proxy is not None:
180+
proxy_init = proxy
181+
elif config.proxy_config is None:
182+
raise ValueError("Proxy config is required when reader is provided")
183+
else:
184+
assert (
185+
spark is not None
186+
) # set in "both provided; proxy from config" branch
187+
proxy_init = config.proxy_config.to_object(spark)
149188
if config.writer_config is None and writer is None:
150189
raise ValueError(
151190
"Writer config is required when writer is not provided"
152191
)
153192
elif writer is None or (
154193
config.writer_config is not None and config.writer_config.force
155194
):
156-
writer_init = config.writer_config.to_object(spark) # type: ignore
195+
assert spark is not None # set when reader or writer from config
196+
assert config.writer_config is not None # ensured by branch condition
197+
writer_init = config.writer_config.to_object(spark)
157198
else:
158199
writer_init = writer
159200
cls._instance._reader = reader_init
160201
cls._instance._writer = writer_init
161202
cls._instance._file = DefaultFindFilePath()
203+
cls._instance._proxy = proxy_init
162204
cls._instance._data_layer_history = {
163205
DataCloudObjectType.DLO: set(),
164206
DataCloudObjectType.DMO: set(),
@@ -217,6 +259,9 @@ def write_to_dmo(
217259
self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO)
218260
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)
219261

262+
def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
263+
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens)
264+
220265
def find_file_path(self, file_name: str) -> Path:
221266
"""Return a file path"""
222267

src/datacustomcode/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from datacustomcode.io.base import BaseDataAccessLayer
3939
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
4040
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
41+
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001
4142
from datacustomcode.spark.base import BaseSparkSessionProvider
4243

4344
DEFAULT_CONFIG_NAME = "config.yaml"
@@ -109,6 +110,7 @@ def to_object(self) -> _P:
109110
class ClientConfig(BaseModel):
110111
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
111112
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
113+
proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None
112114
spark_config: Union[SparkConfig, None] = None
113115
spark_provider_config: Union[
114116
SparkProviderConfig[BaseSparkSessionProvider], None
@@ -136,6 +138,7 @@ def merge(
136138

137139
self.reader_config = merge(self.reader_config, other.reader_config)
138140
self.writer_config = merge(self.writer_config, other.writer_config)
141+
self.proxy_config = merge(self.proxy_config, other.proxy_config)
139142
self.spark_config = merge(self.spark_config, other.spark_config)
140143
self.spark_provider_config = merge(
141144
self.spark_provider_config, other.spark_provider_config

src/datacustomcode/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,8 @@ spark_config:
1717
spark.submit.deployMode: client
1818
spark.sql.execution.arrow.pyspark.enabled: 'true'
1919
spark.driver.extraJavaOptions: -Djava.security.manager=allow
20+
21+
proxy_config:
22+
type_config_name: LocalProxyClientProvider
23+
options:
24+
credentials_profile: default
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

src/datacustomcode/proxy/base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from abc import ABC
18+
19+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
20+
21+
22+
class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin):
23+
def __init__(self):
24+
pass
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from abc import abstractmethod
18+
19+
from datacustomcode.io.base import BaseDataAccessLayer
20+
21+
22+
class BaseProxyClient(BaseDataAccessLayer):
23+
def __init__(self, spark=None, **kwargs):
24+
if spark is not None:
25+
super().__init__(spark)
26+
27+
@abstractmethod
28+
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from datacustomcode.proxy.client.base import BaseProxyClient
18+
19+
20+
class LocalProxyClientProvider(BaseProxyClient):
21+
"""Default proxy client provider."""
22+
23+
CONFIG_NAME = "LocalProxyClientProvider"
24+
25+
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str:
26+
return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}"

src/datacustomcode/run.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import List, Union
2222

2323
from datacustomcode.config import config
24+
from datacustomcode.scan import get_package_type
2425

2526

2627
def _set_config_option(config_obj, key: str, value: str) -> None:
@@ -60,6 +61,8 @@ def run_entrypoint(
6061
f"config.json not found at {config_json_path}. config.json is required."
6162
)
6263

64+
package_type = get_package_type(entrypoint_dir)
65+
6366
try:
6467
with open(config_json_path, "r") as f:
6568
config_json = json.load(f)
@@ -68,21 +71,23 @@ def run_entrypoint(
6871
f"config.json at {config_json_path} is not valid JSON"
6972
) from err
7073

71-
# Require dataspace to be present in config.json
72-
dataspace = config_json.get("dataspace")
73-
if not dataspace:
74-
raise ValueError(
75-
f"config.json at {config_json_path} is missing required field 'dataspace'. "
76-
f"Please ensure config.json contains a 'dataspace' field."
77-
)
78-
79-
# Load config file first
80-
if config_file:
81-
config.load(config_file)
82-
83-
# Add dataspace to reader and writer config options
84-
_set_config_option(config.reader_config, "dataspace", dataspace)
85-
_set_config_option(config.writer_config, "dataspace", dataspace)
74+
if package_type == "script":
75+
# Require dataspace to be present in config.json
76+
dataspace = config_json.get("dataspace")
77+
if not dataspace:
78+
raise ValueError(
79+
f"config.json at {config_json_path} is missing required "
80+
f"field 'dataspace'. "
81+
f"Please ensure config.json contains a 'dataspace' field."
82+
)
83+
84+
# Load config file first
85+
if config_file:
86+
config.load(config_file)
87+
88+
# Add dataspace to reader and writer config options
89+
_set_config_option(config.reader_config, "dataspace", dataspace)
90+
_set_config_option(config.writer_config, "dataspace", dataspace)
8691

8792
if profile != "default":
8893
_set_config_option(config.reader_config, "credentials_profile", profile)

0 commit comments

Comments
 (0)