1+ # Copyright (c) 2025, Salesforce, Inc.
2+ # SPDX-License-Identifier: Apache-2
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+
16+ import json
17+ import os
18+ import shutil
19+ import sys
20+ import tempfile
21+ import textwrap
22+ from typing import List
23+
24+ import pytest
25+ from pydantic import BaseModel
26+
27+ from datacustomcode import function_utils
28+
29+
30+ class SampleRequest (BaseModel ):
31+ message : str
32+ count : int = 5
33+ tags : List [str ] = []
34+ version : str = "v1"
35+
36+
37+ @pytest .fixture
38+ def sample_entrypoint ():
39+ """Create a temporary entrypoint file with a function."""
40+ with tempfile .NamedTemporaryFile (
41+ mode = "w" , suffix = ".py" , delete = False
42+ ) as temp_file :
43+ entrypoint_content = textwrap .dedent (
44+ """
45+ from typing import List
46+ from pydantic import BaseModel
47+
48+ class SampleRequest(BaseModel):
49+ message: str
50+ count: int = 5
51+ tags: List[str] = []
52+ version: str = "v1"
53+
54+ class SampleResponse(BaseModel):
55+ result: str
56+ success: bool = True
57+
58+ def function(request: SampleRequest) -> SampleResponse:
59+ return SampleResponse(result=f"Processed {request.message}")
60+ """
61+ )
62+ temp_file .write (entrypoint_content )
63+ temp_file_path = temp_file .name
64+
65+ yield temp_file_path
66+
67+ if os .path .exists (temp_file_path ):
68+ os .unlink (temp_file_path )
69+
70+
71+ @pytest .fixture
72+ def entrypoint_no_annotations ():
73+ """Create an entrypoint with no type annotations."""
74+ with tempfile .NamedTemporaryFile (
75+ mode = "w" , suffix = ".py" , delete = False
76+ ) as temp_file :
77+ entrypoint_content = textwrap .dedent (
78+ """
79+ def function(request):
80+ return {"result": "no annotations"}
81+ """
82+ )
83+ temp_file .write (entrypoint_content )
84+ temp_file_path = temp_file .name
85+
86+ yield temp_file_path
87+
88+ if os .path .exists (temp_file_path ):
89+ os .unlink (temp_file_path )
90+
91+
92+ def test_get_function_signature_types (sample_entrypoint , entrypoint_no_annotations ):
93+ """Test extracting request and response types from function signatures."""
94+ module = function_utils .load_function_module (sample_entrypoint )
95+ func = function_utils .get_function_callable (module )
96+ req_type , resp_type , req_name , resp_name = (
97+ function_utils .get_function_signature_types (func )
98+ )
99+
100+ assert req_name == "SampleRequest"
101+ assert resp_name == "SampleResponse"
102+ assert req_type is not None
103+ assert resp_type is not None
104+
105+ module_no_annot = function_utils .load_function_module (entrypoint_no_annotations )
106+ func_no_annot = function_utils .get_function_callable (module_no_annot )
107+ req_type , resp_type , req_name , resp_name = (
108+ function_utils .get_function_signature_types (func_no_annot )
109+ )
110+
111+ assert req_name is None
112+ assert resp_name is None
113+
114+
115+ def test_inspect_function_types_static (sample_entrypoint , entrypoint_no_annotations ):
116+ """Test static AST-based inspection of function types."""
117+ req_name , resp_name = function_utils .inspect_function_types_static (
118+ sample_entrypoint
119+ )
120+ assert req_name == "SampleRequest"
121+ assert resp_name == "SampleResponse"
122+
123+ req_name , resp_name = function_utils .inspect_function_types_static (
124+ entrypoint_no_annotations
125+ )
126+ assert req_name is None
127+ assert resp_name is None
128+
129+ def test_inspect_function_types (sample_entrypoint ):
130+ """Test dynamic inspection of function types."""
131+ req_name , resp_name = function_utils .inspect_function_types (sample_entrypoint )
132+ assert req_name == "SampleRequest"
133+ assert resp_name == "SampleResponse"
134+
135+ req_name , resp_name = function_utils .inspect_function_types ("/nonexistent/file.py" )
136+ assert req_name is None
137+ assert resp_name is None
138+
139+
140+ def test_get_request_type (sample_entrypoint , entrypoint_no_annotations ):
141+ """Test getting request type from entrypoint."""
142+ req_type = function_utils .get_request_type (sample_entrypoint )
143+ assert req_type is not None
144+ assert hasattr (req_type , "model_fields" )
145+
146+ with pytest .raises (ValueError , match = "must have a type annotation" ):
147+ function_utils .get_request_type (entrypoint_no_annotations )
148+
149+
150+ def test_generate_test_json ():
151+ """Test generating test.json file from entrypoint with simple and complex nested types."""
152+ temp_dir = tempfile .mkdtemp ()
153+ models_file = os .path .join (temp_dir , "test_models.py" )
154+
155+ try :
156+ # Test 1: Simple request type
157+ entrypoint_simple = os .path .join (temp_dir , "entrypoint_simple.py" )
158+ output_simple = os .path .join (temp_dir , "test_simple.json" )
159+
160+ with open (models_file , "w" ) as f :
161+ models_content = textwrap .dedent (
162+ """
163+ from pydantic import BaseModel
164+ from typing import List
165+
166+ class SimpleRequest(BaseModel):
167+ message: str
168+ count: int = 5
169+ tags: List[str] = []
170+ version: str = "v1"
171+
172+ class NestedConfig(BaseModel):
173+ host: str
174+ port: int = 8080
175+ enabled: bool = True
176+
177+ class ComplexRequest(BaseModel):
178+ name: str
179+ max_items: int = 100
180+ config: NestedConfig
181+ metadata: dict = {}
182+ """
183+ )
184+ f .write (models_content )
185+
186+ with open (entrypoint_simple , "w" ) as f :
187+ entrypoint_content = textwrap .dedent (
188+ """
189+ from test_models import SimpleRequest
190+
191+ def function(request: SimpleRequest):
192+ return {"result": "ok"}
193+ """
194+ )
195+ f .write (entrypoint_content )
196+
197+ sys .path .insert (0 , temp_dir )
198+
199+ function_utils .generate_test_json (entrypoint_simple , output_simple )
200+ assert os .path .exists (output_simple )
201+
202+ with open (output_simple , "r" ) as f :
203+ data = json .load (f )
204+
205+ assert "message" in data
206+ assert data ["count" ] == 5
207+ assert data ["version" ] == "v1"
208+ assert data ["tags" ] == []
209+
210+ # Test 2: Complex request type with nested models
211+ entrypoint_complex = os .path .join (temp_dir , "entrypoint_complex.py" )
212+ output_complex = os .path .join (temp_dir , "test_complex.json" )
213+
214+ with open (entrypoint_complex , "w" ) as f :
215+ entrypoint_content = textwrap .dedent (
216+ """
217+ from test_models import ComplexRequest
218+
219+ def function(request: ComplexRequest):
220+ return {"result": "ok"}
221+ """
222+ )
223+ f .write (entrypoint_content )
224+
225+ function_utils .generate_test_json (entrypoint_complex , output_complex )
226+ assert os .path .exists (output_complex )
227+
228+ with open (output_complex , "r" ) as f :
229+ complex_data = json .load (f )
230+
231+ assert "name" in complex_data
232+ assert "max_items" in complex_data
233+ assert complex_data ["max_items" ] == 100
234+ assert "config" in complex_data
235+ assert isinstance (complex_data ["config" ], dict )
236+ assert "host" in complex_data ["config" ]
237+ assert "port" in complex_data ["config" ]
238+ assert complex_data ["config" ]["port" ] == 8080
239+ assert complex_data ["config" ]["enabled" ] is True
240+ assert "metadata" in complex_data
241+ assert complex_data ["metadata" ] == {}
242+
243+ finally :
244+ if temp_dir in sys .path :
245+ sys .path .remove (temp_dir )
246+ if os .path .exists (temp_dir ):
247+ shutil .rmtree (temp_dir )
0 commit comments