Skip to content

Commit 283acec

Browse files
committed
fix: change async to threading
1 parent 8bf0170 commit 283acec

2 files changed

Lines changed: 50 additions & 50 deletions

File tree

runpod/serverless/modules/rp_fastapi.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import os
55
import uuid
6-
import asyncio
6+
import time
7+
import threading
78
from dataclasses import dataclass
89
from typing import Union, Optional, Dict, Any
910

10-
import aiohttp
1111
import uvicorn
12+
import requests
1213
from fastapi import FastAPI, APIRouter
1314
from fastapi.encoders import jsonable_encoder
1415
from fastapi.responses import RedirectResponse
@@ -85,31 +86,31 @@ class StreamOutput:
8586

8687

8788
# ------------------------------ Webhook Sender ------------------------------ #
88-
async def _send_webhook_async(url: str, payload: Dict[str, Any]) -> None:
89+
def _send_webhook(url: str, payload: Dict[str, Any]) -> None:
8990
"""
90-
Sends a webhook to the provided URL asynchronously. Retries once if the first attempt fails.
91+
Sends a webhook to the provided URL. Retries once if the first attempt fails.
9192
9293
Args:
9394
url (str): The URL to send the webhook to.
9495
payload (Dict[str, Any]): The JSON payload to send.
9596
"""
96-
async def attempt_send(session, url, payload):
97+
def attempt_send(session, url, payload):
9798
try:
98-
async with session.post(url, json=payload, timeout=10) as response:
99-
response.raise_for_status() # Raises exception for 4xx/5xx responses
100-
return True
101-
except (aiohttp.ClientError, aiohttp.http_exceptions.HttpProcessingError) as err:
99+
response = session.post(url, json=payload, timeout=10)
100+
response.raise_for_status() # Raises exception for 4xx/5xx responses
101+
return True
102+
except requests.RequestException as err:
102103
print(f"Request to {url} failed: {err}")
103104
return False
104105

105-
async with aiohttp.ClientSession() as session:
106-
if await attempt_send(session, url, payload):
106+
with requests.Session() as session:
107+
if attempt_send(session, url, payload):
107108
return True
108109

109110
print("Retrying...")
110-
await asyncio.sleep(1) # Wait for 1 second before retrying
111+
time.sleep(1) # Wait for 1 second before retrying
111112

112-
if await attempt_send(session, url, payload):
113+
if attempt_send(session, url, payload):
113114
return True
114115

115116
print("Failed to send webhook after retry.")
@@ -240,7 +241,10 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
240241
})
241242

242243
if job_request.webhook:
243-
asyncio.create_task(_send_webhook_async(job_request.webhook, job_output))
244+
thread = threading.Thread(
245+
target=_send_webhook,
246+
args=(job_request.webhook, job_output), daemon=True)
247+
thread.start()
244248

