Skip to content

Commit 07e66c0

Browse files
Merge pull request #263 from runpod/update-async-runner
Update async runner
2 parents 39bc115 + 7af6136 commit 07e66c0

4 files changed

Lines changed: 202 additions & 62 deletions

File tree

runpod/endpoint/asyncio/asyncio_runner.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
'''
2-
Author: Oleg Rybalko
3-
Github: https://github.com/SkullMag
4-
Date: 2023-03-27
5-
'''
1+
""" Module for running endpoints asynchronously. """
62
# pylint: disable=too-few-public-methods,R0801
73

4+
from typing import Any, Dict
85
import asyncio
96
import aiohttp
107

8+
from runpod.endpoint.helpers import FINAL_STATES, is_completed
9+
1110

1211
class Job:
1312
"""Class representing a job for an asynchronous endpoint"""
@@ -17,45 +16,84 @@ def __init__(self, endpoint_id: str, job_id: str, session: aiohttp.ClientSession
1716

1817
self.endpoint_id = endpoint_id
1918
self.job_id = job_id
20-
self.status_url = f"{endpoint_url_base}/{self.endpoint_id}/status/{self.job_id}"
21-
self.cancel_url = f"{endpoint_url_base}/{self.endpoint_id}/cancel/{self.job_id}"
2219
self.headers = {
2320
"Content-Type": "application/json",
2421
"Authorization": f"Bearer {api_key}"
2522
}
2623
self.session = session
24+
self.endpoint_url_base = endpoint_url_base
25+
26+
self.job_status = None
27+
self.job_output = None
28+
29+
async def _fetch_job(self, source: str = "status") -> Dict[str, Any]:
30+
""" Returns the raw json of the status, reaises an exception if invalid.
31+
32+
Args:
33+
source: The URL source path of the job status.
34+
"""
35+
status_url = f"{self.endpoint_url_base}/{self.endpoint_id}/{source}/{self.job_id}"
36+
job_state = await self.session.get(status_url, headers=self.headers)
37+
job_state = await job_state.json()
38+
39+
if is_completed(job_state["status"]):
40+
self.job_status = job_state["status"]
41+
self.job_output = job_state.get("output", None)
42+
43+
return job_state
2744

2845
async def status(self) -> str:
2946
"""Gets jobs' status
3047
3148
Returns:
3249
COMPLETED, FAILED or IN_PROGRESS
3350
"""
34-
async with self.session.get(self.status_url, headers=self.headers) as resp:
35-
return (await resp.json())["status"]
51+
if self.job_status is not None:
52+
return self.job_status
53+
54+
job_state = await self._fetch_job()
55+
return job_state["status"]
56+
57+
async def _wait_for_completion(self):
58+
while not is_completed(await self.status()):
59+
await asyncio.sleep(1)
3660

37-
async def output(self) -> any:
61+
async def output(self, timeout: int = 0) -> Any:
3862
"""Waits for serverless API job to complete or fail
3963
4064
Returns:
4165
Output of job
4266
Raises:
4367
KeyError if job Failed
4468
"""
45-
while await self.status() not in ["COMPLETED", "FAILED"]:
46-
await asyncio.sleep(1)
69+
if self.job_output is not None:
70+
return self.job_output
71+
72+
try:
73+
await asyncio.wait_for(self._wait_for_completion(), timeout)
74+
except asyncio.TimeoutError as exc:
75+
raise TimeoutError("Job timed out.") from exc
4776

48-
async with self.session.get(self.status_url, headers=self.headers) as resp:
49-
return (await resp.json())["output"]
77+
job_data = await self._fetch_job()
78+
return job_data.get("output", None)
79+
80+
async def stream(self) -> Any:
81+
""" Returns a generator that yields the output of the job request. """
82+
while True:
83+
await asyncio.sleep(1)
84+
stream_partial = await self._fetch_job(source="stream")
85+
if stream_partial["status"] not in FINAL_STATES:
86+
for chunk in stream_partial.get("stream", []):
87+
yield chunk["output"]
5088

5189
async def cancel(self) -> dict:
5290
"""Cancels current job
5391
5492
Returns:
5593
Output of cancel operation
5694
"""
57-
58-
async with self.session.post(self.cancel_url, headers=self.headers) as resp:
95+
cancel_url = f"{self.endpoint_url_base}/{self.endpoint_id}/cancel/{self.job_id}"
96+
async with self.session.post(cancel_url, headers=self.headers) as resp:
5997
return await resp.json()
6098

6199

@@ -88,3 +126,21 @@ async def run(self, endpoint_input: dict) -> Job:
88126
json_resp = await resp.json()
89127

90128
return Job(self.endpoint_id, json_resp["id"], self.session)
129+
130+
async def health(self) -> dict:
131+
"""Checks health of endpoint
132+
133+
Returns:
134+
Health of endpoint
135+
"""
136+
async with self.session.get(f"{self.endpoint_id}/health", headers=self.headers) as resp:
137+
return await resp.json()
138+
139+
async def purge_queue(self) -> dict:
140+
"""Purges queue of endpoint
141+
142+
Returns:
143+
Purge status
144+
"""
145+
async with self.session.post(f"{self.endpoint_id}/purge", headers=self.headers) as resp:
146+
return await resp.json()

runpod/endpoint/helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
""" Helper functions for the RunPod Endpoint API. """
2+
3+
FINAL_STATES = ["COMPLETED", "FAILED", "TIMED_OUT"]
4+
5+
# Exception Messages
6+
UNAUTHORIZED_MSG = "401 Unauthorized | Make sure Runpod API key is set and valid."
7+
API_KEY_NOT_SET_MSG = ("Expected `run_pod.api_key` to be initialized. "
8+
"You can solve this by setting `run_pod.api_key = 'your-key'. "
9+
"An API key can be generated at "
10+
"https://runpod.io/console/user/settings")
11+
12+
def is_completed(status: str) -> bool:
13+
"""Returns true if status is one of the possible final states for a serverless request."""
14+
return status in ["COMPLETED", "FAILED", "TIMED_OUT", "CANCELLED"]

runpod/endpoint/runner.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,9 @@
66
import requests
77
from requests.adapters import HTTPAdapter, Retry
88

9-
FINAL_STATES = ["COMPLETED", "FAILED", "TIMED_OUT"]
10-
11-
# Exception Messages
12-
UNAUTHORIZED_MSG = "401 Unauthorized | Make sure Runpod API key is set and valid."
13-
API_KEY_NOT_SET_MSG = ("Expected `run_pod.api_key` to be initialized. "
14-
"You can solve this by setting `run_pod.api_key = 'your-key'. "
15-
"An API key can be generated at "
16-
"https://runpod.io/console/user/settings")
17-
18-
19-
def is_completed(status: str) -> bool:
20-
"""Returns true if status is one of the possible final states for a serverless request."""
21-
return status in ["COMPLETED", "FAILED", "TIMED_OUT", "CANCELLED"]
9+
from runpod.endpoint.helpers import (
10+
FINAL_STATES, UNAUTHORIZED_MSG, API_KEY_NOT_SET_MSG, is_completed
11+
)
2212

2313

2414
# ---------------------------------------------------------------------------- #
@@ -146,16 +136,6 @@ def output(self, timeout: int = 0) -> Any:
146136

147137
return self._fetch_job().get("output", None)
148138

149-
def cancel(self, timeout: int = 3) -> Any:
150-
"""
151-
Cancels the job and returns the result of the cancellation request.
152-
153-
Args:
154-
timeout: The number of seconds to wait for the server to respond before giving up.
155-
"""
156-
return self.rp_client.post(f"{self.endpoint_id}/cancel/{self.job_id}",
157-
data=None, timeout=timeout)
158-
159139
def stream(self) -> Any:
160140
""" Returns a generator that yields the output of the job request. """
161141
while True:
@@ -167,6 +147,16 @@ def stream(self) -> Any:
167147
elif stream_partial["status"] in FINAL_STATES:
168148
break
169149

150+
def cancel(self, timeout: int = 3) -> Any:
151+
"""
152+
Cancels the job and returns the result of the cancellation request.
153+
154+
Args:
155+
timeout: The number of seconds to wait for the server to respond before giving up.
156+
"""
157+
return self.rp_client.post(f"{self.endpoint_id}/cancel/{self.job_id}",
158+
data=None, timeout=timeout)
159+
170160

171161
# ---------------------------------------------------------------------------- #
172162
# Endpoint #

tests/test_endpoint/test_asyncio_runner.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tracemalloc
55
import asyncio
66
import unittest
7-
from unittest.mock import patch, MagicMock
7+
from unittest.mock import patch, MagicMock, AsyncMock
88
from unittest import IsolatedAsyncioTestCase
99

1010
from runpod.endpoint.asyncio.asyncio_runner import Job, Endpoint
@@ -19,37 +19,91 @@ async def test_status(self):
1919
'''
2020
Tests Job.status
2121
'''
22-
with patch("aiohttp.ClientSession") as mock_session:
23-
mock_resp = MagicMock()
24-
mock_resp.json = MagicMock(return_value=asyncio.Future())
25-
mock_resp.json.return_value.set_result({"status": "COMPLETED"})
26-
mock_session.get.return_value.__aenter__.return_value = mock_resp
22+
with patch("aiohttp.ClientSession", new_callable=AsyncMock) as mock_session_class:
23+
mock_session = mock_session_class.return_value
24+
mock_get = mock_session.get
25+
mock_resp = AsyncMock()
26+
27+
mock_resp.json.return_value = {"status": "COMPLETED"}
28+
mock_get.return_value = mock_resp
2729

2830
job = Job("endpoint_id", "job_id", mock_session)
2931
status = await job.status()
3032
assert status == "COMPLETED"
33+
assert await job.status() == "COMPLETED"
3134

3235
async def test_output(self):
3336
'''
3437
Tests Job.output
3538
'''
3639
with patch("runpod.endpoint.asyncio.asyncio_runner.asyncio.sleep") as mock_sleep, \
37-
patch("aiohttp.ClientSession") as mock_session:
38-
mock_resp = MagicMock()
40+
patch("aiohttp.ClientSession", new_callable=AsyncMock) as mock_session_class:
41+
mock_session = mock_session_class.return_value
42+
mock_get = mock_session.get
43+
mock_resp = AsyncMock()
3944

4045
async def json_side_effect():
4146
if mock_sleep.call_count == 0:
4247
return {"status": "IN_PROGRESS"}
4348
return {"output": "OUTPUT", "status": "COMPLETED"}
4449

45-
mock_resp.json = json_side_effect
46-
mock_session.get.return_value.__aenter__.return_value = mock_resp
50+
mock_resp.json.side_effect = json_side_effect
51+
mock_get.return_value = mock_resp
4752

4853
job = Job("endpoint_id", "job_id", mock_session)
49-
output_task = asyncio.create_task(job.output())
54+
output_task = asyncio.create_task(job.output(timeout=3))
5055

5156
output = await output_task
5257
assert output == "OUTPUT"
58+
assert await job.output() == "OUTPUT"
59+
60+
async def test_output_timeout(self):
61+
'''
62+
Tests Job.output with a timeout
63+
'''
64+
with patch("aiohttp.ClientSession", new_callable=AsyncMock) as mock_session_class:
65+
mock_session = mock_session_class.return_value
66+
mock_get = mock_session.get
67+
mock_resp = AsyncMock()
68+
69+
mock_resp.json.return_value = {"status": "IN_PROGRESS"}
70+
mock_get.return_value = mock_resp
71+
72+
job = Job("endpoint_id", "job_id", mock_session)
73+
output_task = asyncio.create_task(job.output(timeout=1))
74+
75+
with self.assertRaises(TimeoutError):
76+
await output_task
77+
78+
async def test_stream(self):
79+
'''
80+
Tests Job.stream
81+
'''
82+
with patch("aiohttp.ClientSession", new_callable=AsyncMock) as mock_session_class:
83+
mock_session = mock_session_class.return_value
84+
mock_get = mock_session.get
85+
mock_resp = AsyncMock()
86+
87+
responses = [
88+
{"stream": [{"output": "OUTPUT1"}], "status": "IN_PROGRESS"},
89+
{"stream": [{"output": "OUTPUT2"}], "status": "IN_PROGRESS"},
90+
]
91+
92+
async def json_side_effect():
93+
return responses.pop(0) if responses else {"stream": [], "status": "COMPLETED"}
94+
95+
mock_resp.json.side_effect = json_side_effect
96+
mock_get.return_value = mock_resp
97+
98+
job = Job("endpoint_id", "job_id", mock_session)
99+
100+
outputs = []
101+
async for stream_output in job.stream():
102+
outputs.append(stream_output)
103+
if not responses: # Break the loop when responses are exhausted
104+
break
105+
106+
assert outputs == ["OUTPUT1", "OUTPUT2"]
53107

54108
async def test_cancel(self):
55109
'''
@@ -67,27 +121,25 @@ async def test_cancel(self):
67121

68122
async def test_output_in_progress_then_completed(self):
69123
'''Tests Job.output when status is initially IN_PROGRESS and then changes to COMPLETED'''
70-
with patch("runpod.endpoint.asyncio.asyncio_runner.asyncio.sleep") as mock_sleep, \
71-
patch("aiohttp.ClientSession") as mock_session:
72-
mock_resp = MagicMock()
124+
with patch("aiohttp.ClientSession", new_callable=AsyncMock) as mock_session_class:
125+
mock_session = mock_session_class.return_value
126+
mock_get = mock_session.get
127+
mock_resp = AsyncMock()
128+
73129
responses = [
74130
{"status": "IN_PROGRESS"},
75-
{"status": "COMPLETED"},
76-
{"output": "OUTPUT"}
131+
{"status": "COMPLETED", "output": "OUTPUT"}
77132
]
78133

79134
async def json_side_effect():
80-
if responses:
81-
return responses.pop(0)
82-
return {"status": "IN_PROGRESS"}
135+
return responses.pop(0) if responses else {"status": "COMPLETED", "output": "OUTPUT"} # pylint: disable=line-too-long
83136

84-
mock_resp.json = json_side_effect
85-
mock_session.get.return_value.__aenter__.return_value = mock_resp
137+
mock_resp.json.side_effect = json_side_effect
138+
mock_get.return_value = mock_resp
86139

87140
job = Job("endpoint_id", "job_id", mock_session)
88-
output = await job.output()
141+
output = await job.output(timeout=3)
89142
assert output == "OUTPUT"
90-
mock_sleep.assert_called_once_with(1)
91143

92144
class TestEndpoint(IsolatedAsyncioTestCase):
93145
''' Unit tests for the Endpoint class. '''
@@ -106,6 +158,34 @@ async def test_run(self):
106158
job = await endpoint.run({"input": "INPUT"})
107159
assert job.job_id == "job_id"
108160

161+
async def test_health(self):
162+
'''
163+
Tests Endpoint.health
164+
'''
165+
with patch("aiohttp.ClientSession") as mock_session:
166+
mock_resp = MagicMock()
167+
mock_resp.json = MagicMock(return_value=asyncio.Future())
168+
mock_resp.json.return_value.set_result({"status": "HEALTHY"})
169+
mock_session.get.return_value.__aenter__.return_value = mock_resp
170+
171+
endpoint = Endpoint("endpoint_id", mock_session)
172+
health = await endpoint.health()
173+
assert health == {"status": "HEALTHY"}
174+
175+
async def test_purge_queue(self):
176+
'''
177+
Tests Endpoint.purge_queue
178+
'''
179+
with patch("aiohttp.ClientSession") as mock_session:
180+
mock_resp = MagicMock()
181+
mock_resp.json = MagicMock(return_value=asyncio.Future())
182+
mock_resp.json.return_value.set_result({"result": "PURGED"})
183+
mock_session.post.return_value.__aenter__.return_value = mock_resp
184+
185+
endpoint = Endpoint("endpoint_id", mock_session)
186+
purge_result = await endpoint.purge_queue()
187+
assert purge_result == {"result": "PURGED"}
188+
109189
class TestEndpointInitialization(unittest.TestCase):
110190
'''Tests for the Endpoint class initialization.'''
111191

0 commit comments

Comments
 (0)