Skip to content

Commit afa1200

Browse files
use production llm gateway
1 parent 5aae4ab commit afa1200

5 files changed

Lines changed: 144 additions & 83 deletions

File tree

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import (
17+
Any,
18+
Dict,
19+
Optional,
20+
)
21+
22+
from abc import ABC
23+
from loguru import logger
24+
25+
from datacustomcode.token_provider import (
26+
CredentialsTokenProvider,
27+
SFCLITokenProvider,
28+
TokenProvider,
29+
)
30+
31+
32+
class EinsteinPlatformClient(ABC):
33+
EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1"
34+
EINSTEIN_WARNING_MESSAGE = (
35+
"If your code uses Einstein APIs, make sure you have "
36+
'configured the SDK to use "client_credentials" auth type. '
37+
"Refer to https://developer.salesforce.com/docs/ai/agentforce/"
38+
"guide/agent-api-get-started.html#create-a-salesforce-app "
39+
"to create your external client app."
40+
)
41+
42+
def __init__(
43+
self,
44+
credentials_profile: Optional[str] = None,
45+
sf_cli_org: Optional[str] = None,
46+
):
47+
if sf_cli_org:
48+
self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org)
49+
logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}")
50+
else:
51+
profile = credentials_profile or "default"
52+
self._token_provider = CredentialsTokenProvider(profile)
53+
logger.debug(f"Using credentials token provider with profile: {profile}")
54+
self.token_response = None
55+
56+
def get_headers(self):
57+
if self.token_response is None:
58+
self.token_response = self._token_provider.get_token()
59+
60+
return {
61+
"Authorization": f"Bearer {self.token_response.access_token}",
62+
"Content-Type": "application/json",
63+
"x-sfdc-app-context": "EinsteinGPT",
64+
"x-client-feature-id": "ai-platform-models-connected-app",
65+
}
66+
67+
def parse_response(self, response):
68+
response_data: Dict[str, Any] = {}
69+
if response.content:
70+
try:
71+
response_data = response.json()
72+
except ValueError:
73+
logger.warning("Failed to parse response as JSON")
74+
response_data = {"raw_response": response.text}
75+
return response_data

src/datacustomcode/einstein_predictions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions
1818

1919
__all__ = [
20-
"EinsteinPredictions",
2120
"DefaultEinsteinPredictions",
21+
"EinsteinPredictions",
2222
]

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,17 @@
2424
from loguru import logger
2525
import requests
2626

27+
from datacustomcode.einstein_platform_client import EinsteinPlatformClient
2728
from datacustomcode.einstein_predictions.base import EinsteinPredictions
2829
from datacustomcode.einstein_predictions.types import (
2930
PredictionRequest,
3031
PredictionResponse,
3132
PredictionType,
3233
)
33-
from datacustomcode.token_provider import (
34-
CredentialsTokenProvider,
35-
SFCLITokenProvider,
36-
TokenProvider,
37-
)
3834

3935

40-
class DefaultEinsteinPredictions(EinsteinPredictions):
36+
class DefaultEinsteinPredictions(EinsteinPlatformClient, EinsteinPredictions):
4137
CONFIG_NAME = "DefaultEinsteinPredictions"
42-
EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1"
43-
4438
ENDPOINT_MAP: ClassVar[dict[PredictionType, str]] = {
4539
PredictionType.REGRESSION: "regression",
4640
PredictionType.CLUSTERING: "clustering",
@@ -55,21 +49,12 @@ def __init__(
5549
sf_cli_org: Optional[str] = None,
5650
**kwargs,
5751
):
58-
super().__init__(**kwargs)
59-
60-
if sf_cli_org:
61-
self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org)
62-
logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}")
63-
else:
64-
profile = credentials_profile or "default"
65-
self._token_provider = CredentialsTokenProvider(profile)
66-
logger.debug(f"Using credentials token provider with profile: {profile}")
52+
EinsteinPlatformClient.__init__(
53+
self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org
54+
)
55+
EinsteinPredictions.__init__(self, **kwargs)
6756

