Skip to content

Commit 036f56f

Browse files
update based on the PR feedback, including removing code duplication
1 parent 06425e4 commit 036f56f

8 files changed

Lines changed: 71 additions & 65 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ python = ">=3.10,<3.12"
107107
pyyaml = "^6.0"
108108
salesforce-cdp-connector = ">=1.0.19"
109109
setuptools_scm = "^7.1.0"
110+
requests = "2.33.1"
110111

111112
[tool.poetry.group.dev.dependencies]
112113
build = "*"

src/datacustomcode/einstein_platform_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@
2929

3030

3131
class EinsteinPlatformClient:
32-
EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1"
32+
EINSTEIN_PLATFORM_MODELS_URL = (
33+
"https://api.salesforce.com/einstein/platform/v1/models"
34+
)
3335

3436
def __init__(
3537
self,
3638
credentials_profile: Optional[str] = None,
3739
sf_cli_org: Optional[str] = None,
40+
**kwargs: Any,
3841
):
3942
if sf_cli_org:
4043
self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org)
@@ -44,6 +47,7 @@ def __init__(
4447
self._token_provider = CredentialsTokenProvider(profile)
4548
logger.debug(f"Using credentials token provider with profile: {profile}")
4649
self.token_response = None
50+
super().__init__(**kwargs)
4751

4852
def get_headers(self):
4953
if self.token_response is None:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import (
17+
ClassVar,
18+
Optional,
19+
Type,
20+
cast,
21+
)
22+
23+
from datacustomcode.common_config import BaseObjectConfig
24+
25+
26+
class CredentialsObjectConfig(BaseObjectConfig):
27+
type_to_create: ClassVar[Type]
28+
credentials_profile: Optional[str] = None
29+
sf_cli_org: Optional[str] = None
30+
31+
def to_object(self):
32+
"""Create an object instance, automatically including credentials in options"""
33+
34+
options = self.options.copy()
35+
if self.credentials_profile is not None:
36+
options["credentials_profile"] = self.credentials_profile
37+
if self.sf_cli_org is not None:
38+
options["sf_cli_org"] = self.sf_cli_org
39+
40+
type_ = self.type_to_create.subclass_from_config_name(self.type_config_name)
41+
return cast(type_, type_(**options))

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
ClassVar,
1919
Dict,
2020
List,
21-
Optional,
2221
)
2322

2423
from loguru import logger
@@ -43,17 +42,6 @@ class DefaultEinsteinPredictions(EinsteinPlatformClient, EinsteinPredictions):
4342
PredictionType.MULTI_OUTCOME: "multi-outcome",
4443
}
4544

46-
def __init__(
47-
self,
48-
credentials_profile: Optional[str] = None,
49-
sf_cli_org: Optional[str] = None,
50-
**kwargs,
51-
):
52-
EinsteinPlatformClient.__init__(
53-
self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org
54-
)
55-
EinsteinPredictions.__init__(self, **kwargs)
56-
5745
def predict(self, request: PredictionRequest) -> PredictionResponse:
5846
endpoint = self.ENDPOINT_MAP.get(request.prediction_type)
5947
if not endpoint:
@@ -63,8 +51,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse:
6351
)
6452

6553
api_url = (
66-
f"{self.EINSTEIN_PLATFORM_URL}/models/"
67-
f"{request.model_api_name}/{endpoint}"
54+
f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_api_name}/{endpoint}"
6855
)
6956

7057
prediction_columns: List[Dict[str, Any]] = []

src/datacustomcode/einstein_predictions_config.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,17 @@
1919
Type,
2020
TypeVar,
2121
Union,
22-
cast,
2322
)
2423

25-
from datacustomcode.common_config import (
26-
BaseConfig,
27-
BaseObjectConfig,
28-
default_config_file,
29-
)
24+
from datacustomcode.common_config import BaseConfig, default_config_file
25+
from datacustomcode.einstein_platform_config import CredentialsObjectConfig
3026
from datacustomcode.einstein_predictions.base import EinsteinPredictions
3127

