Skip to content

Commit ea1c41c

Browse files
add EP to configuration
1 parent 7161e25 commit ea1c41c

12 files changed

Lines changed: 334 additions & 76 deletions
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import yaml
3+
from pydantic import (
4+
BaseModel,
5+
ConfigDict,
6+
Field,
7+
)
8+
9+
from typing import Any
10+
11+
DEFAULT_CONFIG_NAME = "config.yaml"
12+
13+
14+
def default_config_file() -> str:
15+
return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
16+
17+
18+
class ForceableConfig(BaseModel):
19+
force: bool = Field(
20+
default=False,
21+
description="If True, this takes precedence over parameters passed to the initializer of the client",
22+
)
23+
24+
25+
class BaseObjectConfig(ForceableConfig):
26+
model_config = ConfigDict(validate_default=True, extra="forbid")
27+
type_config_name: str = Field(
28+
description="The config name of the object to create",
29+
)
30+
options: dict[str, Any] = Field(
31+
default_factory=dict,
32+
description="Options passed to the constructor.",
33+
)
34+
35+
36+
class BaseConfig(BaseModel):
37+
def load(self, config_path: str) -> "BaseConfig":
38+
"""Load configuration from a YAML file and merge with existing config"""
39+
with open(config_path, "r") as f:
40+
config_data = yaml.safe_load(f)
41+
42+
loaded_config = self.__class__.model_validate(config_data)
43+
self.update(loaded_config)
44+
return self

src/datacustomcode/config.py

Lines changed: 7 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import os
1817
from typing import (
1918
TYPE_CHECKING,
2019
Any,
@@ -27,11 +26,8 @@
2726
)
2827

2928
from pydantic import (
30-
BaseModel,
31-
ConfigDict,
3229
Field,
3330
)
34-
import yaml
3531

3632
# This lets all readers and writers to be findable via config
3733
from datacustomcode.io import * # noqa: F403
@@ -41,38 +37,18 @@
4137
from datacustomcode.proxy.base import BaseProxyAccessLayer
4238
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
4339
from datacustomcode.spark.base import BaseSparkSessionProvider
44-
45-
DEFAULT_CONFIG_NAME = "config.yaml"
40+
from datacustomcode.common_config import ForceableConfig, BaseObjectConfig, BaseConfig, default_config_file
4641

4742

4843
if TYPE_CHECKING:
4944
from pyspark.sql import SparkSession
5045

5146

52-
class ForceableConfig(BaseModel):
53-
force: bool = Field(
54-
default=False,
55-
description="If True, this takes precedence over parameters passed to the "
56-
"initializer of the client.",
57-
)
58-
59-
6047
_T = TypeVar("_T", bound="BaseDataAccessLayer")
6148

6249

63-
class AccessLayerObjectConfig(ForceableConfig, Generic[_T]):
64-
model_config = ConfigDict(validate_default=True, extra="forbid")
50+
class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):
6551
type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer
66-
type_config_name: str = Field(
67-
description="The config name of the object to create. "
68-
"For metrics, this would might be 'ipmnormal'. For custom classes, you can "
69-
"assign a name to a class variable `CONFIG_NAME` and reference it here.",
70-
)
71-
options: dict[str, Any] = Field(
72-
default_factory=dict,
73-
description="Options passed to the constructor.",
74-
)
75-
7652
def to_object(self, spark: SparkSession) -> _T:
7753
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
7854
return cast(_T, type_(spark=spark, **self.options))
@@ -97,35 +73,22 @@ class SparkConfig(ForceableConfig):
9773
_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)
9874

9975

100-
class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
76+
class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
10177
"""Config for proxy clients that take no constructor args (e.g. no spark)."""
102-
103-
model_config = ConfigDict(validate_default=True, extra="forbid")
10478
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
105-
type_config_name: str = Field(
106-
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
107-
)
108-
options: dict[str, Any] = Field(default_factory=dict)
109-
11079
def to_object(self) -> _PX:
11180
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
11281
return cast(_PX, type_(**self.options))
11382

11483

115-
class SparkProviderConfig(ForceableConfig, Generic[_P]):
116-
model_config = ConfigDict(validate_default=True, extra="forbid")
84+
class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
11785
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
118-
type_config_name: str = Field(
119-
description="CONFIG_NAME of the Spark session provider."
120-
)
121-
options: dict[str, Any] = Field(default_factory=dict)
122-
12386
def to_object(self) -> _P:
12487
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
12588
return cast(_P, type_(**self.options))
12689

12790

128-
class ClientConfig(BaseModel):
91+
class ClientConfig(BaseConfig):
12992
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
13093
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
13194
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
@@ -163,31 +126,9 @@ def merge(
163126
)
164127
return self
165128

166-
def load(self, config_path: str) -> ClientConfig:
167-
"""Load a config from a file and update this config with it.
168-
169-
Args:
170-
config_path: The path to the config file
171-
172-
Returns:
173-
Self, with updated values from the loaded config.
174-
"""
175-
with open(config_path, "r") as f:
176-
config_data = yaml.safe_load(f)
177-
loaded_config = ClientConfig.model_validate(config_data)
178-
179-
return self.update(loaded_config)
180-
181-
182-
config = ClientConfig()
183129
"""Global config object.
184130
185131
This is the object that makes config accessible globally and globally mutable.
186132
"""
187-
188-
189-
def _defaults() -> str:
190-
return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
191-
192-
193-
config.load(_defaults())
133+
config = ClientConfig()
134+
config.load(default_config_file())