6857
def predict(self, request: PredictionRequest) -> PredictionResponse:
69-
"""Make a prediction request to the Einstein Predictions API"""
70-
token_response = self._token_provider.get_token()
71-
access_token = token_response.access_token
72-
7358
endpoint = self.ENDPOINT_MAP.get(request.prediction_type)
7459
if not endpoint:
7560
raise RuntimeError(
@@ -102,42 +87,24 @@ def predict(self, request: PredictionRequest) -> PredictionResponse:
10287
if request.settings:
10388
payload["settings"] = request.settings
10489

105-
headers = {
106-
"Authorization": f"Bearer {access_token}",
107-
"Content-Type": "application/json",
108-
"x-sfdc-app-context": "EinsteinGPT",
109-
"x-client-feature-id": "ai-platform-models-connected-app",
110-
}
111-
11290
logger.debug(f"Making Einstein prediction request to: {api_url}")
11391
try:
114-
response = requests.post(api_url, json=payload, headers=headers, timeout=60)
92+
response = requests.post(
93+
api_url, json=payload, headers=self.get_headers(), timeout=180
94+
)
11595
if not response.ok and not response.text:
11696
error_msg = (
11797
f"Einstein Prediction request failed: {api_url} - "
11898
f"{response.status_code} {response.reason}. "
119-
"If your code uses Einstein APIs, make sure you have "
120-
'configured the SDK to use "client_credentials" auth type. '
121-
"Refer to https://developer.salesforce.com/docs/ai/agentforce/"
122-
"guide/agent-api-get-started.html#create-a-salesforce-app "
123-
"to create your external client app."
99+
f"{self.EINSTEIN_WARNING_MESSAGE}"
124100
)
125101
logger.error(error_msg)
126102
except requests.exceptions.RequestException as e:
127-
logger.error(f"Prediction API request failed: {api_url} {e}")
128-
raise RuntimeError(f"Prediction API request failed: {e}") from e
129-
130-
response_data: Dict[str, Any] = {}
131-
if response.content:
132-
try:
133-
response_data = response.json()
134-
except ValueError:
135-
logger.warning("Failed to parse response as JSON")
136-
response_data = {"raw_response": response.text}
103+
logger.error(f"Einstein Prediction request failed: {api_url} {e}")
104+
raise RuntimeError(f"Einstein Prediction request failed: {e}") from e
137105

138106
return PredictionResponse(
139-
version="v1",
140107
prediction_type=request.prediction_type,
141108
status_code=response.status_code,
142-
data=response_data,
109+
data=self.parse_response(response),
143110
)

src/datacustomcode/llm_gateway/default.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,69 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import (
17+
Any,
18+
Dict,
19+
Optional,
20+
)
21+
22+
from loguru import logger
23+
import requests
24+
25+
from datacustomcode.einstein_platform_client import EinsteinPlatformClient
26+
27+
1628
from datacustomcode.llm_gateway.base import LLMGateway
1729
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
1830
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
19-
from datacustomcode.llm_gateway.types.generate_text_response_builder import (
20-
GenerateTextResponseBuilder,
21-
)
2231

2332

24-
class DefaultLLMGateway(LLMGateway):
33+
class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway):
2534
CONFIG_NAME = "DefaultLLMGateway"
2635

36+
def __init__(
37+
self,
38+
credentials_profile: Optional[str] = None,
39+
sf_cli_org: Optional[str] = None,
40+
**kwargs,
41+
):
42+
EinsteinPlatformClient.__init__(
43+
self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org
44+
)
45+
LLMGateway.__init__(self, **kwargs)
46+
2747
def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
48+
api_url = (
49+
f"{self.EINSTEIN_PLATFORM_URL}/models/"
50+
f"{request.model_name}/generations"
51+
)
2852

29-
response_data = {
30-
"version": "v1",
31-
"status_code": 200,
32-
"data": {"generation": {"generatedText": "Hello World"}},
53+
payload: Dict[str, Any] = {
54+
"prompt": request.prompt
3355
}
3456

35-
return GenerateTextResponseBuilder.build(response_data)
57+
if request.localization:
58+
payload["localization"] = request.localization
59+
if request.tags:
60+
payload["tags"] = request.tags
61+
62+
logger.debug(f"Making Generate text request: {api_url}")
63+
try:
64+
response = requests.post(
65+
api_url, json=payload, headers=self.get_headers(), timeout=180
66+
)
67+
if not response.ok and not response.text:
68+
error_msg = (
69+
f"Generate text request failed: {api_url} - "
70+
f"{response.status_code} {response.reason}. "
71+
f"{self.EINSTEIN_WARNING_MESSAGE}"
72+
)
73+
logger.error(error_msg)
74+
except requests.exceptions.RequestException as e:
75+
logger.error(f"Generate text request failed: {api_url} {e}")
76+
raise RuntimeError(f"Generate text request failed: {e}") from e
77+
78+
return GenerateTextResponse(
79+
status_code=response.status_code,
80+
data=self.parse_response(response)
81+
)

tests/test_llm_gateway.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from pydantic import ValidationError
44
import pytest
55

6-
from datacustomcode.llm_gateway.base import LLMGateway
7-
from datacustomcode.llm_gateway.default import DefaultLLMGateway
86
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
97
from datacustomcode.llm_gateway.types.generate_text_request_builder import (
108
GenerateTextRequestBuilder,
@@ -210,28 +208,3 @@ def test_builder_with_minimal_dict(self):
210208
response = GenerateTextResponseBuilder.build(response_dict)
211209
assert response.status_code == 200
212210
assert response.version == "v1" # Default value
213-
214-
215-
class TestDefaultLLMGateway:
216-
"""Test DefaultLLMGateway implementation."""
217-
218-
def test_default_gateway_is_llm_gateway(self):
219-
"""Test DefaultLLMGateway inherits from LLMGateway."""
220-
gateway = DefaultLLMGateway()
221-
assert isinstance(gateway, LLMGateway)
222-
223-
def test_generate_text_returns_response(self):
224-
"""Test generate_text returns GenerateTextResponse."""
225-
gateway = DefaultLLMGateway()
226-
request = GenerateTextRequest(model_name="gpt-4", prompt="Hello")
227-
response = gateway.generate_text(request)
228-
assert isinstance(response, GenerateTextResponse)
229-
230-
def test_generate_text_success_response(self):
231-
"""Test generate_text returns successful response."""
232-
gateway = DefaultLLMGateway()
233-
request = GenerateTextRequest(model_name="gpt-4", prompt="Hello")
234-
response = gateway.generate_text(request)
235-
assert response.is_success is True
236-
assert response.status_code == 200
237-
assert len(response.text) > 0

0 commit comments

Comments
 (0)