Skip to content

Commit 9b3cd22

Browse files
Adding testcase for function_utils.py
1 parent b0608ea commit 9b3cd22

1 file changed

Lines changed: 247 additions & 0 deletions

File tree

tests/test_function_utils.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import os
18+
import shutil
19+
import sys
20+
import tempfile
21+
import textwrap
22+
from typing import List
23+
24+
import pytest
25+
from pydantic import BaseModel
26+
27+
from datacustomcode import function_utils
28+
29+
30+
class SampleRequest(BaseModel):
31+
message: str
32+
count: int = 5
33+
tags: List[str] = []
34+
version: str = "v1"
35+
36+
37+
@pytest.fixture
38+
def sample_entrypoint():
39+
"""Create a temporary entrypoint file with a function."""
40+
with tempfile.NamedTemporaryFile(
41+
mode="w", suffix=".py", delete=False
42+
) as temp_file:
43+
entrypoint_content = textwrap.dedent(
44+
"""
45+
from typing import List
46+
from pydantic import BaseModel
47+
48+
class SampleRequest(BaseModel):
49+
message: str
50+
count: int = 5
51+
tags: List[str] = []
52+
version: str = "v1"
53+
54+
class SampleResponse(BaseModel):
55+
result: str
56+
success: bool = True
57+
58+
def function(request: SampleRequest) -> SampleResponse:
59+
return SampleResponse(result=f"Processed {request.message}")
60+
"""
61+
)
62+
temp_file.write(entrypoint_content)
63+
temp_file_path = temp_file.name
64+
65+
yield temp_file_path
66+
67+
if os.path.exists(temp_file_path):
68+
os.unlink(temp_file_path)
69+
70+
71+
@pytest.fixture
72+
def entrypoint_no_annotations():
73+
"""Create an entrypoint with no type annotations."""
74+
with tempfile.NamedTemporaryFile(
75+
mode="w", suffix=".py", delete=False
76+
) as temp_file:
77+
entrypoint_content = textwrap.dedent(
78+
"""
79+
def function(request):
80+
return {"result": "no annotations"}
81+
"""
82+
)
83+
temp_file.write(entrypoint_content)
84+
temp_file_path = temp_file.name
85+
86+
yield temp_file_path
87+
88+
if os.path.exists(temp_file_path):
89+
os.unlink(temp_file_path)
90+
91+
92+
def test_get_function_signature_types(sample_entrypoint, entrypoint_no_annotations):
93+
"""Test extracting request and response types from function signatures."""
94+
module = function_utils.load_function_module(sample_entrypoint)
95+
func = function_utils.get_function_callable(module)
96+
req_type, resp_type, req_name, resp_name = (
97+
function_utils.get_function_signature_types(func)
98+
)
99+
100+
assert req_name == "SampleRequest"
101+
assert resp_name == "SampleResponse"
102+
assert req_type is not None
103+
assert resp_type is not None
104+
105+
module_no_annot = function_utils.load_function_module(entrypoint_no_annotations)
106+
func_no_annot = function_utils.get_function_callable(module_no_annot)
107+
req_type, resp_type, req_name, resp_name = (
108+
function_utils.get_function_signature_types(func_no_annot)
109+
)
110+
111+
assert req_name is None
112+
assert resp_name is None
113+
114+
115+
def test_inspect_function_types_static(sample_entrypoint, entrypoint_no_annotations):
116+
"""Test static AST-based inspection of function types."""
117+
req_name, resp_name = function_utils.inspect_function_types_static(
118+
sample_entrypoint
119+
)
120+
assert req_name == "SampleRequest"
121+
assert resp_name == "SampleResponse"
122+
123+
req_name, resp_name = function_utils.inspect_function_types_static(
124+
entrypoint_no_annotations
125+
)
126+
assert req_name is None
127+
assert resp_name is None
128+
129+
def test_inspect_function_types(sample_entrypoint):
130+
"""Test dynamic inspection of function types."""
131+
req_name, resp_name = function_utils.inspect_function_types(sample_entrypoint)
132+
assert req_name == "SampleRequest"
133+
assert resp_name == "SampleResponse"
134+
135+
req_name, resp_name = function_utils.inspect_function_types("/nonexistent/file.py")
136+
assert req_name is None
137+
assert resp_name is None
138+
139+
140+
def test_get_request_type(sample_entrypoint, entrypoint_no_annotations):
141+
"""Test getting request type from entrypoint."""
142+
req_type = function_utils.get_request_type(sample_entrypoint)
143+
assert req_type is not None
144+
assert hasattr(req_type, "model_fields")
145+
146+
with pytest.raises(ValueError, match="must have a type annotation"):
147+
function_utils.get_request_type(entrypoint_no_annotations)
148+
149+
150+
def test_generate_test_json():
151+
"""Test generating test.json file from entrypoint with simple and complex nested types."""
152+
temp_dir = tempfile.mkdtemp()
153+
models_file = os.path.join(temp_dir, "test_models.py")
154+
155+
try:
156+
# Test 1: Simple request type
157+
entrypoint_simple = os.path.join(temp_dir, "entrypoint_simple.py")
158+
output_simple = os.path.join(temp_dir, "test_simple.json")
159+
160+
with open(models_file, "w") as f:
161+
models_content = textwrap.dedent(
162+
"""
163+
from pydantic import BaseModel
164+
from typing import List
165+
166+
class SimpleRequest(BaseModel):
167+
message: str
168+
count: int = 5
169+
tags: List[str] = []
170+
version: str = "v1"
171+
172+
class NestedConfig(BaseModel):
173+
host: str
174+
port: int = 8080
175+
enabled: bool = True
176+
177+
class ComplexRequest(BaseModel):
178+
name: str
179+
max_items: int = 100
180+
config: NestedConfig
181+
metadata: dict = {}
182+
"""
183+
)
184+
f.write(models_content)
185+
186+
with open(entrypoint_simple, "w") as f:
187+
entrypoint_content = textwrap.dedent(
188+
"""
189+
from test_models import SimpleRequest
190+
191+
def function(request: SimpleRequest):
192+
return {"result": "ok"}
193+
"""
194+
)
195+
f.write(entrypoint_content)
196+
197+
sys.path.insert(0, temp_dir)
198+
199+
function_utils.generate_test_json(entrypoint_simple, output_simple)
200+
assert os.path.exists(output_simple)
201+
202+
with open(output_simple, "r") as f:
203+
data = json.load(f)
204+
205+
assert "message" in data
206+
assert data["count"] == 5
207+
assert data["version"] == "v1"
208+
assert data["tags"] == []
209+
210+
# Test 2: Complex request type with nested models
211+
entrypoint_complex = os.path.join(temp_dir, "entrypoint_complex.py")
212+
output_complex = os.path.join(temp_dir, "test_complex.json")
213+
214+
with open(entrypoint_complex, "w") as f:
215+
entrypoint_content = textwrap.dedent(
216+
"""
217+
from test_models import ComplexRequest
218+
219+
def function(request: ComplexRequest):
220+
return {"result": "ok"}
221+
"""
222+
)
223+
f.write(entrypoint_content)
224+
225+
function_utils.generate_test_json(entrypoint_complex, output_complex)
226+
assert os.path.exists(output_complex)
227+
228+
with open(output_complex, "r") as f:
229+
complex_data = json.load(f)
230+
231+
assert "name" in complex_data
232+
assert "max_items" in complex_data
233+
assert complex_data["max_items"] == 100
234+
assert "config" in complex_data
235+
assert isinstance(complex_data["config"], dict)
236+
assert "host" in complex_data["config"]
237+
assert "port" in complex_data["config"]
238+
assert complex_data["config"]["port"] == 8080
239+
assert complex_data["config"]["enabled"] is True
240+
assert "metadata" in complex_data
241+
assert complex_data["metadata"] == {}
242+
243+
finally:
244+
if temp_dir in sys.path:
245+
sys.path.remove(temp_dir)
246+
if os.path.exists(temp_dir):
247+
shutil.rmtree(temp_dir)

0 commit comments

Comments
 (0)