1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from typing import (
17- Any ,
18- Dict ,
19- Optional ,
20- )
16+ from typing import Any , Dict
2117
2218from loguru import logger
2319import requests
2622from datacustomcode .llm_gateway .base import LLMGateway
2723from datacustomcode .llm_gateway .types .generate_text_request import GenerateTextRequest
2824from datacustomcode .llm_gateway .types .generate_text_response import GenerateTextResponse
25+ from datacustomcode .llm_gateway .types .generate_text_response_builder import (
26+ GenerateTextResponseBuilder ,
27+ )
2928
3029
3130class DefaultLLMGateway (EinsteinPlatformClient , LLMGateway ):
3231 CONFIG_NAME = "DefaultLLMGateway"
3332
34- def __init__ (
35- self ,
36- credentials_profile : Optional [str ] = None ,
37- sf_cli_org : Optional [str ] = None ,
38- ** kwargs ,
39- ):
40- EinsteinPlatformClient .__init__ (
41- self , credentials_profile = credentials_profile , sf_cli_org = sf_cli_org
42- )
43- LLMGateway .__init__ (self , ** kwargs )
44-
4533 def generate_text (self , request : GenerateTextRequest ) -> GenerateTextResponse :
4634 api_url = (
47- f"{ self .EINSTEIN_PLATFORM_URL } /models /{ request .model_name } /generations"
35+ f"{ self .EINSTEIN_PLATFORM_MODELS_URL } /{ request .model_name } /generations"
4836 )
4937
5038 payload : Dict [str , Any ] = {"prompt" : request .prompt }
@@ -69,6 +57,8 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
6957 logger .error (f"Generate text request failed: { api_url } { e } " )
7058 raise RuntimeError (f"Generate text request failed: { e } " ) from e
7159
72- return GenerateTextResponse (
73- status_code = response .status_code , data = self .parse_response (response )
74- )
60+ response_dict = {
61+ "status_code" : response .status_code ,
62+ "data" : self .parse_response (response ),
63+ }
64+ return GenerateTextResponseBuilder .build (response_dict )
0 commit comments