Skip to content

Commit 44f1f72

Browse files
add support for einstein predict
1 parent a68b42c commit 44f1f72

4 files changed

Lines changed: 394 additions & 0 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from abc import ABC, abstractmethod
17+
18+
from datacustomcode.einstein_predictions.types import (PredictionRequest, PredictionResponse)
19+
20+
class EinsteinPredictions(ABC):
21+
@abstractmethod
22+
def predict(self, request: PredictionRequest) -> PredictionResponse: ...
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
# Copyright (c) 2025, Salesforce, Inc.
3+
# SPDX-License-Identifier: Apache-2
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from datacustomcode.einstein_predictions.types import (
18+
PredictionRequest,
19+
PredictionResponse
20+
)
21+
22+
class DefaultEinsteinPredictions:
23+
def __init__(self, base_url: str, access_token: str) -> None:
24+
pass
25+
26+
def predict(self, request: PredictionRequest) -> PredictionResponse:
27+
pass
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from enum import Enum, unique
17+
from pydantic import (
18+
BaseModel,
19+
Field,
20+
model_validator,
21+
)
22+
23+
from typing import (
24+
Literal,
25+
Optional,
26+
Dict,
27+
Any
28+
)
29+
30+
@unique
31+
class PredictionType(Enum):
32+
REGRESSION = 1
33+
CLUSTERING = 2
34+
CLASSIFICATION = 3
35+
MULTI_OUTCOME = 4
36+
BINARY_CLASSIFICATION = 5
37+
38+
class PredictionColumn(BaseModel):
39+
column_name: str = Field(min_length=1, description="Column name")
40+
string_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column string values")
41+
double_values: Optional[list[float]] = Field(default = None, min_length=1, description="Column double values")
42+
boolean_values: Optional[list[bool]] = Field(default = None, min_length=1, description="Column boolean values")
43+
date_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column date values")
44+
datetime_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column datetime values")
45+
46+
@model_validator(mode='after')
47+
def validate_exactly_one_value_type(self):
48+
set_count = sum([
49+
self.string_values is not None,
50+
self.double_values is not None,
51+
self.boolean_values is not None,
52+
self.date_values is not None,
53+
self.datetime_values is not None
54+
])
55+
56+
if set_count != 1:
57+
raise ValueError("Exactly one value type must be set")
58+
59+
return self
60+
61+
class PredictionColumBuilder:
62+
def __init__(self) -> None:
63+
self._column_name: str = None
64+
self._string_values: list[str] = None
65+
self._double_values: list[float] = None
66+
self._boolean_values: list[bool] = None
67+
self._date_values: list[str] = None
68+
self._datetime_values: list[str] = None
69+
70+
def set_column_name(self, column_name: str) -> "PredictionColumBuilder":
71+
self._column_name = column_name
72+
return self
73+
74+
def set_string_values(self, string_values: list[str]) -> "PredictionColumBuilder":
75+
self._string_values = string_values
76+
return self
77+
78+
def set_double_values(self, double_values: list[float]) -> "PredictionColumBuilder":
79+
self._double_values = double_values
80+
return self
81+
82+
def set_boolean_values(self, boolean_values: list[bool]) -> "PredictionColumBuilder":
83+
self._boolean_values = boolean_values
84+
return self
85+
86+
def set_date_values(self, date_values: list[str]) -> "PredictionColumBuilder":
87+
self._date_values = date_values
88+
return self
89+
90+
def set_datetime_values(self, datetime_values: list[str]) -> "PredictionColumBuilder":
91+
self._datetime_values = datetime_values
92+
return self
93+
94+
def build(self) -> PredictionColumn:
95+
return PredictionColumn(
96+
column_name = self._column_name,
97+
string_values = self._string_values,
98+
double_values = self._double_values,
99+
boolean_values = self._boolean_values,
100+
date_values = self._date_values,
101+
datetime_values = self._datetime_values
102+
)
103+
104+
class PredictionRequest(BaseModel):
105+
version: Literal["v1"] = Field(
106+
default="v1", description="API version, must be 'v1'"
107+
)
108+
prediction_type: PredictionType = Field(description="Prediction type")
109+
model_api_name: str = Field(min_length=1, description="API name of the model to use")
110+
prediction_columns: list[PredictionColumn] = Field(min_length=1, description="List of prediction columns")
111+
settings: Optional[Dict[str, Any]] = Field(default=None, description="Settings for the prediction request")
112+
113+
class PredictionRequestBuilder:
114+
def __init__(self) -> None:
115+
self._prediction_type: PredictionType = None
116+
self._model_api_name: str = None
117+
self._prediction_columns: list[PredictionColumn] = []
118+
self._settings: Dict[str, Any] = None
119+
120+
def set_prediction_type(self, prediction_type: PredictionType) -> "PredictionRequestBuilder":
121+
self._prediction_type = prediction_type
122+
return self
123+
124+
def set_model(self, model_api_name: str) -> "PredictionRequestBuilder":
125+
self._model_api_name = model_api_name
126+
return self
127+
128+
def set_prediction_columns(
129+
self,
130+
prediction_columns: list[PredictionColumn]
131+
) -> "PredictionRequestBuilder":
132+
self._prediction_columns = prediction_columns
133+
return self
134+
135+
def set_settings(self, settings: Dict[str, Any]):
136+
self._settings = settings
137+
return self
138+
139+
def build(self) -> PredictionRequest:
140+
return PredictionRequest(
141+
prediction_type=self._prediction_type,
142+
model_api_name=self._model_api_name,
143+
prediction_columns=self._prediction_columns,
144+
settings=self._settings
145+
)
146+
147+
class PredictionResponse(BaseModel):
148+
version: Literal["v1"] = Field(default="v1", description="API version")
149+
prediction_type: PredictionType = Field(description="Prediction type")
150+
status_code: int = Field(description="HTTP status code")
151+
data: Optional[Dict[str, Any]] = Field(default=None, description="Response data")
152+
153+
@property
154+
def is_success(self) -> bool:
155+
return self.status_code == 200
156+

