Skip to content

Commit 9e11c99

Browse files
Merge pull request #295 from runpod/280-enable-webhooks-for-local-testing-of-serverless-workers
280 enable webhooks for local testing of serverless workers
2 parents 5a396ed + 5f2f7c6 commit 9e11c99

4 files changed

Lines changed: 153 additions & 34 deletions

File tree

runpod/serverless/modules/rp_fastapi.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33

44
import os
55
import uuid
6+
import threading
67
from dataclasses import dataclass
78
from typing import Union, Optional, Dict, Any
89

910
import uvicorn
11+
import requests
1012
from fastapi import FastAPI, APIRouter
1113
from fastapi.encoders import jsonable_encoder
1214
from fastapi.responses import RedirectResponse
@@ -53,12 +55,14 @@ class TestJob:
5355
'''
5456
id: Optional[str] = None
5557
input: Optional[Union[dict, list, str, int, float, bool]] = None
58+
webhook: Optional[str] = None
5659

5760

5861
@dataclass
59-
class DefaultInput:
62+
class DefaultRequest:
6063
""" Represents a test input. """
6164
input: Dict[str, Any]
65+
webhook: Optional[str] = None
6266

6367

6468
# ------------------------------ Output Objects ------------------------------ #
@@ -80,6 +84,28 @@ class StreamOutput:
8084
error: Optional[str] = None
8185

8286

87+
# ------------------------------ Webhook Sender ------------------------------ #
88+
def _send_webhook(url: str, payload: Dict[str, Any]) -> bool:
89+
"""
90+
Sends a webhook to the provided URL.
91+
92+
Args:
93+
url (str): The URL to send the webhook to.
94+
payload (Dict[str, Any]): The JSON payload to send.
95+
96+
Returns:
97+
bool: True if the request was successful, False otherwise.
98+
"""
99+
with requests.Session() as session:
100+
try:
101+
response = session.post(url, json=payload, timeout=10)
102+
response.raise_for_status() # Raises exception for 4xx/5xx responses
103+
return True
104+
except requests.RequestException as err:
105+
print(f"Request to {url} failed: {err}")
106+
return False
107+
108+
83109
# ---------------------------------------------------------------------------- #
84110
# API Worker #
85111
# ---------------------------------------------------------------------------- #
@@ -176,17 +202,17 @@ async def _realtime(self, job: Job):
176202
# ---------------------------------------------------------------------------- #
177203

178204
# ------------------------------------ run ----------------------------------- #
179-
async def _sim_run(self, job_input: DefaultInput) -> JobOutput:
205+
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
180206
""" Development endpoint to simulate run behavior. """
181207
assigned_job_id = f"test-{uuid.uuid4()}"
182-
job_list.add_job(assigned_job_id, job_input.input)
208+
job_list.add_job(assigned_job_id, job_request.input, job_request.webhook)
183209
return jsonable_encoder({"id": assigned_job_id, "status": "IN_PROGRESS"})
184210

185211
# ---------------------------------- runsync --------------------------------- #
186-
async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput:
212+
async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
187213
""" Development endpoint to simulate runsync behavior. """
188214
assigned_job_id = f"test-{uuid.uuid4()}"
189-
job = TestJob(id=assigned_job_id, input=job_input.input)
215+
job = TestJob(id=assigned_job_id, input=job_request.input)
190216

191217
if is_generator(self.config["handler"]):
192218
generator_output = run_job_generator(self.config["handler"], job.__dict__)
@@ -203,6 +229,12 @@ async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput:
203229
"error": job_output['error']
204230
})
205231

232+
if job_request.webhook:
233+
thread = threading.Thread(
234+
target=_send_webhook,
235+
args=(job_request.webhook, job_output), daemon=True)
236+
thread.start()
237+
206238
return jsonable_encoder({
207239
"id": job.id,
208240
"status": "COMPLETED",
@@ -212,15 +244,15 @@ async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput:
212244
# ---------------------------------- stream ---------------------------------- #
213245
async def _sim_stream(self, job_id: str) -> StreamOutput:
214246
""" Development endpoint to simulate stream behavior. """
215-
job_input = job_list.get_job_input(job_id)
216-
if job_input is None:
247+
stashed_job = job_list.get_job(job_id)
248+
if stashed_job is None:
217249
return jsonable_encoder({
218250
"id": job_id,
219251
"status": "FAILED",
220252
"error": "Job ID not found"
221253
})
222254

223-
job = TestJob(id=job_id, input=job_input)
255+
job = TestJob(id=job_id, input=stashed_job.input)
224256

225257
if is_generator(self.config["handler"]):
226258
generator_output = run_job_generator(self.config["handler"], job.__dict__)
@@ -236,6 +268,12 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
236268

237269
job_list.remove_job(job.id)
238270

271+
if stashed_job.webhook:
272+
thread = threading.Thread(
273+
target=_send_webhook,
274+
args=(stashed_job.webhook, stream_accumulator), daemon=True)
275+
thread.start()
276+
239277
return jsonable_encoder({
240278
"id": job_id,
241279
"status": "COMPLETED",
@@ -245,15 +283,15 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
245283
# ---------------------------------- status ---------------------------------- #
246284
async def _sim_status(self, job_id: str) -> JobOutput:
247285
""" Development endpoint to simulate status behavior. """
248-
job_input = job_list.get_job_input(job_id)
249-
if job_input is None:
286+
stashed_job = job_list.get_job(job_id)
287+
if stashed_job is None:
250288
return jsonable_encoder({
251289
"id": job_id,
252290
"status": "FAILED",
253291
"error": "Job ID not found"
254292
})
255293

256-
job = TestJob(id=job_id, input=job_input)
294+
job = TestJob(id=stashed_job.id, input=stashed_job.input)
257295

258296
if is_generator(self.config["handler"]):
259297
generator_output = run_job_generator(self.config["handler"], job.__dict__)
@@ -272,6 +310,12 @@ async def _sim_status(self, job_id: str) -> JobOutput:
272310
"error": job_output['error']
273311
})
274312

313+
if stashed_job.webhook:
314+
thread = threading.Thread(
315+
target=_send_webhook,
316+
args=(stashed_job.webhook, job_output), daemon=True)
317+
thread.start()
318+
275319
return jsonable_encoder({
276320
"id": job_id,
277321
"status": "COMPLETED",

runpod/serverless/modules/worker_state.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,40 @@ def get_auth_header():
2525

2626
# ------------------------------- Job Tracking ------------------------------- #
2727
class Job:
28-
""" Represents a job. """
29-
30-
def __init__(self, job_id: str, job_input: Optional[Dict[str, Any]] = None) -> None:
31-
self.job_id = job_id
32-
self.job_input = job_input
28+
"""
29+
Represents a job object.
30+
31+
Args:
32+
job_id: The id of the job, a unique string.
33+
job_input: The input to the job.
34+
webhook: The webhook to send the job output to.
35+
"""
36+
37+
def __init__(
38+
self,
39+
job_id: str,
40+
job_input: Optional[Dict[str, Any]] = None,
41+
webhook: Optional[str] = None,
42+
) -> None:
43+
self.id = job_id
44+
self.input = job_input
45+
self.webhook = webhook
3346

3447
def __eq__(self, other: object) -> bool:
3548
if isinstance(other, Job):
36-
return self.job_id == other.job_id
49+
return self.id == other.id
3750
return False
3851

3952
def __hash__(self) -> int:
40-
return hash(self.job_id)
53+
return hash(self.id)
4154

4255
def __str__(self) -> str:
43-
return self.job_id
56+
return self.id
4457

4558

59+
# ---------------------------------------------------------------------------- #
60+
# Tracker #
61+
# ---------------------------------------------------------------------------- #
4662
class Jobs:
4763
''' Track the state of current jobs.'''
4864

@@ -55,26 +71,26 @@ def __new__(cls):
5571
Jobs._instance.jobs = set()
5672
return Jobs._instance
5773

58-
def add_job(self, job_id, job_input=None):
74+
def add_job(self, job_id, job_input=None, webhook=None):
5975
'''
6076
Adds a job to the list of jobs.
6177
'''
62-
self.jobs.add(Job(job_id, job_input))
78+
self.jobs.add(Job(job_id, job_input, webhook))
6379

6480
def remove_job(self, job_id):
6581
'''
6682
Removes a job from the list of jobs.
6783
'''
6884
self.jobs.remove(Job(job_id))
6985

70-
def get_job_input(self, job_id) -> Optional[Union[dict, list, str, int, float, bool]]:
86+
def get_job(self, job_id) -> Optional[Union[dict, list, str, int, float, bool]]:
7187
'''
7288
Returns the job with the given id.
7389
Used within rp_fastapi.py for local testing.
7490
'''
7591
for job in self.jobs:
76-
if job.job_id == job_id:
77-
return job.job_input
92+
if job.id == job_id:
93+
return job
7894

7995
return None
8096

tests/test_serverless/test_modules/test_fastapi.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import asyncio
66

77
import unittest
8-
from unittest.mock import patch, Mock
8+
from unittest.mock import patch, Mock, MagicMock
99
import pytest
1010

11+
import requests
1112
import runpod
1213
from runpod.serverless.modules import rp_fastapi
1314

@@ -53,6 +54,31 @@ def test_start_serverless_with_realtime(self):
5354
os.environ.pop("RUNPOD_REALTIME_PORT")
5455
os.environ.pop("RUNPOD_ENDPOINT_ID")
5556

57+
def test_webhook_sender_success(self):
58+
"""Test the webhook sender when the request is successful."""
59+
module_location = "runpod.serverless.modules.rp_fastapi.requests.Session.post"
60+
61+
with patch(module_location, new_callable=MagicMock) as mock_post:
62+
# Simulate a successful response
63+
mock_post.return_value.status_code = 200
64+
65+
# Call the function
66+
success = rp_fastapi._send_webhook("test_webhook", {"test": "output"})
67+
assert success is True
68+
69+
def test_webhook_sender_failure(self):
70+
"""Test the webhook sender when the request fails."""
71+
module_location = "runpod.serverless.modules.rp_fastapi.requests.Session.post"
72+
73+
with patch(module_location, new_callable=MagicMock) as mock_post:
74+
# Configure the mock to simulate a failure (e.g., a 500 status code)
75+
mock_post.return_value.raise_for_status.side_effect = requests.HTTPError()
76+
mock_post.return_value.status_code = 500
77+
78+
# Call the function
79+
success = rp_fastapi._send_webhook("test_webhook", {"test": "output"})
80+
assert success is False
81+
5682
@pytest.mark.asyncio
5783
def test_run(self):
5884
'''
@@ -72,7 +98,7 @@ def test_run(self):
7298
input={"test_input": "test_input"}
7399
)
74100

