Skip to content

Commit e24ae3e

Browse files
Adding unit test
1 parent 3959bac commit e24ae3e

2 files changed

Lines changed: 329 additions & 0 deletions

File tree

tests/test_llm_gateway.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
from pydantic import ValidationError
5+
6+
from datacustomcode.llm_gateway.base import LLMGateway
7+
from datacustomcode.llm_gateway.default import DefaultLLMGateway
8+
from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest
9+
from datacustomcode.llm_gateway.types.generate_text_request_builder import GenerateTextRequestBuilder
10+
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
11+
from datacustomcode.llm_gateway.types.generate_text_response_builder import GenerateTextResponseBuilder
12+
13+
14+
class TestGenerateTextRequest:
15+
"""Test GenerateTextRequest model."""
16+
17+
def test_version_defaults_to_v1(self):
18+
"""Test version field defaults to v1."""
19+
request = GenerateTextRequest(model_name="gpt-4", prompt="Hello")
20+
assert request.version == "v1"
21+
22+
def test_model_name_required(self):
23+
"""Test model_name is required."""
24+
with pytest.raises(ValidationError) as exc_info:
25+
GenerateTextRequest(prompt="Hello")
26+
# Error message uses camelCase alias
27+
assert "modelName" in str(exc_info.value) or "model_name" in str(exc_info.value)
28+
29+
def test_prompt_required(self):
30+
"""Test prompt is required."""
31+
with pytest.raises(ValidationError) as exc_info:
32+
GenerateTextRequest(model_name="gpt-4")
33+
assert "prompt" in str(exc_info.value)
34+
35+
def test_version_must_be_v1(self):
36+
"""Test version must be literal 'v1'."""
37+
with pytest.raises(ValidationError) as exc_info:
38+
GenerateTextRequest(version="v2", model_name="gpt-4", prompt="Hello")
39+
assert "version" in str(exc_info.value)
40+
41+
def test_camel_case_serialization(self):
42+
"""Test serialization with camelCase aliases."""
43+
request = GenerateTextRequest(model_name="gpt-4", prompt="Hello")
44+
data = request.model_dump(by_alias=True)
45+
assert "modelName" in data
46+
assert data["modelName"] == "gpt-4"
47+
assert "model_name" not in data
48+
49+
def test_accepts_camel_case_input(self):
50+
"""Test model can accept camelCase field names."""
51+
request = GenerateTextRequest(modelName="gpt-4", prompt="Hello")
52+
assert request.model_name == "gpt-4"
53+
54+
55+
class TestGenerateTextRequestBuilder:
56+
"""Test GenerateTextRequestBuilder."""
57+
58+
def test_builder_basic_usage(self):
59+
"""Test basic builder pattern."""
60+
builder = GenerateTextRequestBuilder()
61+
request = builder.set_prompt("Hello").set_model("gpt-4").build()
62+
assert request.prompt == "Hello"
63+
assert request.model_name == "gpt-4"
64+
65+
def test_builder_with_localization_dict(self):
66+
"""Test builder with localization dictionary."""
67+
builder = GenerateTextRequestBuilder()
68+
localization = {"defaultLocale": "en-US", "timezone": "PST"}
69+
request = (
70+
builder.set_prompt("Hello")
71+
.set_model("gpt-4")
72+
.set_localization(localization=localization)
73+
.build()
74+
)
75+
assert request.localization == localization
76+
77+
def test_builder_with_locale_string(self):
78+
"""Test builder with simple locale string."""
79+
builder = GenerateTextRequestBuilder()
80+
request = (
81+
builder.set_prompt("Hello")
82+
.set_model("gpt-4")
83+
.set_localization(locale="en-US")
84+
.build()
85+
)
86+
# Verify the localization structure
87+
assert request.localization is not None
88+
assert request.localization["defaultLocale"] == "en-US"
89+
assert request.localization["inputLocales"] == [{"locale": "en-US", "probability": 1.0}]
90+
assert request.localization["expectedLocales"] == ["en-US"]
91+
92+
def test_builder_with_tags(self):
93+
"""Test builder with tags."""
94+
builder = GenerateTextRequestBuilder()
95+
tags = {"user": "test", "session": "123"}
96+
request = (
97+
builder.set_prompt("Hello")
98+
.set_model("gpt-4")
99+
.set_tags(tags)
100+
.build()
101+
)
102+
assert request.tags == tags
103+
104+
def test_builder_validates_on_build(self):
105+
"""Test builder validates request on build."""
106+
builder = GenerateTextRequestBuilder()
107+
with pytest.raises(ValidationError):
108+
builder.set_prompt("Hello").set_model("").build()
109+
110+
def test_builder_localization_requires_argument(self):
111+
"""Test set_localization requires either localization or locale."""
112+
builder = GenerateTextRequestBuilder()
113+
with pytest.raises(ValueError) as exc_info:
114+
builder.set_localization()
115+
assert "Must provide either localization or locale" in str(exc_info.value)
116+
117+
118+
class TestGenerateTextResponse:
119+
"""Test GenerateTextResponse model."""
120+
121+
def test_response_defaults(self):
122+
"""Test response field defaults."""
123+
response = GenerateTextResponse(status_code=200)
124+
assert response.version == "v1"
125+
assert response.data is None
126+
127+
def test_is_success_property(self):
128+
"""Test is_success property returns True for 200."""
129+
response = GenerateTextResponse(status_code=200)
130+
assert response.is_success is True
131+
assert response.is_error is False
132+
133+
def test_is_success_false_for_non_200(self):
134+
"""Test is_success property returns False for non-200."""
135+
response = GenerateTextResponse(status_code=400)
136+
assert response.is_success is False
137+
assert response.is_error is True
138+
139+
def test_text_property_success(self):
140+
"""Test text property extracts generated text on success."""
141+
response = GenerateTextResponse(
142+
status_code=200,
143+
data={"generation": {"generatedText": "Hello world"}}
144+
)
145+
assert response.text == "Hello world"
146+
147+
def test_text_property_returns_empty_on_error(self):
148+
"""Test text property returns empty string on error."""
149+
response = GenerateTextResponse(status_code=400, data={"error": "Bad request"})
150+
assert response.text == ""
151+
152+
def test_text_property_handles_missing_data(self):
153+
"""Test text property handles missing data gracefully."""
154+
response = GenerateTextResponse(status_code=200, data=None)
155+
assert response.text == ""
156+
157+
def test_text_property_handles_missing_nested_fields(self):
158+
"""Test text property handles missing nested fields."""
159+
response = GenerateTextResponse(status_code=200, data={"other": "data"})
160+
assert response.text == ""
161+
162+
def test_error_code_property(self):
163+
"""Test error_code property extracts error code on error."""
164+
response = GenerateTextResponse(
165+
status_code=400,
166+
data={"errorCode": "INVALID_REQUEST"}
167+
)
168+
assert response.error_code == "INVALID_REQUEST"
169+
170+
def test_error_code_falls_back_to_status_code(self):
171+
"""Test error_code falls back to status_code if no errorCode in data."""
172+
response = GenerateTextResponse(status_code=500, data={"message": "error"})
173+
assert response.error_code == '500'
174+
175+
def test_error_code_returns_empty_on_success(self):
176+
"""Test error_code returns empty string on success."""
177+
response = GenerateTextResponse(status_code=200)
178+
assert response.error_code == ""
179+
180+
def test_status_code_validation(self):
181+
"""Test status_code must be >= 0."""
182+
with pytest.raises(ValidationError) as exc_info:
183+
GenerateTextResponse(status_code=-1)
184+
assert "status_code" in str(exc_info.value)
185+
186+
187+
class TestGenerateTextResponseBuilder:
188+
"""Test GenerateTextResponseBuilder."""
189+
190+
def test_builder_build_from_dict(self):
191+
"""Test building response from dictionary."""
192+
response_dict = {
193+
"version": "v1",
194+
"status_code": 200,
195+
"data": {"generation": {"generatedText": "Hello world"}}
196+
}
197+
response = GenerateTextResponseBuilder.build(response_dict)
198+
assert isinstance(response, GenerateTextResponse)
199+
assert response.status_code == 200
200+
assert response.text == "Hello world"
201+
202+
def test_builder_validates_dict(self):
203+
"""Test builder validates the dictionary."""
204+
invalid_dict = {"version": "v1"} # Missing required status_code
205+
with pytest.raises(ValidationError):
206+
GenerateTextResponseBuilder.build(invalid_dict)
207+
208+
def test_builder_with_minimal_dict(self):
209+
"""Test builder with minimal required fields."""
210+
response_dict = {"status_code": 200}
211+
response = GenerateTextResponseBuilder.build(response_dict)
212+
assert response.status_code == 200
213+
assert response.version == "v1" # Default value
214+
215+
216+
class TestDefaultLLMGateway:
217+
"""Test DefaultLLMGateway implementation."""
218+
219+
def test_default_gateway_is_llm_gateway(self):
220+
"""Test DefaultLLMGateway inherits from LLMGateway."""
221+
gateway = DefaultLLMGateway()
222+
assert isinstance(gateway, LLMGateway)
223+
224+
def test_generate_text_returns_response(self):
225+
"""Test generate_text returns GenerateTextResponse."""
226+
gateway = DefaultLLMGateway()
227+
request = GenerateTextRequest(model_name="gpt-4", prompt="Hello")
228+
response = gateway.generate_text(request)
229+
assert isinstance(response, GenerateTextResponse)
230+
231+
def test_generate_text_success_response(self):
232+
"""Test generate_text returns successful response."""
233+
gateway = DefaultLLMGateway()
234+
request = GenerateTextRequest(model_name="gpt-4", prompt="Hello")
235+
response = gateway.generate_text(request)
236+
assert response.is_success is True
237+
assert response.status_code == 200
238+
assert len(response.text) > 0

