Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/datacustomcode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from datacustomcode.credentials import AuthType, Credentials
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.print import PrintDataCloudWriter
from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider
from datacustomcode.proxy.client.LocalProxyClientProvider import (
LocalProxyClientProvider,
)

__all__ = [
"AuthType",
Expand Down
79 changes: 36 additions & 43 deletions src/datacustomcode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from enum import Enum
import importlib
from typing import (
TYPE_CHECKING,
ClassVar,
Expand Down Expand Up @@ -107,20 +108,23 @@ class Client:
_reader: BaseDataCloudReader
_writer: BaseDataCloudWriter
_file: DefaultFindFilePath
_proxy: BaseProxyClient
_proxy: Optional[BaseProxyClient]
_data_layer_history: dict[DataCloudObjectType, set[str]]
_code_type: str

def __new__(
cls,
reader: Optional[BaseDataCloudReader] = None,
writer: Optional["BaseDataCloudWriter"] = None,
proxy: Optional[BaseProxyClient] = None,
spark_provider: Optional["BaseSparkSessionProvider"] = None,
code_type: str = "script",
) -> Client:
if "function" in code_type:
return cls._new_function_client()

if cls._instance is None:
cls._instance = super().__new__(cls)

spark = None
Comment thread
joroscoSF marked this conversation as resolved.
# Initialize Readers and Writers from config
# and/or provided reader and writer
if reader is None or writer is None:
Expand All @@ -139,22 +143,6 @@ def __new__(
provider = DefaultSparkSessionProvider()

spark = provider.get_session(config.spark_config)
elif (
proxy is None
and config.proxy_config is not None
and config.spark_config is not None
):
# Both reader and writer provided; we still need spark for proxy init
provider = (
spark_provider
if spark_provider is not None
else (
config.spark_provider_config.to_object()
if config.spark_provider_config is not None
else DefaultSparkSessionProvider()
)
)
spark = provider.get_session(config.spark_config)

if config.reader_config is None and reader is None:
raise ValueError(
Expand All @@ -163,44 +151,23 @@ def __new__(
elif reader is None or (
config.reader_config is not None and config.reader_config.force
):
if config.proxy_config is None:
raise ValueError(
"Proxy config is required when reader is built from config"
)
assert (
spark is not None
) # set in "reader is None or writer is None" branch
assert config.reader_config is not None # ensured by branch condition
proxy_init = config.proxy_config.to_object(spark)

reader_init = config.reader_config.to_object(spark)
reader_init = config.reader_config.to_object(spark) # type: ignore
else:
reader_init = reader
if proxy is not None:
proxy_init = proxy
elif config.proxy_config is None:
raise ValueError("Proxy config is required when reader is provided")
else:
assert (
spark is not None
) # set in "both provided; proxy from config" branch
proxy_init = config.proxy_config.to_object(spark)
if config.writer_config is None and writer is None:
raise ValueError(
"Writer config is required when writer is not provided"
)
elif writer is None or (
config.writer_config is not None and config.writer_config.force
):
assert spark is not None # set when reader or writer from config
assert config.writer_config is not None # ensured by branch condition
writer_init = config.writer_config.to_object(spark)
writer_init = config.writer_config.to_object(spark) # type: ignore
else:
writer_init = writer

cls._instance._reader = reader_init
cls._instance._writer = writer_init
cls._instance._file = DefaultFindFilePath()
cls._instance._proxy = proxy_init
cls._instance._data_layer_history = {
DataCloudObjectType.DLO: set(),
DataCloudObjectType.DMO: set(),
Expand All @@ -209,6 +176,30 @@ def __new__(
raise ValueError("Cannot set reader or writer after client is initialized")
return cls._instance

@classmethod
def _new_function_client(cls) -> Client:
for dependency in config.dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError as exc:
try:
if "." in dependency:
module_name, object_name = dependency.rsplit(".", 1)
module = importlib.import_module(module_name)
getattr(module, object_name)
else:
raise exc
except AttributeError as inner_exc:
raise inner_exc from exc

cls._instance = super().__new__(cls)
cls._instance._proxy = (
config.proxy_config.to_object() # type: ignore
if config.proxy_config is not None
else None
)
return cls._instance

def read_dlo(self, name: str) -> PySparkDataFrame:
"""Read a DLO from Data Cloud.

Expand Down Expand Up @@ -260,6 +251,8 @@ def write_to_dmo(
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)

def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
if self._proxy is None:
raise ValueError("No proxy configured; set proxy or proxy_config")
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens)

def find_file_path(self, file_name: str) -> Path:
Expand Down
28 changes: 27 additions & 1 deletion src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
from datacustomcode.proxy.base import BaseProxyAccessLayer
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001
from datacustomcode.spark.base import BaseSparkSessionProvider

Expand Down Expand Up @@ -93,6 +94,23 @@ class SparkConfig(ForceableConfig):

_P = TypeVar("_P", bound=BaseSparkSessionProvider)

_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)


class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
"""Config for proxy clients that take no constructor args (e.g. no spark)."""

model_config = ConfigDict(validate_default=True, extra="forbid")
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
type_config_name: str = Field(
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
)
options: dict[str, Any] = Field(default_factory=dict)

def to_object(self) -> _PX:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_PX, type_(**self.options))


class SparkProviderConfig(ForceableConfig, Generic[_P]):
model_config = ConfigDict(validate_default=True, extra="forbid")
Expand All @@ -110,11 +128,18 @@ def to_object(self) -> _P:
class ClientConfig(BaseModel):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
spark_config: Union[SparkConfig, None] = None
spark_provider_config: Union[
SparkProviderConfig[BaseSparkSessionProvider], None
] = None
dependencies: list[str] = Field(
default_factory=list,
description="""
Extra modules to import before running the entrypoint
(merged with --dependencies from CLI).
""",
)

def update(self, other: ClientConfig) -> ClientConfig:
"""Merge this ClientConfig with another, respecting force flags.
Expand Down Expand Up @@ -143,6 +168,7 @@ def merge(
self.spark_provider_config = merge(
self.spark_provider_config, other.spark_provider_config
)
self.dependencies = list(dict.fromkeys(self.dependencies + other.dependencies))
return self

def load(self, config_path: str) -> ClientConfig:
Expand Down
2 changes: 1 addition & 1 deletion src/datacustomcode/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from datacustomcode.mixin import UserExtendableNamedConfigMixin


class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin):
class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin):
def __init__(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ class LocalProxyClientProvider(BaseProxyClient):

CONFIG_NAME = "LocalProxyClientProvider"

def __init__(self, **kwargs: object) -> None:
pass

def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str:
return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}"
9 changes: 4 additions & 5 deletions src/datacustomcode/proxy/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

from abc import abstractmethod

from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.proxy.base import BaseProxyAccessLayer


class BaseProxyClient(BaseDataAccessLayer):
def __init__(self, spark=None, **kwargs):
if spark is not None:
super().__init__(spark)
class BaseProxyClient(BaseProxyAccessLayer):
def __init__(self):
pass

@abstractmethod
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...
13 changes: 8 additions & 5 deletions src/datacustomcode/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def run_entrypoint(
"""
add_py_folder(entrypoint)

# Load config file (so we can merge config.dependencies with CLI deps)
if config_file:
config.load(config_file)

# Merge dependencies from config and CLI (config first, then CLI, deduped)
merged_dependencies = list(dict.fromkeys(config.dependencies + list(dependencies)))

# Read dataspace from config.json (required)
entrypoint_dir = os.path.dirname(entrypoint)
config_json_path = os.path.join(entrypoint_dir, "config.json")
Expand Down Expand Up @@ -81,18 +88,14 @@ def run_entrypoint(
f"Please ensure config.json contains a 'dataspace' field."
)

# Load config file first
if config_file:
config.load(config_file)

# Add dataspace to reader and writer config options
_set_config_option(config.reader_config, "dataspace", dataspace)
_set_config_option(config.writer_config, "dataspace", dataspace)

if profile != "default":
_set_config_option(config.reader_config, "credentials_profile", profile)
_set_config_option(config.writer_config, "credentials_profile", profile)
for dependency in dependencies:
for dependency in merged_dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError as exc:
Expand Down