tests/test_einstein_predictions.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
from datacustomcode.einstein_predictions.types import (
5+
PredictionColumn,
6+
PredictionRequest,
7+
PredictionResponse,
8+
PredictionType,
9+
PredictionColumBuilder,
10+
PredictionRequestBuilder,
11+
)
12+
13+
class TestPredictionColumnValidation:
14+
def test_string_values_only(self):
15+
column = PredictionColumn(
16+
column_name="test_col",
17+
string_values=["a", "b", "c"]
18+
)
19+
assert column.column_name == "test_col"
20+
assert column.string_values == ["a", "b", "c"]
21+
assert column.double_values is None
22+
assert column.boolean_values is None
23+
assert column.date_values is None
24+
assert column.datetime_values is None
25+
26+
def test_double_values_only(self):
27+
column = PredictionColumn(
28+
column_name="test_col",
29+
double_values=[1.0, 2.5, 3.7]
30+
)
31+
assert column.double_values == [1.0, 2.5, 3.7]
32+
assert column.string_values is None
33+
assert column.boolean_values is None
34+
assert column.date_values is None
35+
assert column.datetime_values is None
36+
37+
def test_boolean_values_only(self):
38+
column = PredictionColumn(
39+
column_name="test_col",
40+
boolean_values=[True, False, True]
41+
)
42+
assert column.boolean_values == [True, False, True]
43+
assert column.string_values is None
44+
assert column.double_values is None
45+
assert column.date_values is None
46+
assert column.datetime_values is None
47+
48+
49+
def test_date_values_only(self):
50+
column = PredictionColumn(
51+
column_name="test_col",
52+
date_values=["2024-01-01", "2024-01-02"]
53+
)
54+
assert column.date_values == ["2024-01-01", "2024-01-02"]
55+
assert column.string_values is None
56+
assert column.double_values is None
57+
assert column.boolean_values is None
58+
assert column.datetime_values is None
59+
60+
def test_datetime_values_only(self):
61+
column = PredictionColumn(
62+
column_name="test_col",
63+
datetime_values=["2024-01-01T12:00:00", "2024-01-02T13:00:00"]
64+
)
65+
assert column.datetime_values == ["2024-01-01T12:00:00", "2024-01-02T13:00:00"]
66+
assert column.string_values is None
67+
assert column.double_values is None
68+
assert column.boolean_values is None
69+
assert column.date_values is None
70+
71+
def test_no_column_name_raises_error(self):
72+
with pytest.raises(ValidationError) as exc_info:
73+
PredictionColumn(
74+
column_name="",
75+
string_values=["a", "b"],
76+
double_values=[1.0, 2.0]
77+
)
78+
79+
assert str(exc_info.value) is not None
80+
81+
def test_no_values_raises_error(self):
82+
with pytest.raises(ValidationError) as exc_info:
83+
PredictionColumn(column_name="test_col")
84+
85+
assert str(exc_info.value) is not None
86+
87+
def test_string_and_double_raises_error(self):
88+
with pytest.raises(ValidationError) as exc_info:
89+
PredictionColumn(
90+
column_name="test_col",
91+
string_values=["a", "b"],
92+
double_values=[1.0, 2.0]
93+
)
94+
95+
assert str(exc_info.value) is not None
96+
97+
def test_empty_values_raises_error(self):
98+
with pytest.raises(ValidationError) as exc_info:
99+
PredictionColumn(
100+
column_name="test_col",
101+
string_values=[]
102+
)
103+
104+
assert str(exc_info.value) is not None
105+
106+
class TestPredictionColumnBuilder:
107+
def test_builder_with_string_values(self):
108+
column = (PredictionColumBuilder()
109+
.set_column_name("test_col")
110+
.set_string_values(["a", "b"])
111+
.build())
112+
113+
assert column.column_name == "test_col"
114+
assert column.string_values == ["a", "b"]
115+
116+
class TestPredictionRequest:
117+
def test_request_with_multiple_columns(self):
118+
request = PredictionRequest(
119+
prediction_type=PredictionType.CLASSIFICATION,
120+
model_api_name="classifier",
121+
prediction_columns=[
122+
PredictionColumn(column_name="col1", string_values=["a"]),
123+
PredictionColumn(column_name="col2", double_values=[1.0]),
124+
PredictionColumn(column_name="col3", boolean_values=[True])
125+
]
126+
)
127+
128+
assert len(request.prediction_columns) == 3
129+
130+
def test_request_requires_model_api_name(self):
131+
with pytest.raises(ValidationError):
132+
PredictionRequest(
133+
prediction_type=PredictionType.REGRESSION,
134+
model_api_name="",
135+
prediction_columns=[
136+
PredictionColumn(column_name="col1", double_values=[1.0])
137+
]
138+
)
139+
140+
def test_request_requires_prediction_columns(self):
141+
with pytest.raises(ValidationError):
142+
PredictionRequest(
143+
prediction_type=PredictionType.REGRESSION,
144+
model_api_name="model",
145+
prediction_columns=[]
146+
)
147+
148+
class TestPredictionRequestBuilder:
149+
def test_builder_creates_valid_request(self):
150+
request = (PredictionRequestBuilder()
151+
.set_prediction_type(PredictionType.CLUSTERING)
152+
.set_model("cluster_model")
153+
.set_prediction_columns([
154+
PredictionColumn(column_name="test_col", double_values=[1.0])
155+
])
156+
.set_settings({
157+
'maxTopContributors': 20
158+
})
159+
.build())
160+
161+
assert request.prediction_type == PredictionType.CLUSTERING
162+
assert request.model_api_name == "cluster_model"
163+
assert len(request.prediction_columns) == 1
164+
assert request.settings == { 'maxTopContributors': 20 }
165+
166+
167+
class TestPredictionResponse:
168+
def test_successful_response(self):
169+
response = PredictionResponse(
170+
version="v1",
171+
prediction_type=PredictionType.REGRESSION,
172+
status_code=200,
173+
data={"results": [{"prediction": {"value": 42.5}}]}
174+
)
175+
176+
assert response.is_success
177+
assert response.status_code == 200
178+
assert response.data is not None
179+
180+
def test_failed_response(self):
181+
response = PredictionResponse(
182+
version="v1",
183+
prediction_type=PredictionType.REGRESSION,
184+
status_code=500,
185+
data={"error": "Internal server error"}
186+
)
187+
188+
assert not response.is_success
189+
assert response.status_code == 500

0 commit comments

Comments
 (0)