Skip to content

Commit 5f201e3

Browse files
test Einstein Prediction with actual production API
1 parent 643586a commit 5f201e3

12 files changed

Lines changed: 577 additions & 419 deletions

File tree

src/datacustomcode/cli.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,15 @@ def deploy(
179179
function_invoke_opt: str,
180180
sf_cli_org: Optional[str],
181181
):
182-
from datacustomcode.credentials import Credentials
183182
from datacustomcode.deploy import (
184183
COMPUTE_TYPES,
185-
AccessTokenResponse,
186184
CodeExtensionMetadata,
187-
_retrieve_access_token_from_sf_cli,
188185
deploy_full,
189186
)
187+
from datacustomcode.token_provider import (
188+
CredentialsTokenProvider,
189+
SFCLITokenProvider,
190+
)
190191

191192
logger.debug("Deploying project")
192193

@@ -220,22 +221,15 @@ def deploy(
220221
function_invoke_options = function_invoke_opt.split(",")
221222
metadata.functionInvokeOptions = function_invoke_options
222223

223-
auth: Union[Credentials, AccessTokenResponse]
224-
if sf_cli_org:
225-
try:
226-
auth = _retrieve_access_token_from_sf_cli(sf_cli_org)
227-
except RuntimeError as e:
228-
click.secho(f"Error: {e}", fg="red")
229-
raise click.Abort() from None
230-
else:
231-
try:
232-
auth = Credentials.from_available(profile=profile)
233-
except ValueError as e:
234-
click.secho(
235-
f"Error: {e}",
236-
fg="red",
237-
)
238-
raise click.Abort() from None
224+
try:
225+
if sf_cli_org:
226+
auth = SFCLITokenProvider(sf_cli_org).get_token()
227+
else:
228+
auth = CredentialsTokenProvider(profile).get_token()
229+
except RuntimeError as e:
230+
click.secho(f"Error: {e}", fg="red")
231+
raise click.Abort() from None
232+
239233
deploy_full(path, metadata, auth, network)
240234

241235

src/datacustomcode/config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,10 @@ proxy_config:
2626

2727
einstein_predictions_config:
2828
type_config_name: DefaultEinsteinPredictions
29-
options: {}
29+
options:
30+
credentials_profile: default
31+
32+
llm_gateway_config:
33+
type_config_name: DefaultLLMGateway
34+
options:
35+
credentials_profile: default

src/datacustomcode/deploy.py

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
import os
2020
import re
2121
import shutil
22-
import subprocess
2322
import tempfile
2423
import time
2524
from typing import (
26-
TYPE_CHECKING,
2725
Any,
2826
Callable,
2927
Dict,
@@ -37,15 +35,10 @@
3735
import requests
3836

3937
from datacustomcode.cmd import cmd_output
40-
from datacustomcode.credentials import AuthType
4138
from datacustomcode.scan import find_base_directory, get_package_type
4239

43-
if TYPE_CHECKING:
44-
from datacustomcode.credentials import Credentials
45-
4640
DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code"
4741
DATA_TRANSFORMS_PATH = "services/data/v63.0/ssot/data-transforms"
48-
AUTH_PATH = "services/oauth2/token"
4942
WAIT_FOR_DEPLOYMENT_TIMEOUT = 3000
5043

5144
# Available compute types for Data Cloud deployments.
@@ -163,80 +156,6 @@ class AccessTokenResponse(BaseModel):
163156
instance_url: str
164157

165158

166-
def _retrieve_access_token(credentials: Credentials) -> AccessTokenResponse:
167-
"""Get an access token for the Salesforce API."""
168-
logger.debug("Getting oauth token...")
169-
170-
url = f"{credentials.login_url.rstrip('/')}/{AUTH_PATH.lstrip('/')}"
171-
172-
if credentials.auth_type == AuthType.OAUTH_TOKENS:
173-
data = {
174-
"grant_type": "refresh_token",
175-
"refresh_token": credentials.refresh_token,
176-
"client_id": credentials.client_id,
177-
"client_secret": credentials.client_secret,
178-
}
179-
elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS:
180-
data = {
181-
"grant_type": "client_credentials",
182-
"client_id": credentials.client_id,
183-
"client_secret": credentials.client_secret,
184-
}
185-
else:
186-
raise ValueError(f"Unsupported auth_type: {credentials.auth_type}")
187-
188-
response = _make_api_call(url, "POST", data=data)
189-
return AccessTokenResponse(**response)
190-
191-
192-
def _retrieve_access_token_from_sf_cli(sf_cli_org: str) -> AccessTokenResponse:
193-
"""Get an access token from the Salesforce CLI."""
194-
try:
195-
result = subprocess.run(
196-
["sf", "org", "display", "--target-org", sf_cli_org, "--json"],
197-
capture_output=True,
198-
text=True,
199-
check=True,
200-
timeout=30,
201-
)
202-
except FileNotFoundError as exc:
203-
raise RuntimeError(
204-
"The 'sf' command was not found. "
205-
"Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli"
206-
) from exc
207-
except subprocess.TimeoutExpired as exc:
208-
raise RuntimeError(
209-
f"'sf org display' timed out for org '{sf_cli_org}'"
210-
) from exc
211-
except subprocess.CalledProcessError as exc:
212-
raise RuntimeError(
213-
f"'sf org display' failed for org '{sf_cli_org}'.\n"
214-
f"Ensure the org is authenticated via 'sf org login web'.\n"
215-
f"stderr: {exc.stderr.strip()}"
216-
) from exc
217-
218-
try:
219-
data = json.loads(result.stdout)
220-
except json.JSONDecodeError as exc:
221-
raise RuntimeError(f"Failed to parse 'sf org display' output: {exc}") from exc
222-
223-
if data.get("status") != 0:
224-
raise RuntimeError(
225-
f"SF CLI error for org '{sf_cli_org}': "
226-
f"{data.get('message', 'unknown error')}"
227-
)
228-
229-
org_result = data.get("result", {})
230-
access_token = org_result.get("accessToken")
231-
instance_url = org_result.get("instanceUrl")
232-
if not access_token or not instance_url:
233-
raise RuntimeError(
234-
f"'sf org display' did not return an access token or instance URL "
235-
f"for org '{sf_cli_org}'"
236-
)
237-
return AccessTokenResponse(access_token=access_token, instance_url=instance_url)
238-
239-
240159
class CreateDeploymentResponse(BaseModel):
241160
fileUploadUrl: str
242161

@@ -567,16 +486,11 @@ def zip(
567486
def deploy_full(
568487
directory: str,
569488
metadata: CodeExtensionMetadata,
570-
credentials: Union["Credentials", AccessTokenResponse],
489+
access_token: AccessTokenResponse,
571490
docker_network: str,
572491
callback=None,
573492
) -> AccessTokenResponse:
574493
"""Deploy a data transform in the DataCloud."""
575-
if isinstance(credentials, AccessTokenResponse):
576-
access_token = credentials
577-
else:
578-
access_token = _retrieve_access_token(credentials)
579-
580494
# prepare payload
581495
config = get_config(directory)
582496

@@ -587,7 +501,6 @@ def deploy_full(
587501
wait_for_deployment(access_token, metadata, callback)
588502

589503
# create data transform
590-
591504
if isinstance(config, DataTransformConfig):
592505
create_data_transform(directory, access_token, metadata, config)
593506
return access_token

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,131 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import (
17+
Any,
18+
ClassVar,
19+
Dict,
20+
List,
21+
Optional,
22+
)
23+
24+
from loguru import logger
25+
import requests
26+
1627
from datacustomcode.einstein_predictions.base import EinsteinPredictions
1728
from datacustomcode.einstein_predictions.types import (
1829
PredictionRequest,
1930
PredictionResponse,
31+
PredictionType,
32+
)
33+
from datacustomcode.token_provider import (
34+
CredentialsTokenProvider,
35+
SFCLITokenProvider,
36+
TokenProvider,
2037
)
2138

2239

2340
class DefaultEinsteinPredictions(EinsteinPredictions):
2441
CONFIG_NAME = "DefaultEinsteinPredictions"
42+
EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1"
2543

26-
def __init__(self, **kwargs):
44+
ENDPOINT_MAP: ClassVar[dict[PredictionType, str]] = {
45+
PredictionType.REGRESSION: "regression",
46+
PredictionType.CLUSTERING: "clustering",
47+
PredictionType.CLASSIFICATION: "classification",
48+
PredictionType.BINARY_CLASSIFICATION: "binary-classification",
49+
PredictionType.MULTI_OUTCOME: "multi-outcome",
50+
}
51+
52+
def __init__(
53+
self,
54+
credentials_profile: Optional[str] = None,
55+
sf_cli_org: Optional[str] = None,
56+
**kwargs,
57+
):
2758
super().__init__(**kwargs)
2859

60+
if sf_cli_org:
61+
self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org)
62+
logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}")
63+
else:
64+
profile = credentials_profile or "default"
65+
self._token_provider = CredentialsTokenProvider(profile)
66+
logger.debug(f"Using credentials token provider with profile: {profile}")
67+
2968
def predict(self, request: PredictionRequest) -> PredictionResponse:
69+
"""Make a prediction request to the Einstein Predictions API"""
70+
token_response = self._token_provider.get_token()
71+
access_token = token_response.access_token
72+
73+
endpoint = self.ENDPOINT_MAP.get(request.prediction_type)
74+
if not endpoint:
75+
raise RuntimeError(
76+
f"Unknown prediction type: {request.prediction_type}. "
77+
f"Valid types: {list(self.ENDPOINT_MAP.keys())}"
78+
)
79+
80+
api_url = (
81+
f"{self.EINSTEIN_PLATFORM_URL}/models/"
82+
f"{request.model_api_name}/{endpoint}"
83+
)
84+
85+
prediction_columns: List[Dict[str, Any]] = []
86+
for col in request.prediction_columns:
87+
col_data: Dict[str, Any] = {"columnName": col.column_name}
88+
if col.string_values:
89+
col_data["stringValues"] = col.string_values
90+
if col.double_values:
91+
col_data["doubleValues"] = col.double_values
92+
if col.boolean_values:
93+
col_data["booleanValues"] = col.boolean_values
94+
if col.date_values:
95+
col_data["dateValues"] = col.date_values
96+
if col.datetime_values:
97+
col_data["datetimeValues"] = col.datetime_values
98+
prediction_columns.append(col_data)
99+
100+
payload: Dict[str, Any] = {"predictionColumns": prediction_columns}
101+
102+
if request.settings:
103+
payload["settings"] = request.settings
104+
105+
headers = {
106+
"Authorization": f"Bearer {access_token}",
107+
"Content-Type": "application/json",
108+
"x-sfdc-app-context": "EinsteinGPT",
109+
"x-client-feature-id": "ai-platform-models-connected-app",
110+
}
111+
112+
logger.debug(f"Making Einstein prediction request to: {api_url}")
113+
try:
114+
response = requests.post(api_url, json=payload, headers=headers, timeout=60)
115+
if not response.ok and not response.text:
116+
error_msg = (
117+
f"Einstein Prediction request failed: {api_url} - "
118+
f"{response.status_code} {response.reason}. "
119+
"If your code uses Einstein APIs, make sure you have "
120+
'configured the SDK to use "client_credentials" auth type. '
121+
"Refer to https://developer.salesforce.com/docs/ai/agentforce/"
122+
"guide/agent-api-get-started.html#create-a-salesforce-app "
123+
"to create your external client app."
124+
)
125+
logger.error(error_msg)
126+
except requests.exceptions.RequestException as e:
127+
logger.error(f"Prediction API request failed: {api_url} {e}")
128+
raise RuntimeError(f"Prediction API request failed: {e}") from e
129+
130+
response_data: Dict[str, Any] = {}
131+
if response.content:
132+
try:
133+
response_data = response.json()
134+
except ValueError:
135+
logger.warning("Failed to parse response as JSON")
136+
response_data = {"raw_response": response.text}
137+
30138
return PredictionResponse(
31139
version="v1",
32140
prediction_type=request.prediction_type,
33-
status_code=200,
34-
data={"results": [{"prediction": {"predictedValue": 1.0}}]},
141+
status_code=response.status_code,
142+
data=response_data,
35143
)

0 commit comments

Comments
 (0)