3228
_E = TypeVar("_E", bound=EinsteinPredictions)
3329

3430

35-
class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]):
36-
type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract]
37-
38-
def to_object(self) -> _E:
39-
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
40-
return cast("_E", type_(**self.options))
31+
class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]):
32+
type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract]
4133

4234

4335
class EinsteinPredictionsConfig(BaseConfig):

src/datacustomcode/llm_gateway/default.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import (
17-
Any,
18-
Dict,
19-
Optional,
20-
)
16+
from typing import Any, Dict
2117

2218
from loguru import logger
2319
import requests
@@ -26,25 +22,17 @@
2622
from datacustomcode.llm_gateway.base import LLMGateway
2723
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
2824
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
25+
from datacustomcode.llm_gateway.types.generate_text_response_builder import (
26+
GenerateTextResponseBuilder,
27+
)
2928

3029

3130
class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway):
3231
CONFIG_NAME = "DefaultLLMGateway"
3332

34-
def __init__(
35-
self,
36-
credentials_profile: Optional[str] = None,
37-
sf_cli_org: Optional[str] = None,
38-
**kwargs,
39-
):
40-
EinsteinPlatformClient.__init__(
41-
self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org
42-
)
43-
LLMGateway.__init__(self, **kwargs)
44-
4533
def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
4634
api_url = (
47-
f"{self.EINSTEIN_PLATFORM_URL}/models/{request.model_name}/generations"
35+
f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_name}/generations"
4836
)
4937

5038
payload: Dict[str, Any] = {"prompt": request.prompt}
@@ -69,6 +57,8 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
6957
logger.error(f"Generate text request failed: {api_url} {e}")
7058
raise RuntimeError(f"Generate text request failed: {e}") from e
7159

72-
return GenerateTextResponse(
73-
status_code=response.status_code, data=self.parse_response(response)
74-
)
60+
response_dict = {
61+
"status_code": response.status_code,
62+
"data": self.parse_response(response),
63+
}
64+
return GenerateTextResponseBuilder.build(response_dict)

src/datacustomcode/llm_gateway_config.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,17 @@
1919
Type,
2020
TypeVar,
2121
Union,
22-
cast,
2322
)
2423

25-
from datacustomcode.common_config import (
26-
BaseConfig,
27-
BaseObjectConfig,
28-
default_config_file,
29-
)
24+
from datacustomcode.common_config import BaseConfig, default_config_file
25+
from datacustomcode.einstein_platform_config import CredentialsObjectConfig
3026
from datacustomcode.llm_gateway.base import LLMGateway
3127

3228
_E = TypeVar("_E", bound=LLMGateway)
3329

3430

35-
class LLMGatewayObjectConfig(BaseObjectConfig, Generic[_E]):
36-
type_base: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract]
37-
38-
def to_object(self) -> _E:
39-
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
40-
return cast("_E", type_(**self.options))
31+
class LLMGatewayObjectConfig(CredentialsObjectConfig, Generic[_E]):
32+
type_to_create: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract]
4133

4234

4335
class LLMGatewayConfig(BaseConfig):

src/datacustomcode/templates/function/payload/entrypoint.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def make_einstein_prediction(runtime: Runtime) -> None:
5959
)
6060

6161
prediction_response = runtime.einstein_predictions.predict(prediction_request)
62-
print(
62+
logger.info(
6363
f"Einstein prediction results - success: [{prediction_response.is_success}] "
6464
f"response data: {prediction_response.data}"
6565
)
@@ -73,11 +73,10 @@ def generate_text(runtime: Runtime):
7373
.build()
7474
)
7575
llm_response = runtime.llm_gateway.generate_text(llm_request)
76-
77-
if llm_response.is_success:
78-
print(llm_response.text)
79-
else:
80-
print(llm_response.error_code)
76+
logger.info(
77+
f"LLM Gateway generate text results - success: [{llm_response.is_success}] "
78+
f"response data: {llm_response.data}"
79+
)
8180

8281

8382
def function(request: dict, runtime: Runtime) -> dict:

0 commit comments

Comments
 (0)