Skip to content

Commit 8f0049e

Browse files
authored
Merge pull request #70 from forcedotcom/jo_func
fine tuning functions
2 parents 7117670 + 92a0c50 commit 8f0049e

6 files changed

Lines changed: 51 additions & 51 deletions

File tree

src/datacustomcode/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from datacustomcode.credentials import AuthType, Credentials
1818
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
1919
from datacustomcode.io.writer.print import PrintDataCloudWriter
20-
from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider
20+
from datacustomcode.proxy.client.LocalProxyClientProvider import (
21+
LocalProxyClientProvider,
22+
)
2123

2224
__all__ = [
2325
"AuthType",

src/datacustomcode/client.py

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,23 @@ class Client:
107107
_reader: BaseDataCloudReader
108108
_writer: BaseDataCloudWriter
109109
_file: DefaultFindFilePath
110-
_proxy: BaseProxyClient
110+
_proxy: Optional[BaseProxyClient]
111111
_data_layer_history: dict[DataCloudObjectType, set[str]]
112+
_code_type: str
112113

113114
def __new__(
114115
cls,
115116
reader: Optional[BaseDataCloudReader] = None,
116117
writer: Optional["BaseDataCloudWriter"] = None,
117118
proxy: Optional[BaseProxyClient] = None,
118119
spark_provider: Optional["BaseSparkSessionProvider"] = None,
120+
code_type: str = "script",
119121
) -> Client:
122+
if "function" in code_type:
123+
return cls._new_function_client()
124+
120125
if cls._instance is None:
121126
cls._instance = super().__new__(cls)
122-
123-
spark = None
124127
# Initialize Readers and Writers from config
125128
# and/or provided reader and writer
126129
if reader is None or writer is None:
@@ -139,22 +142,6 @@ def __new__(
139142
provider = DefaultSparkSessionProvider()
140143

141144
spark = provider.get_session(config.spark_config)
142-
elif (
143-
proxy is None
144-
and config.proxy_config is not None
145-
and config.spark_config is not None
146-
):
147-
# Both reader and writer provided; we still need spark for proxy init
148-
provider = (
149-
spark_provider
150-
if spark_provider is not None
151-
else (
152-
config.spark_provider_config.to_object()
153-
if config.spark_provider_config is not None
154-
else DefaultSparkSessionProvider()
155-
)
156-
)
157-
spark = provider.get_session(config.spark_config)
158145

159146
if config.reader_config is None and reader is None:
160147
raise ValueError(
@@ -163,44 +150,23 @@ def __new__(
163150
elif reader is None or (
164151
config.reader_config is not None and config.reader_config.force
165152
):
166-
if config.proxy_config is None:
167-
raise ValueError(
168-
"Proxy config is required when reader is built from config"
169-
)
170-
assert (
171-
spark is not None
172-
) # set in "reader is None or writer is None" branch
173-
assert config.reader_config is not None # ensured by branch condition
174-
proxy_init = config.proxy_config.to_object(spark)
175-
176-
reader_init = config.reader_config.to_object(spark)
153+
reader_init = config.reader_config.to_object(spark) # type: ignore
177154
else:
178155
reader_init = reader
179-
if proxy is not None:
180-
proxy_init = proxy
181-
elif config.proxy_config is None:
182-
raise ValueError("Proxy config is required when reader is provided")
183-
else:
184-
assert (
185-
spark is not None
186-
) # set in "both provided; proxy from config" branch
187-
proxy_init = config.proxy_config.to_object(spark)
188156
if config.writer_config is None and writer is None:
189157
raise ValueError(
190158
"Writer config is required when writer is not provided"
191159
)
192160
elif writer is None or (
193161
config.writer_config is not None and config.writer_config.force
194162
):
195-
assert spark is not None # set when reader or writer from config
196-
assert config.writer_config is not None # ensured by branch condition
197-
writer_init = config.writer_config.to_object(spark)
163+
writer_init = config.writer_config.to_object(spark) # type: ignore
198164
else:
199165
writer_init = writer
166+
200167
cls._instance._reader = reader_init
201168
cls._instance._writer = writer_init
202169
cls._instance._file = DefaultFindFilePath()
203-
cls._instance._proxy = proxy_init
204170
cls._instance._data_layer_history = {
205171
DataCloudObjectType.DLO: set(),
206172
DataCloudObjectType.DMO: set(),
@@ -209,6 +175,16 @@ def __new__(
209175
raise ValueError("Cannot set reader or writer after client is initialized")
210176
return cls._instance
211177

178+
@classmethod
179+
def _new_function_client(cls) -> Client:
180+
cls._instance = super().__new__(cls)
181+
cls._instance._proxy = (
182+
config.proxy_config.to_object() # type: ignore
183+
if config.proxy_config is not None
184+
else None
185+
)
186+
return cls._instance
187+
212188
def read_dlo(self, name: str) -> PySparkDataFrame:
213189
"""Read a DLO from Data Cloud.
214190
@@ -260,6 +236,8 @@ def write_to_dmo(
260236
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)
261237

262238
def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
239+
if self._proxy is None:
240+
raise ValueError("No proxy configured; set proxy or proxy_config")
263241
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens)
264242

265243
def find_file_path(self, file_name: str) -> Path:

src/datacustomcode/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from datacustomcode.io.base import BaseDataAccessLayer
3939
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
4040
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
41+
from datacustomcode.proxy.base import BaseProxyAccessLayer
4142
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001
4243
from datacustomcode.spark.base import BaseSparkSessionProvider
4344

@@ -93,6 +94,23 @@ class SparkConfig(ForceableConfig):
9394

9495
_P = TypeVar("_P", bound=BaseSparkSessionProvider)
9596

97+
_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)
98+
99+
100+
class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
101+
"""Config for proxy clients that take no constructor args (e.g. no spark)."""
102+
103+
model_config = ConfigDict(validate_default=True, extra="forbid")
104+
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
105+
type_config_name: str = Field(
106+
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
107+
)
108+
options: dict[str, Any] = Field(default_factory=dict)
109+
110+
def to_object(self) -> _PX:
111+
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
112+
return cast(_PX, type_(**self.options))
113+
96114

97115
class SparkProviderConfig(ForceableConfig, Generic[_P]):
98116
model_config = ConfigDict(validate_default=True, extra="forbid")
@@ -110,7 +128,7 @@ def to_object(self) -> _P:
110128
class ClientConfig(BaseModel):
111129
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
112130
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
113-
proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None
131+
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
114132
spark_config: Union[SparkConfig, None] = None
115133
spark_provider_config: Union[
116134
SparkProviderConfig[BaseSparkSessionProvider], None

src/datacustomcode/proxy/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@
1919
from datacustomcode.mixin import UserExtendableNamedConfigMixin
2020

2121

22-
class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin):
22+
class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin):
2323
def __init__(self):
2424
pass

src/datacustomcode/proxy/client/local_proxy_client.py renamed to src/datacustomcode/proxy/client/LocalProxyClientProvider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,8 @@ class LocalProxyClientProvider(BaseProxyClient):
2222

2323
CONFIG_NAME = "LocalProxyClientProvider"
2424

25+
def __init__(self, **kwargs: object) -> None:
26+
pass
27+
2528
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str:
2629
return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}"

src/datacustomcode/proxy/client/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616

1717
from abc import abstractmethod
1818

19-
from datacustomcode.io.base import BaseDataAccessLayer
19+
from datacustomcode.proxy.base import BaseProxyAccessLayer
2020

2121

22-
class BaseProxyClient(BaseDataAccessLayer):
23-
def __init__(self, spark=None, **kwargs):
24-
if spark is not None:
25-
super().__init__(spark)
22+
class BaseProxyClient(BaseProxyAccessLayer):
23+
def __init__(self):
24+
pass
2625

2726
@abstractmethod
2827
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...

0 commit comments

Comments
 (0)