Skip to content

Commit 7e75012

Browse files
committed
fix: enable SLS_CORE with envvar RUNPOD_SLS_CORE=1
1 parent a433a29 commit 7e75012

2 files changed

Lines changed: 33 additions & 21 deletions

File tree

runpod/serverless/__init__.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,22 @@
44
Arguments can be passed in when the worker is started, and will be passed to the worker.
55
"""
66

7+
import argparse
8+
import json
79
import os
10+
import signal
811
import sys
9-
import json
1012
import time
11-
import signal
12-
import argparse
13-
from typing import Dict, Any
13+
import typing
14+
from typing import Any, Dict
1415

1516
from runpod.serverless import core
17+
18+
from ..version import __version__ as runpod_version
1619
from . import worker
1720
from .modules import rp_fastapi
1821
from .modules.rp_logger import RunPodLogger
1922
from .modules.rp_progress import progress_update
20-
from ..version import __version__ as runpod_version
21-
2223

2324
log = RunPodLogger()
2425

@@ -65,11 +66,13 @@ def _set_config_args(config) -> dict:
6566

6667
# Parse the test input from JSON
6768
if config["rp_args"]["test_input"]:
68-
config["rp_args"]["test_input"] = json.loads(config["rp_args"]["test_input"])
69+
config["rp_args"]["test_input"] = json.loads(
70+
config["rp_args"]["test_input"])
6971

7072
# Parse the test output from JSON
7173
if config["rp_args"].get("test_output", None):
72-
config["rp_args"]["test_output"] = json.loads(config["rp_args"]["test_output"])
74+
config["rp_args"]["test_output"] = json.loads(
75+
config["rp_args"]["test_output"])
7376

7477
# Set the log level
7578
if config["rp_args"]["rp_log_level"]:
@@ -133,8 +136,9 @@ def start(config: Dict[str, Any]):
133136
api_port=config['rp_args']['rp_api_port'],
134137
api_concurrency=config['rp_args']['rp_api_concurrency']
135138
)
139+
return
136140

137-
elif realtime_port:
141+
if realtime_port:
138142
log.info(f"Starting API server for realtime on port {realtime_port}.")
139143
api_server = rp_fastapi.WorkerAPI(config)
140144

@@ -143,12 +147,13 @@ def start(config: Dict[str, Any]):
143147
api_port=realtime_port,
144148
api_concurrency=realtime_concurrency
145149
)
150+
return
146151

147152
# --------------------------------- SLS-Core --------------------------------- #
148-
elif os.environ.get("RUNPOD_USE_CORE", None) or os.environ.get("RUNPOD_CORE_PATH", None):
149-
log.info("Starting worker with SLS-Core.")
153+
if os.getenv("RUNPOD_SLS_CORE", "false").lower() in ("1", 't', 'T', 'TRUE', 'true', 'True', '0', 'f', 'F', 'FALSE', 'false', 'False'):
150154
core.main(config)
155+
return
151156

152157
# --------------------------------- Standard --------------------------------- #
153-
else:
154-
worker.main(config)
158+
worker.main(config)
159+
return

runpod/serverless/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
from ctypes import CDLL, byref, c_char_p, c_int
1010
from typing import Any, Callable, List, Dict, Optional
11+
import typing
1112

1213
from runpod.version import __version__ as runpod_version
1314
from runpod.serverless.modules.rp_logger import RunPodLogger
@@ -31,28 +32,35 @@ def __str__(self) -> str:
3132
return f"CGetJobResult(res_len={self.res_len}, status_code={self.status_code})"
3233

3334

35+
def notregistered():
36+
""" Function to raise NotImplementedError """
37+
raise RuntimeError("This function is not registered with the SLS Core.")
3438
class Hook: # pylint: disable=too-many-instance-attributes
3539
""" Singleton class for interacting with sls_core.so"""
3640

3741
_instance = None
3842

43+
3944
# C function pointers
40-
_get_jobs: Callable = None
41-
_progress_update: Callable = None
42-
_stream_output: Callable = None
43-
_post_output: Callable = None
44-
_finish_stream: Callable = None
45+
_get_jobs: Callable = notregistered
46+
_progress_update: Callable = notregistered
47+
_stream_output: Callable = notregistered
48+
_post_output: Callable = notregistered
49+
_finish_stream: Callable = notregistered
4550

4651
def __new__(cls):
4752
if Hook._instance is None:
4853
log.debug("SLS Core | Initializing Hook.")
4954
Hook._instance = object.__new__(cls)
5055
Hook._initialized = False
56+
5157
return Hook._instance
5258

5359
def __init__(self, rust_so_path: Optional[str] = None) -> None:
60+
5461
if self._initialized:
5562
return
63+
5664

5765
if rust_so_path is None:
5866
default_path = os.path.join(
@@ -169,8 +177,6 @@ def finish_stream(self, job_id: str) -> bool:
169177
return bool(self._finish_stream(
170178
c_char_p(id_bytes), c_int(len(id_bytes))
171179
))
172-
173-
174180
# -------------------------------- Process Job ------------------------------- #
175181
async def _process_job(config: Dict[str, Any], job: Dict[str, Any], hook) -> Dict[str, Any]:
176182
""" Process a single job. """
@@ -181,7 +187,7 @@ async def _process_job(config: Dict[str, Any], job: Dict[str, Any], hook) -> Dic
181187
if inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler):
182188
log.debug("SLS Core | Running job as a generator.")
183189
generator_output = rp_job.run_job_generator(handler, job)
184-
aggregated_output = {'output': []}
190+
aggregated_output: dict[str, typing.Any] = {'output': []}
185191

186192
async for part in generator_output:
187193
log.debug(f"SLS Core | Streaming output: {part}", job['id'])
@@ -210,6 +216,7 @@ async def _process_job(config: Dict[str, Any], job: Dict[str, Any], hook) -> Dic
210216
finally:
211217
log.debug(f"SLS Core | Posting output: {result}", job['id'])
212218
hook.post_output(job['id'], result)
219+
return result
213220

214221

215222
# ---------------------------------------------------------------------------- #

0 commit comments

Comments
 (0)