44import tracemalloc
55import asyncio
66import unittest
7- from unittest .mock import patch , MagicMock
7+ from unittest .mock import patch , MagicMock , AsyncMock
88from unittest import IsolatedAsyncioTestCase
99
1010from runpod .endpoint .asyncio .asyncio_runner import Job , Endpoint
@@ -19,37 +19,91 @@ async def test_status(self):
1919 '''
2020 Tests Job.status
2121 '''
22- with patch ("aiohttp.ClientSession" ) as mock_session :
23- mock_resp = MagicMock ()
24- mock_resp .json = MagicMock (return_value = asyncio .Future ())
25- mock_resp .json .return_value .set_result ({"status" : "COMPLETED" })
26- mock_session .get .return_value .__aenter__ .return_value = mock_resp
22+ with patch ("aiohttp.ClientSession" , new_callable = AsyncMock ) as mock_session_class :
23+ mock_session = mock_session_class .return_value
24+ mock_get = mock_session .get
25+ mock_resp = AsyncMock ()
26+
27+ mock_resp .json .return_value = {"status" : "COMPLETED" }
28+ mock_get .return_value = mock_resp
2729
2830 job = Job ("endpoint_id" , "job_id" , mock_session )
2931 status = await job .status ()
3032 assert status == "COMPLETED"
33+ assert await job .status () == "COMPLETED"
3134
3235 async def test_output (self ):
3336 '''
3437 Tests Job.output
3538 '''
3639 with patch ("runpod.endpoint.asyncio.asyncio_runner.asyncio.sleep" ) as mock_sleep , \
37- patch ("aiohttp.ClientSession" ) as mock_session :
38- mock_resp = MagicMock ()
40+ patch ("aiohttp.ClientSession" , new_callable = AsyncMock ) as mock_session_class :
41+ mock_session = mock_session_class .return_value
42+ mock_get = mock_session .get
43+ mock_resp = AsyncMock ()
3944
4045 async def json_side_effect ():
4146 if mock_sleep .call_count == 0 :
4247 return {"status" : "IN_PROGRESS" }
4348 return {"output" : "OUTPUT" , "status" : "COMPLETED" }
4449
45- mock_resp .json = json_side_effect
46- mock_session . get . return_value . __aenter__ .return_value = mock_resp
50+ mock_resp .json . side_effect = json_side_effect
51+ mock_get .return_value = mock_resp
4752
4853 job = Job ("endpoint_id" , "job_id" , mock_session )
49- output_task = asyncio .create_task (job .output ())
54+ output_task = asyncio .create_task (job .output (timeout = 3 ))
5055
5156 output = await output_task
5257 assert output == "OUTPUT"
58+ assert await job .output () == "OUTPUT"
59+
60+ async def test_output_timeout (self ):
61+ '''
62+ Tests Job.output with a timeout
63+ '''
64+ with patch ("aiohttp.ClientSession" , new_callable = AsyncMock ) as mock_session_class :
65+ mock_session = mock_session_class .return_value
66+ mock_get = mock_session .get
67+ mock_resp = AsyncMock ()
68+
69+ mock_resp .json .return_value = {"status" : "IN_PROGRESS" }
70+ mock_get .return_value = mock_resp
71+
72+ job = Job ("endpoint_id" , "job_id" , mock_session )
73+ output_task = asyncio .create_task (job .output (timeout = 1 ))
74+
75+ with self .assertRaises (TimeoutError ):
76+ await output_task
77+
78+ async def test_stream (self ):
79+ '''
80+ Tests Job.stream
81+ '''
82+ with patch ("aiohttp.ClientSession" , new_callable = AsyncMock ) as mock_session_class :
83+ mock_session = mock_session_class .return_value
84+ mock_get = mock_session .get
85+ mock_resp = AsyncMock ()
86+
87+ responses = [
88+ {"stream" : [{"output" : "OUTPUT1" }], "status" : "IN_PROGRESS" },
89+ {"stream" : [{"output" : "OUTPUT2" }], "status" : "IN_PROGRESS" },
90+ ]
91+
92+ async def json_side_effect ():
93+ return responses .pop (0 ) if responses else {"stream" : [], "status" : "COMPLETED" }
94+
95+ mock_resp .json .side_effect = json_side_effect
96+ mock_get .return_value = mock_resp
97+
98+ job = Job ("endpoint_id" , "job_id" , mock_session )
99+
100+ outputs = []
101+ async for stream_output in job .stream ():
102+ outputs .append (stream_output )
103+ if not responses : # Break the loop when responses are exhausted
104+ break
105+
106+ assert outputs == ["OUTPUT1" , "OUTPUT2" ]
53107
54108 async def test_cancel (self ):
55109 '''
@@ -67,27 +121,25 @@ async def test_cancel(self):
67121
68122 async def test_output_in_progress_then_completed (self ):
69123 '''Tests Job.output when status is initially IN_PROGRESS and then changes to COMPLETED'''
70- with patch ("runpod.endpoint.asyncio.asyncio_runner.asyncio.sleep" ) as mock_sleep , \
71- patch ("aiohttp.ClientSession" ) as mock_session :
72- mock_resp = MagicMock ()
124+ with patch ("aiohttp.ClientSession" , new_callable = AsyncMock ) as mock_session_class :
125+ mock_session = mock_session_class .return_value
126+ mock_get = mock_session .get
127+ mock_resp = AsyncMock ()
128+
73129 responses = [
74130 {"status" : "IN_PROGRESS" },
75- {"status" : "COMPLETED" },
76- {"output" : "OUTPUT" }
131+ {"status" : "COMPLETED" , "output" : "OUTPUT" }
77132 ]
78133
79134 async def json_side_effect ():
80- if responses :
81- return responses .pop (0 )
82- return {"status" : "IN_PROGRESS" }
135+ return responses .pop (0 ) if responses else {"status" : "COMPLETED" , "output" : "OUTPUT" } # pylint: disable=line-too-long
83136
84- mock_resp .json = json_side_effect
85- mock_session . get . return_value . __aenter__ .return_value = mock_resp
137+ mock_resp .json . side_effect = json_side_effect
138+ mock_get .return_value = mock_resp
86139
87140 job = Job ("endpoint_id" , "job_id" , mock_session )
88- output = await job .output ()
141+ output = await job .output (timeout = 3 )
89142 assert output == "OUTPUT"
90- mock_sleep .assert_called_once_with (1 )
91143
92144class TestEndpoint (IsolatedAsyncioTestCase ):
93145 ''' Unit tests for the Endpoint class. '''
@@ -106,6 +158,34 @@ async def test_run(self):
106158 job = await endpoint .run ({"input" : "INPUT" })
107159 assert job .job_id == "job_id"
108160
161+ async def test_health (self ):
162+ '''
163+ Tests Endpoint.health
164+ '''
165+ with patch ("aiohttp.ClientSession" ) as mock_session :
166+ mock_resp = MagicMock ()
167+ mock_resp .json = MagicMock (return_value = asyncio .Future ())
168+ mock_resp .json .return_value .set_result ({"status" : "HEALTHY" })
169+ mock_session .get .return_value .__aenter__ .return_value = mock_resp
170+
171+ endpoint = Endpoint ("endpoint_id" , mock_session )
172+ health = await endpoint .health ()
173+ assert health == {"status" : "HEALTHY" }
174+
175+ async def test_purge_queue (self ):
176+ '''
177+ Tests Endpoint.purge_queue
178+ '''
179+ with patch ("aiohttp.ClientSession" ) as mock_session :
180+ mock_resp = MagicMock ()
181+ mock_resp .json = MagicMock (return_value = asyncio .Future ())
182+ mock_resp .json .return_value .set_result ({"result" : "PURGED" })
183+ mock_session .post .return_value .__aenter__ .return_value = mock_resp
184+
185+ endpoint = Endpoint ("endpoint_id" , mock_session )
186+ purge_result = await endpoint .purge_queue ()
187+ assert purge_result == {"result" : "PURGED" }
188+
109189class TestEndpointInitialization (unittest .TestCase ):
110190 '''Tests for the Endpoint class initialization.'''
111191
0 commit comments