Skip to content

Commit 732a998

Browse files
fix lint errors
1 parent 54d81a6 commit 732a998

13 files changed

Lines changed: 212 additions & 143 deletions

src/datacustomcode/common_config.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
import os
2-
import yaml
16+
from typing import Any
17+
318
from pydantic import (
419
BaseModel,
520
ConfigDict,
621
Field,
722
)
8-
9-
from typing import Any
23+
import yaml
1024

1125
DEFAULT_CONFIG_NAME = "config.yaml"
1226

@@ -18,7 +32,8 @@ def default_config_file() -> str:
1832
class ForceableConfig(BaseModel):
1933
force: bool = Field(
2034
default=False,
21-
description="If True, this takes precedence over parameters passed to the initializer of the client",
35+
description="If True, this takes precedence over parameters passed to the "
36+
"initializer of the client",
2237
)
2338

2439

@@ -34,6 +49,9 @@ class BaseObjectConfig(ForceableConfig):
3449

3550

3651
class BaseConfig(BaseModel):
52+
def update(self, other: Any) -> "BaseConfig":
53+
raise NotImplementedError("Subclasses must implement update method")
54+
3755
def load(self, config_path: str) -> "BaseConfig":
3856
"""Load configuration from a YAML file and merge with existing config"""
3957
with open(config_path, "r") as f:

src/datacustomcode/config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,13 @@
2525
cast,
2626
)
2727

28-
from pydantic import (
29-
Field,
28+
from pydantic import Field
29+
30+
from datacustomcode.common_config import (
31+
BaseConfig,
32+
BaseObjectConfig,
33+
ForceableConfig,
34+
default_config_file,
3035
)
3136

3237
# This lets all readers and writers to be findable via config
@@ -37,8 +42,6 @@
3742
from datacustomcode.proxy.base import BaseProxyAccessLayer
3843
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
3944
from datacustomcode.spark.base import BaseSparkSessionProvider
40-
from datacustomcode.common_config import ForceableConfig, BaseObjectConfig, BaseConfig, default_config_file
41-
4245

4346
if TYPE_CHECKING:
4447
from pyspark.sql import SparkSession
@@ -49,6 +52,7 @@
4952

5053
class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):
5154
type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer
55+
5256
def to_object(self, spark: SparkSession) -> _T:
5357
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
5458
return cast(_T, type_(spark=spark, **self.options))
@@ -75,14 +79,17 @@ class SparkConfig(ForceableConfig):
7579

7680
class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
7781
"""Config for proxy clients that take no constructor args (e.g. no spark)."""
82+
7883
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
84+
7985
def to_object(self) -> _PX:
8086
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
8187
return cast(_PX, type_(**self.options))
8288

8389

8490
class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
8591
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
92+
8693
def to_object(self) -> _P:
8794
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
8895
return cast(_P, type_(**self.options))
@@ -126,6 +133,7 @@ def merge(
126133
)
127134
return self
128135