tests/test_runtime.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import threading
4+
import pytest
5+
6+
from datacustomcode.runtime.base import BaseRuntime
7+
from datacustomcode.runtime.function import Runtime
8+
from datacustomcode.llm_gateway.default import DefaultLLMGateway
9+
from datacustomcode.file.path.default import DefaultFindFilePath
10+
11+
12+
class TestRuntimeSingleton:
13+
"""Test Runtime singleton pattern."""
14+
15+
def setup_method(self):
16+
"""Reset singleton before each test."""
17+
# Reset the singleton instance for each test
18+
Runtime._instance = None
19+
20+
def test_second_instantiation_raises_error(self):
21+
"""Test creating Runtime twice raises RuntimeError."""
22+
runtime1 = Runtime()
23+
assert runtime1 is not None
24+
25+
with pytest.raises(RuntimeError) as exc_info:
26+
Runtime()
27+
28+
assert "can only be instantiated once" in str(exc_info.value)
29+
assert "Do not instantiate it yourself" in str(exc_info.value)
30+
31+
def test_concurrent_instantiation_thread_safe(self):
32+
"""Test singleton is thread-safe during concurrent instantiation."""
33+
results = []
34+
errors = []
35+
36+
def create_runtime():
37+
try:
38+
runtime = Runtime()
39+
results.append(runtime)
40+
except RuntimeError as e:
41+
errors.append(e)
42+
43+
# Create 10 threads trying to instantiate Runtime
44+
threads = []
45+
for _ in range(10):
46+
thread = threading.Thread(target=create_runtime)
47+
threads.append(thread)
48+
thread.start()
49+
50+
# Wait for all threads to complete
51+
for thread in threads:
52+
thread.join()
53+
54+
# Exactly one should succeed, rest should fail
55+
assert len(results) == 1, f"Expected 1 success, got {len(results)}"
56+
assert len(errors) == 9, f"Expected 9 errors, got {len(errors)}"
57+
58+
# All errors should be RuntimeError about singleton
59+
for error in errors:
60+
assert "can only be instantiated once" in str(error)
61+
62+
63+
class TestRuntimeProperties:
64+
"""Test Runtime properties and methods."""
65+
66+
def setup_method(self):
67+
"""Reset singleton and create fresh instance for each test."""
68+
Runtime._instance = None
69+
self.runtime = Runtime()
70+
71+
def test_runtime_has_llm_gateway(self):
72+
"""Test Runtime has llm_gateway property."""
73+
assert hasattr(self.runtime, 'llm_gateway')
74+
assert isinstance(self.runtime.llm_gateway, DefaultLLMGateway)
75+
76+
def test_runtime_has_file(self):
77+
"""Test Runtime has file property."""
78+
assert hasattr(self.runtime, 'file')
79+
assert isinstance(self.runtime.file, DefaultFindFilePath)
80+
81+
def test_runtime_initializes_only_once(self):
82+
"""Test Runtime.__init__ prevents re-initialization."""
83+
# First init happens in setup_method
84+
llm_gateway_before = self.runtime.llm_gateway
85+
86+
# Call __init__ again (shouldn't re-initialize)
87+
self.runtime.__init__()
88+
89+
# Should still be the same instances
90+
llm_gateway_after = self.runtime.llm_gateway
91+
assert llm_gateway_before is llm_gateway_after

0 commit comments

Comments
 (0)