Skip to content

Commit 28187f2

Browse files
more refactoring
1 parent 429fcd8 commit 28187f2

3 files changed

Lines changed: 21 additions & 37 deletions

File tree

src/datacustomcode/einstein_platform_client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121

2222
from loguru import logger
23+
import requests
2324

2425
from datacustomcode.token_provider import (
2526
CredentialsTokenProvider,
@@ -49,7 +50,7 @@ def __init__(
4950
self.token_response = None
5051
super().__init__(**kwargs)
5152

52-
def get_headers(self):
53+
def _get_headers(self):
5354
if self.token_response is None:
5455
self.token_response = self._token_provider.get_token()
5556

@@ -60,6 +61,23 @@ def get_headers(self):
6061
"x-client-feature-id": "ai-platform-models-connected-app",
6162
}
6263

64+
def make_post_request(self, url, payload):
65+
try:
66+
response = requests.post(
67+
url, json=payload, headers=self._get_headers(), timeout=180
68+
)
69+
if not response.ok:
70+
error_msg = (
71+
f"Request to {url} failed. "
72+
f"Reason: {response.status_code} {response.reason} - "
73+
f"Response body: {response.text}"
74+
)
75+
logger.error(error_msg)
76+
return response
77+
except requests.exceptions.RequestException as e:
78+
logger.error(f"Request to {url} failed: {e}")
79+
raise RuntimeError(f"Request to {url} failed {e}") from e
80+
6381
def parse_response(self, response):
6482
response_data: Dict[str, Any] = {}
6583
if response.content:

src/datacustomcode/einstein_predictions/impl/default.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
List,
2121
)
2222

23-
from loguru import logger
24-
import requests
25-
2623
from datacustomcode.einstein_platform_client import EinsteinPlatformClient
2724
from datacustomcode.einstein_predictions.base import EinsteinPredictions
2825
from datacustomcode.einstein_predictions.types import (
@@ -74,21 +71,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse:
7471
if request.settings:
7572
payload["settings"] = request.settings
7673

77-
logger.debug(f"Making Einstein prediction request to: {api_url}")
78-
try:
79-
response = requests.post(
80-
api_url, json=payload, headers=self.get_headers(), timeout=180
81-
)
82-
if not response.ok and not response.text:
83-
error_msg = (
84-
f"Einstein Prediction request failed: {api_url} - "
85-
f"{response.status_code} {response.reason}"
86-
)
87-
logger.error(error_msg)
88-
except requests.exceptions.RequestException as e:
89-
logger.error(f"Einstein Prediction request failed: {api_url} {e}")
90-
raise RuntimeError(f"Einstein Prediction request failed: {e}") from e
91-
74+
response = self.make_post_request(api_url, payload)
9275
return PredictionResponse(
9376
prediction_type=request.prediction_type,
9477
status_code=response.status_code,

src/datacustomcode/llm_gateway/default.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515

1616
from typing import Any, Dict
1717

18-
from loguru import logger
19-
import requests
20-
2118
from datacustomcode.einstein_platform_client import EinsteinPlatformClient
2219
from datacustomcode.llm_gateway.base import LLMGateway
2320
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
@@ -42,21 +39,7 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
4239
if request.tags:
4340
payload["tags"] = request.tags
4441

45-
logger.debug(f"Making Generate text request: {api_url}")
46-
try:
47-
response = requests.post(
48-
api_url, json=payload, headers=self.get_headers(), timeout=180
49-
)
50-
if not response.ok and not response.text:
51-
error_msg = (
52-
f"Generate text request failed: {api_url} - "
53-
f"{response.status_code} {response.reason}"
54-
)
55-
logger.error(error_msg)
56-
except requests.exceptions.RequestException as e:
57-
logger.error(f"Generate text request failed: {api_url} {e}")
58-
raise RuntimeError(f"Generate text request failed: {e}") from e
59-
42+
response = self.make_post_request(api_url, payload)
6043
response_dict = {
6144
"status_code": response.status_code,
6245
"data": self.parse_response(response),

0 commit comments

Comments
 (0)