136+
129137
"""Global config object.
130138
131139
This is the object that makes config accessible globally and globally mutable.

src/datacustomcode/einstein_predictions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
__all__ = [
2020
"EinsteinPredictions",
2121
"DefaultEinsteinPredictions",
22-
]
22+
]

src/datacustomcode/einstein_predictions/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515

1616
from abc import ABC, abstractmethod
1717

18-
from datacustomcode.einstein_predictions.types import (PredictionRequest, PredictionResponse)
18+
from datacustomcode.einstein_predictions.types import (
19+
PredictionRequest,
20+
PredictionResponse,
21+
)
1922
from datacustomcode.mixin import UserExtendableNamedConfigMixin
2023

24+
2125
class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin):
2226
CONFIG_NAME: str
2327

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) 2025, Salesforce, Inc.
32
# SPDX-License-Identifier: Apache-2
43
#
@@ -17,9 +16,10 @@
1716
from datacustomcode.einstein_predictions.base import EinsteinPredictions
1817
from datacustomcode.einstein_predictions.types import (
1918
PredictionRequest,
20-
PredictionResponse
19+
PredictionResponse,
2120
)
2221

22+
2323
class DefaultEinsteinPredictions(EinsteinPredictions):
2424
CONFIG_NAME = "DefaultEinsteinPredictions"
2525

@@ -28,8 +28,8 @@ def __init__(self, **kwargs):
2828

2929
def predict(self, request: PredictionRequest) -> PredictionResponse:
3030
return PredictionResponse(
31-
version="v1",
32-
prediction_type=request.prediction_type,
33-
status_code=200,
34-
data={"results": [{"prediction": {"predictedValue": "1"}}]}
35-
)
31+
version="v1",
32+
prediction_type=request.prediction_type,
33+
status_code=200,
34+
data={"results": [{"prediction": {"predictedValue": "1"}}]},
35+
)

src/datacustomcode/einstein_predictions/types.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414
# limitations under the License.
1515

1616
from enum import Enum, unique
17+
from typing import (
18+
Any,
19+
Dict,
20+
Literal,
21+
Optional,
22+
)
23+
1724
from pydantic import (
1825
BaseModel,
1926
Field,
2027
model_validator,
2128
)
2229

23-
from typing import (
24-
Literal,
25-
Optional,
26-
Dict,
27-
Any
28-
)
2930

3031
@unique
3132
class PredictionType(Enum):
@@ -35,37 +36,51 @@ class PredictionType(Enum):
3536
MULTI_OUTCOME = 4
3637
BINARY_CLASSIFICATION = 5
3738

39+
3840
class PredictionColumn(BaseModel):
3941
column_name: str = Field(min_length=1, description="Column name")
40-
string_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column string values")
41-
double_values: Optional[list[float]] = Field(default = None, min_length=1, description="Column double values")
42-
boolean_values: Optional[list[bool]] = Field(default = None, min_length=1, description="Column boolean values")
43-
date_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column date values")
44-
datetime_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column datetime values")
42+
string_values: Optional[list[str]] = Field(
43+
default=None, min_length=1, description="Column string values"
44+
)
45+
double_values: Optional[list[float]] = Field(
46+
default=None, min_length=1, description="Column double values"
47+
)
48+
boolean_values: Optional[list[bool]] = Field(
49+
default=None, min_length=1, description="Column boolean values"
50+
)
51+
date_values: Optional[list[str]] = Field(
52+
default=None, min_length=1, description="Column date values"
53+
)
54+
datetime_values: Optional[list[str]] = Field(
55+
default=None, min_length=1, description="Column datetime values"
56+
)
4557

46-
@model_validator(mode='after')
58+
@model_validator(mode="after")
4759
def validate_exactly_one_value_type(self):
48-
set_count = sum([
49-
self.string_values is not None,
50-
self.double_values is not None,
51-
self.boolean_values is not None,
52-
self.date_values is not None,
53-
self.datetime_values is not None
54-
])
60+
set_count = sum(
61+
[
62+
self.string_values is not None,
63+
self.double_values is not None,
64+
self.boolean_values is not None,
65+
self.date_values is not None,
66+
self.datetime_values is not None,
67+
]
68+
)
5569

5670
if set_count != 1:
5771
raise ValueError("Exactly one value type must be set")
5872

5973
return self
6074

75+
6176
class PredictionColumBuilder:
6277
def __init__(self) -> None:
63-
self._column_name: str = None
64-
self._string_values: list[str] = None
65-
self._double_values: list[float] = None
66-
self._boolean_values: list[bool] = None
67-
self._date_values: list[str] = None
68-
self._datetime_values: list[str] = None
78+
self._column_name: Optional[str] = None
79+
self._string_values: Optional[list[str]] = None
80+
self._double_values: Optional[list[float]] = None
81+
self._boolean_values: Optional[list[bool]] = None
82+
self._date_values: Optional[list[str]] = None
83+
self._datetime_values: Optional[list[str]] = None
6984

7085
def set_column_name(self, column_name: str) -> "PredictionColumBuilder":
7186
self._column_name = column_name
@@ -79,45 +94,59 @@ def set_double_values(self, double_values: list[float]) -> "PredictionColumBuild
7994
self._double_values = double_values
8095
return self
8196

82-
def set_boolean_values(self, boolean_values: list[bool]) -> "PredictionColumBuilder":
97+
def set_boolean_values(
98+
self, boolean_values: list[bool]
99+
) -> "PredictionColumBuilder":
83100
self._boolean_values = boolean_values
84101
return self
85102

86103
def set_date_values(self, date_values: list[str]) -> "PredictionColumBuilder":
87104
self._date_values = date_values
88105
return self
89106

90-
def set_datetime_values(self, datetime_values: list[str]) -> "PredictionColumBuilder":
107+
def set_datetime_values(
108+
self, datetime_values: list[str]
109+
) -> "PredictionColumBuilder":
91110
self._datetime_values = datetime_values
92111
return self
93112

94113
def build(self) -> PredictionColumn:
95114
return PredictionColumn(
96-
column_name = self._column_name,
97-
string_values = self._string_values,
98-
double_values = self._double_values,
99-
boolean_values = self._boolean_values,
100-
date_values = self._date_values,
101-
datetime_values = self._datetime_values
115+
column_name=self._column_name,
116+
string_values=self._string_values,
117+
double_values=self._double_values,
118+
boolean_values=self._boolean_values,
119+
date_values=self._date_values,
120+
datetime_values=self._datetime_values,
102121
)
103122

123+
104124
class PredictionRequest(BaseModel):
105125
version: Literal["v1"] = Field(
106126
default="v1", description="API version, must be 'v1'"
107127
)
108128
prediction_type: PredictionType = Field(description="Prediction type")
109-
model_api_name: str = Field(min_length=1, description="API name of the model to use")
110-
prediction_columns: list[PredictionColumn] = Field(min_length=1, description="List of prediction columns")
111-
settings: Optional[Dict[str, Any]] = Field(default=None, description="Settings for the prediction request")
129+
model_api_name: str = Field(
130+
min_length=1, description="API name of the model to use"
131+
)
132+
prediction_columns: list[PredictionColumn] = Field(
133+
min_length=1, description="List of prediction columns"
134+
)
135+
settings: Optional[Dict[str, Any]] = Field(
136+
default=None, description="Settings for the prediction request"
137+
)
138+
112139

113140
class PredictionRequestBuilder:
114141
def __init__(self) -> None:
115-
self._prediction_type: PredictionType = None
116-
self._model_api_name: str = None
142+
self._prediction_type: Optional[PredictionType] = None
143+
self._model_api_name: Optional[str] = None
117144
self._prediction_columns: list[PredictionColumn] = []
118-
self._settings: Dict[str, Any] = None
145+
self._settings: Optional[Dict[str, Any]] = None
119146

120-
def set_prediction_type(self, prediction_type: PredictionType) -> "PredictionRequestBuilder":
147+
def set_prediction_type(
148+
self, prediction_type: PredictionType
149+
) -> "PredictionRequestBuilder":
121150
self._prediction_type = prediction_type
122151
return self
123152

@@ -126,8 +155,7 @@ def set_model(self, model_api_name: str) -> "PredictionRequestBuilder":
126155
return self
127156

128157
def set_prediction_columns(
129-
self,
130-
prediction_columns: list[PredictionColumn]
158+
self, prediction_columns: list[PredictionColumn]
131159
) -> "PredictionRequestBuilder":
132160
self._prediction_columns = prediction_columns
133161
return self
@@ -141,9 +169,10 @@ def build(self) -> PredictionRequest:
141169
prediction_type=self._prediction_type,
142170
model_api_name=self._model_api_name,
143171
prediction_columns=self._prediction_columns,
144-
settings=self._settings
172+
settings=self._settings,
145173
)
146-
174+
175+
147176
class PredictionResponse(BaseModel):
148177
version: Literal["v1"] = Field(default="v1", description="API version")
149178
prediction_type: PredictionType = Field(description="Prediction type")
@@ -153,4 +182,3 @@ class PredictionResponse(BaseModel):
153182
@property
154183
def is_success(self) -> bool:
155184
return self.status_code == 200
156-

0 commit comments

Comments
 (0)