Skip to content

Commit cd119d6

Browse files
add example on how to use prediction api
1 parent ef5fe0c commit cd119d6

4 files changed

Lines changed: 32 additions & 3 deletions

File tree

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ def predict(self, request: PredictionRequest) -> PredictionResponse:
3131
version="v1",
3232
prediction_type=request.prediction_type,
3333
status_code=200,
34-
data={"results": [{"prediction": {"predictedValue": "1"}}]},
34+
data={"results": [{"prediction": {"predictedValue": 1.0}}]},
3535
)

src/datacustomcode/einstein_predictions/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def set_prediction_type(
150150
self._prediction_type = prediction_type
151151
return self
152152

153-
def set_model(self, model_api_name: str) -> "PredictionRequestBuilder":
153+
def set_model_api_name(self, model_api_name: str) -> "PredictionRequestBuilder":
154154
self._model_api_name = model_api_name
155155
return self
156156

src/datacustomcode/templates/function/payload/entrypoint.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
from typing import List
33
from uuid import uuid4
44

5+
from datacustomcode.einstein_predictions.types import (
6+
PredictionColumBuilder,
7+
PredictionRequestBuilder,
8+
PredictionType,
9+
)
510
from datacustomcode.function import Runtime
611
from datacustomcode.llm_gateway.types.generate_text_request_builder import (
712
GenerateTextRequestBuilder,
@@ -38,6 +43,28 @@ def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
3843
return chunks
3944

4045

46+
def make_einstein_prediction(runtime: Runtime) -> None:
47+
column = (
48+
PredictionColumBuilder()
49+
.set_column_name("col1")
50+
.set_string_values(["str1", "str2"])
51+
.build()
52+
)
53+
prediction_request = (
54+
PredictionRequestBuilder()
55+
.set_prediction_type(PredictionType.REGRESSION)
56+
.set_model_api_name("regressionModel")
57+
.set_prediction_columns([column])
58+
.build()
59+
)
60+
61+
prediction_response = runtime.einstein_predictions.predict(prediction_request)
62+
print(
63+
f"Einstein prediction results - success: {prediction_response.is_success} \
64+
response data: {prediction_response.data}"
65+
)
66+
67+
4168
def function(request: dict, runtime: Runtime) -> dict:
4269
logger.info("Inside Function")
4370
logger.info(request)
@@ -46,6 +73,8 @@ def function(request: dict, runtime: Runtime) -> dict:
4673
output_chunks = []
4774
current_seq_no = 1 # Start sequence number from 1
4875

76+
make_einstein_prediction(runtime)
77+
4978
builder = GenerateTextRequestBuilder()
5079
llm_request = builder.set_prompt("Hello").set_model("modelName").build()
5180
llm_response = runtime.llm_gateway.generate_text(llm_request)

tests/test_einstein_predictions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_builder_creates_valid_request(self):
142142
request = (
143143
PredictionRequestBuilder()
144144
.set_prediction_type(PredictionType.CLUSTERING)
145-
.set_model("cluster_model")
145+
.set_model_api_name("cluster_model")
146146
.set_prediction_columns(
147147
[PredictionColumn(column_name="test_col", double_values=[1.0])]
148148
)

0 commit comments

Comments
 (0)