src/datacustomcode/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,7 @@ proxy_config:
2323
type_config_name: LocalProxyClientProvider
2424
options:
2525
credentials_profile: default
26+
27+
einstein_predictions_config:
28+
type_config_name: DefaultEinsteinPredictions
29+
options: {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from datacustomcode.einstein_predictions.base import EinsteinPredictions
17+
from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions
18+
19+
__all__ = [
20+
"EinsteinPredictions",
21+
"DefaultEinsteinPredictions",
22+
]

src/datacustomcode/einstein_predictions/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from abc import ABC, abstractmethod
1717

1818
from datacustomcode.einstein_predictions.types import (PredictionRequest, PredictionResponse)
19+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
20+
21+
class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin):
22+
CONFIG_NAME: str
23+
24+
def __init__(self, **kwargs):
25+
pass
1926

20-
class EinsteinPredictions(ABC):
2127
@abstractmethod
2228
def predict(self, request: PredictionRequest) -> PredictionResponse: ...

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from datacustomcode.einstein_predictions.base import EinsteinPredictions
1718
from datacustomcode.einstein_predictions.types import (
1819
PredictionRequest,
1920
PredictionResponse
2021
)
2122

22-
class DefaultEinsteinPredictions:
23-
def __init__(self, base_url: str, access_token: str) -> None:
24-
pass
23+
class DefaultEinsteinPredictions(EinsteinPredictions):
24+
CONFIG_NAME = "DefaultEinsteinPredictions"
25+
26+
def __init__(self, **kwargs):
27+
super().__init__(**kwargs)
2528

2629
def predict(self, request: PredictionRequest) -> PredictionResponse:
27-
pass
30+
return PredictionResponse(
31+
version="v1",
32+
prediction_type=request.prediction_type,
33+
status_code=200,
34+
data={"results": [{"prediction": {"predictedValue": "1"}}]}
35+
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import (
17+
ClassVar,
18+
Generic,
19+
Type,
20+
TypeVar,
21+
Union,
22+
cast,
23+
)
24+
25+
from datacustomcode.einstein_predictions.base import EinsteinPredictions
26+
from datacustomcode.common_config import BaseObjectConfig, BaseConfig, default_config_file
27+
28+
_E = TypeVar("_E", bound=EinsteinPredictions)
29+
30+
class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]):
31+
type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions
32+
def to_object(self) -> _E:
33+
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
34+
return cast(_E, type_(**self.options))
35+
36+
37+
class EinsteinPredictionsConfig(BaseConfig):
38+
einstein_predictions_config: Union[EinsteinPredictionsObjectConfig[EinsteinPredictions], None] = None
39+
def update(self, other: "EinsteinPredictionsConfig") -> "EinsteinPredictionsConfig":
40+
def merge(
41+
config_a: Union[EinsteinPredictionsObjectConfig, None],
42+
config_b: Union[EinsteinPredictionsObjectConfig, None]
43+
) -> Union[EinsteinPredictionsObjectConfig, None]:
44+
if config_a is not None and config_a.force:
45+
return config_a
46+
if config_b:
47+
return config_b
48+
return config_a
49+
50+
self.einstein_predictions_config = merge(
51+
self.einstein_predictions_config, other.einstein_predictions_config
52+
)
53+
return self
54+
55+
# Global Einstein Predictions config instance
56+
einstein_predictions_config = EinsteinPredictionsConfig()
57+
einstein_predictions_config.load(default_config_file())

src/datacustomcode/function/runtime.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import threading
1818
from typing import Optional
1919

20+
from datacustomcode.einstein_predictions.base import EinsteinPredictions
21+
from datacustomcode.einstein_predictions_config import einstein_predictions_config
2022
from datacustomcode.file.path.default import DefaultFindFilePath
2123
from datacustomcode.function.base import BaseRuntime
2224
from datacustomcode.llm_gateway.default import DefaultLLMGateway
@@ -65,6 +67,7 @@ def __init__(self) -> None:
6567
# Initialize resources
6668
self._llm_gateway = DefaultLLMGateway()
6769
self._file = DefaultFindFilePath()
70+
self._einstein_predictions: Optional[EinsteinPredictions] = None
6871

6972
@property
7073
def llm_gateway(self) -> DefaultLLMGateway:
@@ -75,3 +78,13 @@ def llm_gateway(self) -> DefaultLLMGateway:
7578
def file(self) -> DefaultFindFilePath:
7679
"""Access file operations."""
7780
return self._file
81+
82+
@property
83+
def einstein_predictions(self) -> EinsteinPredictions:
84+
if self._einstein_predictions is None:
85+
if einstein_predictions_config.einstein_predictions_config is None:
86+
raise RuntimeError(
87+
"Einstein Predictions is not configured.Add 'einstein_predictions_config' section to config.yaml"
88+
)
89+
self._einstein_predictions = einstein_predictions_config.einstein_predictions_config.to_object()
90+
return self._einstein_predictions

tests/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
AccessLayerObjectConfig,
1212
ClientConfig,
1313
SparkConfig,
14-
_defaults,
14+
default_config_file,
1515
config,
1616
)
1717
from datacustomcode.io.base import BaseDataAccessLayer
@@ -298,8 +298,8 @@ def test_load_config_from_file(self):
298298
os.unlink(temp_path)
299299

300300
def test_defaults(self):
301-
# Just verify that _defaults function exists and returns a string path
302-
result = _defaults()
301+
# Just verify that default_config_file function exists and returns a string path
302+
result = default_config_file()
303303
assert isinstance(result, str)
304304
assert result.endswith("config.yaml")
305305

0 commit comments

Comments
 (0)