Skip to content

Commit 0b6e703

Browse files
Merge pull request #281 from runpod/flash-boot
feat: add flash boot enable to endpoints
2 parents 5d9f9c5 + d52cc03 commit 0b6e703

4 files changed

Lines changed: 12 additions & 7 deletions

File tree

examples/api/create_endpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
new_template = runpod.create_template(
1111
name="test",
12-
image_name="runpod/base:0.1.0",
12+
image_name="runpod/base:0.4.4",
1313
is_serverless=True
1414
)
1515

@@ -20,7 +20,8 @@
2020
template_id=new_template["id"],
2121
gpu_ids="AMPERE_16",
2222
workers_min=0,
23-
workers_max=1
23+
workers_max=1,
24+
flash_boot=True
2425
)
2526

2627
print(new_endpoint)

runpod/api/ctl_commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def create_endpoint(
249249
name: str, template_id: str, gpu_ids: str = "AMPERE_16",
250250
network_volume_id: str = None, locations: str = None,
251251
idle_timeout: int = 5, scaler_type: str = "QUEUE_DELAY", scaler_value: int = 4,
252-
workers_min: int = 0, workers_max: int = 3
252+
workers_min: int = 0, workers_max: int = 3, flash_boot=False
253253
):
254254
'''
255255
Create an endpoint
@@ -274,7 +274,7 @@ def create_endpoint(
274274
name, template_id, gpu_ids,
275275
network_volume_id, locations,
276276
idle_timeout, scaler_type, scaler_value,
277-
workers_min, workers_max
277+
workers_min, workers_max, flash_boot
278278
)
279279
)
280280

runpod/api/mutations/endpoints.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ def generate_endpoint_mutation(
77
name: str, template_id: str, gpu_ids: str = "AMPERE_16",
88
network_volume_id: str = None, locations: str = None,
99
idle_timeout: int = 5, scaler_type: str = "QUEUE_DELAY", scaler_value: int = 4,
10-
workers_min: int = 0, workers_max: int = 3
10+
workers_min: int = 0, workers_max: int = 3, flash_boot=False
1111
):
1212
""" Generate a string for a GraphQL mutation to create a new endpoint. """
1313
input_fields = []
1414

1515
# ------------------------------ Required Fields ----------------------------- #
16+
if flash_boot:
17+
name = name + "-fb"
18+
1619
input_fields.append(f'name: "{name}"')
1720
input_fields.append(f'templateId: "{template_id}"')
1821
input_fields.append(f'gpuIds: "{gpu_ids}"')

tests/test_api/test_mutation_endpoints.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from runpod.api.mutations.endpoints import generate_endpoint_mutation
66

7+
78
class TestGenerateEndpointMutation(unittest.TestCase):
89
"""Tests for the endpoint mutation generation."""
910

@@ -20,9 +21,9 @@ def test_all_fields(self):
2021
"""Test all the fields."""
2122
result = generate_endpoint_mutation(
2223
"test_name", "test_template_id", "AMPERE_20",
23-
"test_volume_id", "US_WEST", 10, "WORKER_COUNT", 5, 2, 4
24+
"test_volume_id", "US_WEST", 10, "WORKER_COUNT", 5, 2, 4, True
2425
)
25-
self.assertIn('name: "test_name"', result)
26+
self.assertIn('name: "test_name-fb"', result)
2627
self.assertIn('templateId: "test_template_id"', result)
2728
self.assertIn('gpuIds: "AMPERE_20"', result)
2829
self.assertIn('networkVolumeId: "test_volume_id"', result)

0 commit comments

Comments
 (0)