75-
default_input_object = rp_fastapi.DefaultInput(
101+
default_input_object = rp_fastapi.DefaultRequest(
76102
input={"test_input": "test_input"}
77103
)
78104

@@ -115,12 +141,18 @@ def test_runsync(self):
115141
with patch(f"{module_location}.FastAPI", Mock()), \
116142
patch(f"{module_location}.APIRouter", return_value=Mock()), \
117143
patch(f"{module_location}.uvicorn", Mock()), \
118-
patch(f"{module_location}.uuid.uuid4", return_value="123"):
144+
patch(f"{module_location}.uuid.uuid4", return_value="123"), \
145+
patch(f"{module_location}.threading") as mock_threading:
119146

120-
default_input_object = rp_fastapi.DefaultInput(
147+
default_input_object = rp_fastapi.DefaultRequest(
121148
input={"test_input": "test_input"}
122149
)
123150

151+
input_object_with_webhook = rp_fastapi.DefaultRequest(
152+
input={"test_input": "test_input"},
153+
webhook="test_webhook"
154+
)
155+
124156
# Test with handler
125157
worker_api = rp_fastapi.WorkerAPI({"handler": self.handler})
126158

@@ -151,6 +183,10 @@ def generator_handler(job):
151183
error_worker_api._sim_runsync(default_input_object))
152184
assert "error" in error_runsync_return
153185

186+
# Test webhook caller sent
187+
asyncio.run(worker_api._sim_runsync(input_object_with_webhook))
188+
assert mock_threading.Thread.called
189+
154190
loop.close()
155191

156192
@pytest.mark.asyncio
@@ -164,12 +200,18 @@ def test_stream(self):
164200
with patch(f"{module_location}.FastAPI", Mock()), \
165201
patch(f"{module_location}.APIRouter", return_value=Mock()), \
166202
patch(f"{module_location}.uvicorn", Mock()), \
167-
patch(f"{module_location}.uuid.uuid4", return_value="123"):
203+
patch(f"{module_location}.uuid.uuid4", return_value="123"), \
204+
patch(f"{module_location}.threading") as mock_threading:
168205

169-
default_input_object = rp_fastapi.DefaultInput(
206+
default_input_object = rp_fastapi.DefaultRequest(
170207
input={"test_input": "test_input"}
171208
)
172209

210+
input_object_with_webhook = rp_fastapi.DefaultRequest(
211+
input={"test_input": "test_input"},
212+
webhook="test_webhook"
213+
)
214+
173215
worker_api = rp_fastapi.WorkerAPI({"handler": self.handler})
174216

175217
# Add job to job_list
@@ -203,6 +245,11 @@ def generator_handler(job):
203245
"stream": [{"output": {"result": "success"}}]
204246
}
205247

248+
# Test webhook caller sent
249+
asyncio.run(generator_worker_api._sim_run(input_object_with_webhook))
250+
asyncio.run(generator_worker_api._sim_stream("test-123"))
251+
assert mock_threading.Thread.called
252+
206253
loop.close()
207254

208255
@pytest.mark.asyncio
@@ -216,14 +263,20 @@ def test_status(self):
216263
with patch(f"{module_location}.FastAPI", Mock()), \
217264
patch(f"{module_location}.APIRouter", return_value=Mock()), \
218265
patch(f"{module_location}.uvicorn", Mock()), \
219-
patch(f"{module_location}.uuid.uuid4", return_value="123"):
266+
patch(f"{module_location}.uuid.uuid4", return_value="123"), \
267+
patch(f"{module_location}.threading") as mock_threading:
220268

221269
worker_api = rp_fastapi.WorkerAPI({"handler": self.handler})
222270

223-
default_input_object = rp_fastapi.DefaultInput(
271+
default_input_object = rp_fastapi.DefaultRequest(
224272
input={"test_input": "test_input"}
225273
)
226274

275+
input_object_with_webhook = rp_fastapi.DefaultRequest(
276+
input={"test_input": "test_input"},
277+
webhook="test_webhook"
278+
)
279+
227280
# Add job to job_list
228281
asyncio.run(worker_api._sim_run(default_input_object))
229282

@@ -241,6 +294,11 @@ def test_status(self):
241294
"output": {"result": "success"}
242295
}
243296

297+
# Test webhook caller sent
298+
asyncio.run(worker_api._sim_run(input_object_with_webhook))
299+
asyncio.run(worker_api._sim_status("test-123"))
300+
assert mock_threading.Thread.called
301+
244302
# Test with generator handler
245303
def generator_handler(job):
246304
del job

0 commit comments

Comments
 (0)