1+ from __future__ import annotations
2+
3+ import pytest
4+ from pydantic import ValidationError
5+
6+ from datacustomcode .llm_gateway .base import LLMGateway
7+ from datacustomcode .llm_gateway .default import DefaultLLMGateway
8+ from datacustomcode .llm_gateway .types .generate_text_request import GenerateTextRequest
9+ from datacustomcode .llm_gateway .types .generate_text_request_builder import GenerateTextRequestBuilder
10+ from datacustomcode .llm_gateway .types .generate_text_response import GenerateTextResponse
11+ from datacustomcode .llm_gateway .types .generate_text_response_builder import GenerateTextResponseBuilder
12+
13+
14+ class TestGenerateTextRequest :
15+ """Test GenerateTextRequest model."""
16+
17+ def test_version_defaults_to_v1 (self ):
18+ """Test version field defaults to v1."""
19+ request = GenerateTextRequest (model_name = "gpt-4" , prompt = "Hello" )
20+ assert request .version == "v1"
21+
22+ def test_model_name_required (self ):
23+ """Test model_name is required."""
24+ with pytest .raises (ValidationError ) as exc_info :
25+ GenerateTextRequest (prompt = "Hello" )
26+ # Error message uses camelCase alias
27+ assert "modelName" in str (exc_info .value ) or "model_name" in str (exc_info .value )
28+
29+ def test_prompt_required (self ):
30+ """Test prompt is required."""
31+ with pytest .raises (ValidationError ) as exc_info :
32+ GenerateTextRequest (model_name = "gpt-4" )
33+ assert "prompt" in str (exc_info .value )
34+
35+ def test_version_must_be_v1 (self ):
36+ """Test version must be literal 'v1'."""
37+ with pytest .raises (ValidationError ) as exc_info :
38+ GenerateTextRequest (version = "v2" , model_name = "gpt-4" , prompt = "Hello" )
39+ assert "version" in str (exc_info .value )
40+
41+ def test_camel_case_serialization (self ):
42+ """Test serialization with camelCase aliases."""
43+ request = GenerateTextRequest (model_name = "gpt-4" , prompt = "Hello" )
44+ data = request .model_dump (by_alias = True )
45+ assert "modelName" in data
46+ assert data ["modelName" ] == "gpt-4"
47+ assert "model_name" not in data
48+
49+ def test_accepts_camel_case_input (self ):
50+ """Test model can accept camelCase field names."""
51+ request = GenerateTextRequest (modelName = "gpt-4" , prompt = "Hello" )
52+ assert request .model_name == "gpt-4"
53+
54+
55+ class TestGenerateTextRequestBuilder :
56+ """Test GenerateTextRequestBuilder."""
57+
58+ def test_builder_basic_usage (self ):
59+ """Test basic builder pattern."""
60+ builder = GenerateTextRequestBuilder ()
61+ request = builder .set_prompt ("Hello" ).set_model ("gpt-4" ).build ()
62+ assert request .prompt == "Hello"
63+ assert request .model_name == "gpt-4"
64+
65+ def test_builder_with_localization_dict (self ):
66+ """Test builder with localization dictionary."""
67+ builder = GenerateTextRequestBuilder ()
68+ localization = {"defaultLocale" : "en-US" , "timezone" : "PST" }
69+ request = (
70+ builder .set_prompt ("Hello" )
71+ .set_model ("gpt-4" )
72+ .set_localization (localization = localization )
73+ .build ()
74+ )
75+ assert request .localization == localization
76+
77+ def test_builder_with_locale_string (self ):
78+ """Test builder with simple locale string."""
79+ builder = GenerateTextRequestBuilder ()
80+ request = (
81+ builder .set_prompt ("Hello" )
82+ .set_model ("gpt-4" )
83+ .set_localization (locale = "en-US" )
84+ .build ()
85+ )
86+ # Verify the localization structure
87+ assert request .localization is not None
88+ assert request .localization ["defaultLocale" ] == "en-US"
89+ assert request .localization ["inputLocales" ] == [{"locale" : "en-US" , "probability" : 1.0 }]
90+ assert request .localization ["expectedLocales" ] == ["en-US" ]
91+
92+ def test_builder_with_tags (self ):
93+ """Test builder with tags."""
94+ builder = GenerateTextRequestBuilder ()
95+ tags = {"user" : "test" , "session" : "123" }
96+ request = (
97+ builder .set_prompt ("Hello" )
98+ .set_model ("gpt-4" )
99+ .set_tags (tags )
100+ .build ()
101+ )
102+ assert request .tags == tags
103+
104+ def test_builder_validates_on_build (self ):
105+ """Test builder validates request on build."""
106+ builder = GenerateTextRequestBuilder ()
107+ with pytest .raises (ValidationError ):
108+ builder .set_prompt ("Hello" ).set_model ("" ).build ()
109+
110+ def test_builder_localization_requires_argument (self ):
111+ """Test set_localization requires either localization or locale."""
112+ builder = GenerateTextRequestBuilder ()
113+ with pytest .raises (ValueError ) as exc_info :
114+ builder .set_localization ()
115+ assert "Must provide either localization or locale" in str (exc_info .value )
116+
117+
118+ class TestGenerateTextResponse :
119+ """Test GenerateTextResponse model."""
120+
121+ def test_response_defaults (self ):
122+ """Test response field defaults."""
123+ response = GenerateTextResponse (status_code = 200 )
124+ assert response .version == "v1"
125+ assert response .data is None
126+
127+ def test_is_success_property (self ):
128+ """Test is_success property returns True for 200."""
129+ response = GenerateTextResponse (status_code = 200 )
130+ assert response .is_success is True
131+ assert response .is_error is False
132+
133+ def test_is_success_false_for_non_200 (self ):
134+ """Test is_success property returns False for non-200."""
135+ response = GenerateTextResponse (status_code = 400 )
136+ assert response .is_success is False
137+ assert response .is_error is True
138+
139+ def test_text_property_success (self ):
140+ """Test text property extracts generated text on success."""
141+ response = GenerateTextResponse (
142+ status_code = 200 ,
143+ data = {"generation" : {"generatedText" : "Hello world" }}
144+ )
145+ assert response .text == "Hello world"
146+
147+ def test_text_property_returns_empty_on_error (self ):
148+ """Test text property returns empty string on error."""
149+ response = GenerateTextResponse (status_code = 400 , data = {"error" : "Bad request" })
150+ assert response .text == ""
151+
152+ def test_text_property_handles_missing_data (self ):
153+ """Test text property handles missing data gracefully."""
154+ response = GenerateTextResponse (status_code = 200 , data = None )
155+ assert response .text == ""
156+
157+ def test_text_property_handles_missing_nested_fields (self ):
158+ """Test text property handles missing nested fields."""
159+ response = GenerateTextResponse (status_code = 200 , data = {"other" : "data" })
160+ assert response .text == ""
161+
162+ def test_error_code_property (self ):
163+ """Test error_code property extracts error code on error."""
164+ response = GenerateTextResponse (
165+ status_code = 400 ,
166+ data = {"errorCode" : "INVALID_REQUEST" }
167+ )
168+ assert response .error_code == "INVALID_REQUEST"
169+
170+ def test_error_code_falls_back_to_status_code (self ):
171+ """Test error_code falls back to status_code if no errorCode in data."""
172+ response = GenerateTextResponse (status_code = 500 , data = {"message" : "error" })
173+ assert response .error_code == '500'
174+
175+ def test_error_code_returns_empty_on_success (self ):
176+ """Test error_code returns empty string on success."""
177+ response = GenerateTextResponse (status_code = 200 )
178+ assert response .error_code == ""
179+
180+ def test_status_code_validation (self ):
181+ """Test status_code must be >= 0."""
182+ with pytest .raises (ValidationError ) as exc_info :
183+ GenerateTextResponse (status_code = - 1 )
184+ assert "status_code" in str (exc_info .value )
185+
186+
187+ class TestGenerateTextResponseBuilder :
188+ """Test GenerateTextResponseBuilder."""
189+
190+ def test_builder_build_from_dict (self ):
191+ """Test building response from dictionary."""
192+ response_dict = {
193+ "version" : "v1" ,
194+ "status_code" : 200 ,
195+ "data" : {"generation" : {"generatedText" : "Hello world" }}
196+ }
197+ response = GenerateTextResponseBuilder .build (response_dict )
198+ assert isinstance (response , GenerateTextResponse )
199+ assert response .status_code == 200
200+ assert response .text == "Hello world"
201+
202+ def test_builder_validates_dict (self ):
203+ """Test builder validates the dictionary."""
204+ invalid_dict = {"version" : "v1" } # Missing required status_code
205+ with pytest .raises (ValidationError ):
206+ GenerateTextResponseBuilder .build (invalid_dict )
207+
208+ def test_builder_with_minimal_dict (self ):
209+ """Test builder with minimal required fields."""
210+ response_dict = {"status_code" : 200 }
211+ response = GenerateTextResponseBuilder .build (response_dict )
212+ assert response .status_code == 200
213+ assert response .version == "v1" # Default value
214+
215+
216+ class TestDefaultLLMGateway :
217+ """Test DefaultLLMGateway implementation."""
218+
219+ def test_default_gateway_is_llm_gateway (self ):
220+ """Test DefaultLLMGateway inherits from LLMGateway."""
221+ gateway = DefaultLLMGateway ()
222+ assert isinstance (gateway , LLMGateway )
223+
224+ def test_generate_text_returns_response (self ):
225+ """Test generate_text returns GenerateTextResponse."""
226+ gateway = DefaultLLMGateway ()
227+ request = GenerateTextRequest (model_name = "gpt-4" , prompt = "Hello" )
228+ response = gateway .generate_text (request )
229+ assert isinstance (response , GenerateTextResponse )
230+
231+ def test_generate_text_success_response (self ):
232+ """Test generate_text returns successful response."""
233+ gateway = DefaultLLMGateway ()
234+ request = GenerateTextRequest (model_name = "gpt-4" , prompt = "Hello" )
235+ response = gateway .generate_text (request )
236+ assert response .is_success is True
237+ assert response .status_code == 200
238+ assert len (response .text ) > 0
0 commit comments