2424from loguru import logger
2525import requests
2626
27+ from datacustomcode .einstein_platform_client import EinsteinPlatformClient
2728from datacustomcode .einstein_predictions .base import EinsteinPredictions
2829from datacustomcode .einstein_predictions .types import (
2930 PredictionRequest ,
3031 PredictionResponse ,
3132 PredictionType ,
3233)
33- from datacustomcode .token_provider import (
34- CredentialsTokenProvider ,
35- SFCLITokenProvider ,
36- TokenProvider ,
37- )
3834
3935
40- class DefaultEinsteinPredictions (EinsteinPredictions ):
36+ class DefaultEinsteinPredictions (EinsteinPlatformClient , EinsteinPredictions ):
4137 CONFIG_NAME = "DefaultEinsteinPredictions"
42- EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1"
43-
4438 ENDPOINT_MAP : ClassVar [dict [PredictionType , str ]] = {
4539 PredictionType .REGRESSION : "regression" ,
4640 PredictionType .CLUSTERING : "clustering" ,
@@ -55,21 +49,12 @@ def __init__(
5549 sf_cli_org : Optional [str ] = None ,
5650 ** kwargs ,
5751 ):
58- super ().__init__ (** kwargs )
59-
60- if sf_cli_org :
61- self ._token_provider : TokenProvider = SFCLITokenProvider (sf_cli_org )
62- logger .debug (f"Using SF CLI token provider for org: { sf_cli_org } " )
63- else :
64- profile = credentials_profile or "default"
65- self ._token_provider = CredentialsTokenProvider (profile )
66- logger .debug (f"Using credentials token provider with profile: { profile } " )
52+ EinsteinPlatformClient .__init__ (
53+ self , credentials_profile = credentials_profile , sf_cli_org = sf_cli_org
54+ )
55+ EinsteinPredictions .__init__ (self , ** kwargs )
6756
6857 def predict (self , request : PredictionRequest ) -> PredictionResponse :
69- """Make a prediction request to the Einstein Predictions API"""
70- token_response = self ._token_provider .get_token ()
71- access_token = token_response .access_token
72-
7358 endpoint = self .ENDPOINT_MAP .get (request .prediction_type )
7459 if not endpoint :
7560 raise RuntimeError (
@@ -102,42 +87,24 @@ def predict(self, request: PredictionRequest) -> PredictionResponse:
10287 if request .settings :
10388 payload ["settings" ] = request .settings
10489
105- headers = {
106- "Authorization" : f"Bearer { access_token } " ,
107- "Content-Type" : "application/json" ,
108- "x-sfdc-app-context" : "EinsteinGPT" ,
109- "x-client-feature-id" : "ai-platform-models-connected-app" ,
110- }
111-
11290 logger .debug (f"Making Einstein prediction request to: { api_url } " )
11391 try :
114- response = requests .post (api_url , json = payload , headers = headers , timeout = 60 )
92+ response = requests .post (
93+ api_url , json = payload , headers = self .get_headers (), timeout = 180
94+ )
11595 if not response .ok and not response .text :
11696 error_msg = (
11797 f"Einstein Prediction request failed: { api_url } - "
11898 f"{ response .status_code } { response .reason } . "
119- "If your code uses Einstein APIs, make sure you have "
120- 'configured the SDK to use "client_credentials" auth type. '
121- "Refer to https://developer.salesforce.com/docs/ai/agentforce/"
122- "guide/agent-api-get-started.html#create-a-salesforce-app "
123- "to create your external client app."
99+ f"{ self .EINSTEIN_WARNING_MESSAGE } "
124100 )
125101 logger .error (error_msg )
126102 except requests .exceptions .RequestException as e :
127- logger .error (f"Prediction API request failed: { api_url } { e } " )
128- raise RuntimeError (f"Prediction API request failed: { e } " ) from e
129-
130- response_data : Dict [str , Any ] = {}
131- if response .content :
132- try :
133- response_data = response .json ()
134- except ValueError :
135- logger .warning ("Failed to parse response as JSON" )
136- response_data = {"raw_response" : response .text }
103+ logger .error (f"Einstein Prediction request failed: { api_url } { e } " )
104+ raise RuntimeError (f"Einstein Prediction request failed: { e } " ) from e
137105
138106 return PredictionResponse (
139- version = "v1" ,
140107 prediction_type = request .prediction_type ,
141108 status_code = response .status_code ,
142- data = response_data ,
109+ data = self . parse_response ( response ) ,
143110 )
0 commit comments