245249
return jsonable_encoder({
246250
"id": job.id,
@@ -276,7 +280,10 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
276280
job_list.remove_job(job.id)
277281

278282
if stashed_job.webhook:
279-
asyncio.create_task(_send_webhook_async(stashed_job.webhook, stream_accumulator))
283+
thread = threading.Thread(
284+
target=_send_webhook,
285+
args=(stashed_job.webhook, stream_accumulator), daemon=True)
286+
thread.start()
280287

281288
return jsonable_encoder({
282289
"id": job_id,
@@ -315,7 +322,10 @@ async def _sim_status(self, job_id: str) -> JobOutput:
315322
})
316323

317324
if stashed_job.webhook:
318-
asyncio.create_task(_send_webhook_async(stashed_job.webhook, job_output))
325+
thread = threading.Thread(
326+
target=_send_webhook,
327+
args=(stashed_job.webhook, job_output), daemon=True)
328+
thread.start()
319329

320330
return jsonable_encoder({
321331
"id": job_id,

tests/test_serverless/test_modules/test_fastapi.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import asyncio
66

77
import unittest
8-
from unittest.mock import patch, Mock, AsyncMock
8+
from unittest.mock import patch, Mock, MagicMock
99
import pytest
1010

11-
import aiohttp
11+
import requests
1212
import runpod
1313
from runpod.serverless.modules import rp_fastapi
1414

@@ -54,43 +54,31 @@ def test_start_serverless_with_realtime(self):
5454
os.environ.pop("RUNPOD_REALTIME_PORT")
5555
os.environ.pop("RUNPOD_ENDPOINT_ID")
5656

57-
@pytest.mark.asyncio
5857
def test_webhook_sender_success(self):
5958
"""Test the webhook sender when the request is successful."""
60-
loop = asyncio.get_event_loop()
61-
62-
module_location = "runpod.serverless.modules.rp_fastapi.aiohttp.ClientSession"
59+
module_location = "runpod.serverless.modules.rp_fastapi.requests.Session.post"
6360

64-
with patch(f"{module_location}.post", new_callable=AsyncMock) as mock_post:
61+
with patch(module_location, new_callable=MagicMock) as mock_post:
6562
# Simulate a successful response
66-
mock_post.return_value.__aenter__.return_value.status = 200
63+
mock_post.return_value.status_code = 200
6764

68-
# Directly await the function
69-
success = asyncio.run(rp_fastapi._send_webhook_async(
70-
"test_webhook", {"test": "output"}))
65+
# Call the function
66+
success = rp_fastapi._send_webhook("test_webhook", {"test": "output"})
7167
assert success is True
7268

73-
loop.close()
74-
75-
@pytest.mark.asyncio
7669
def test_webhook_sender_failure(self):
7770
"""Test the webhook sender when the request fails."""
78-
loop = asyncio.get_event_loop()
79-
80-
module_location = "runpod.serverless.modules.rp_fastapi.aiohttp.ClientSession"
71+
module_location = "runpod.serverless.modules.rp_fastapi.requests.Session.post"
8172

82-
with patch(f"{module_location}.post", new_callable=AsyncMock) as mock_post:
83-
# Configure the mock to raise an exception to simulate a 500 error
84-
mock_post.return_value.__aenter__.return_value.raise_for_status.side_effect = aiohttp.ClientResponseError( # pylint: disable=line-too-long
85-
request_info=None, history=None, status=500)
73+
with patch(module_location, new_callable=MagicMock) as mock_post:
74+
# Configure the mock to simulate a failure (e.g., a 500 status code)
75+
mock_post.return_value.raise_for_status.side_effect = requests.HTTPError()
76+
mock_post.return_value.status_code = 500
8677

87-
# Directly await the function
88-
success = asyncio.run(rp_fastapi._send_webhook_async(
89-
"test_webhook", {"test": "output"}))
78+
# Call the function
79+
success = rp_fastapi._send_webhook("test_webhook", {"test": "output"})
9080
assert success is False
9181

92-
loop.close()
93-
9482
@pytest.mark.asyncio
9583
def test_run(self):
9684
'''
@@ -195,9 +183,9 @@ def generator_handler(job):
195183
assert "error" in error_runsync_return
196184

197185
# Test webhook caller sent
198-
with patch(f"{module_location}._send_webhook_async", return_value=True) as mock_send_webhook: # pylint: disable=line-too-long
186+
with patch(f"{module_location}.threading") as mock_thread:
199187
asyncio.run(worker_api._sim_runsync(input_object_with_webhook))
200-
assert mock_send_webhook.called
188+
assert mock_thread.Thread.called
201189

202190
loop.close()
203191

@@ -212,7 +200,8 @@ def test_stream(self):
212200
with patch(f"{module_location}.FastAPI", Mock()), \
213201
patch(f"{module_location}.APIRouter", return_value=Mock()), \
214202
patch(f"{module_location}.uvicorn", Mock()), \
215-
patch(f"{module_location}.uuid.uuid4", return_value="123"):
203+
patch(f"{module_location}.uuid.uuid4", return_value="123"), \
204+
patch(f"{module_location}.threading"):
216205

217206
default_input_object = rp_fastapi.DefaultRequest(
218207
input={"test_input": "test_input"}
@@ -243,10 +232,10 @@ def test_stream(self):
243232
}
244233

245234
# Test webhook caller sent
246-
with patch(f"{module_location}._send_webhook_async", return_value=True) as mock_send_webhook: # pylint: disable=line-too-long
235+
with patch(f"{module_location}.threading", return_value=True) as mock_threading:
247236
asyncio.run(worker_api._sim_run(input_object_with_webhook))
248237
asyncio.run(worker_api._sim_stream("test-123"))
249-
assert mock_send_webhook.called
238+
assert mock_threading.Thread.called
250239

251240
# Test with generator handler
252241
def generator_handler(job):
@@ -275,7 +264,8 @@ def test_status(self):
275264
with patch(f"{module_location}.FastAPI", Mock()), \
276265
patch(f"{module_location}.APIRouter", return_value=Mock()), \
277266
patch(f"{module_location}.uvicorn", Mock()), \
278-
patch(f"{module_location}.uuid.uuid4", return_value="123"):
267+
patch(f"{module_location}.uuid.uuid4", return_value="123"), \
268+
patch(f"{module_location}.threading"):
279269

280270
worker_api = rp_fastapi.WorkerAPI({"handler": self.handler})
281271

@@ -306,10 +296,10 @@ def test_status(self):
306296
}
307297

308298
# Test webhook caller sent
309-
with patch(f"{module_location}._send_webhook_async", return_value=True) as mock_send_webhook: # pylint: disable=line-too-long
299+
with patch(f"{module_location}.threading", return_value=True) as mock_threading:
310300
asyncio.run(worker_api._sim_run(input_object_with_webhook))
311301
asyncio.run(worker_api._sim_status("test-123"))
312-
assert mock_send_webhook.called
302+
assert mock_threading.Thread.called
313303

314304
# Test with generator handler
315305
def generator_handler(job):

0 commit comments

Comments
 (0)