Skip to content

Commit 5d9f9c5

Browse files
Merge pull request #278 from runpod/209-support-container-registry-auth-in-create_template
209 support container registry auth in create template
2 parents 9debd76 + f8d26e9 commit 5d9f9c5

4 files changed

Lines changed: 46 additions & 29 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Added
66

77
- Expose bucket name for rp_upload.
8+
- Exposed `containerRegistryAuthId` for template creation.
89

910
### Fixed
1011

runpod/api/ctl_commands.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Templates
1818
from .mutations import templates as template_mutations
1919

20+
2021
def get_user() -> dict:
2122
'''
2223
Get the current user
@@ -25,7 +26,8 @@ def get_user() -> dict:
2526
cleaned_return = raw_response["data"]["myself"]
2627
return cleaned_return
2728

28-
def update_user_settings(pubkey : str) -> dict:
29+
30+
def update_user_settings(pubkey: str) -> dict:
2931
'''
3032
Update the current user
3133
@@ -35,6 +37,7 @@ def update_user_settings(pubkey : str) -> dict:
3537
cleaned_return = raw_response["data"]["updateUserSettings"]
3638
return cleaned_return
3739

40+
3841
def get_gpus() -> dict:
3942
'''
4043
Get all GPU types
@@ -44,7 +47,7 @@ def get_gpus() -> dict:
4447
return cleaned_return
4548

4649

47-
def get_gpu(gpu_id : str, gpu_quantity : int = 1):
50+
def get_gpu(gpu_id: str, gpu_quantity: int = 1):
4851
'''
4952
Get a specific GPU type
5053
@@ -61,6 +64,7 @@ def get_gpu(gpu_id : str, gpu_quantity : int = 1):
6164

6265
return cleaned_return[0]
6366

67+
6468
def get_pods() -> dict:
6569
'''
6670
Get all pods
@@ -69,7 +73,8 @@ def get_pods() -> dict:
6973
cleaned_return = raw_return["data"]["myself"]["pods"]
7074
return cleaned_return
7175

72-
def get_pod(pod_id : str):
76+
77+
def get_pod(pod_id: str):
7378
'''
7479
Get a specific pod
7580
@@ -78,17 +83,18 @@ def get_pod(pod_id : str):
7883
raw_response = run_graphql_query(pod_queries.generate_pod_query(pod_id))
7984
return raw_response["data"]["pod"]
8085

86+
8187
def create_pod(
82-
name:str, image_name:str, gpu_type_id:str,
83-
cloud_type:str="ALL", support_public_ip:bool=True,
84-
start_ssh:bool=True,
85-
data_center_id : Optional[str]=None, country_code:Optional[str]=None,
86-
gpu_count:int=1, volume_in_gb:int=0, container_disk_in_gb:Optional[int]=None,
87-
min_vcpu_count:int=1, min_memory_in_gb:int=1, docker_args:str="",
88-
ports:Optional[str]=None, volume_mount_path:str="/runpod-volume",
89-
env:Optional[dict]=None, template_id:Optional[str]=None,
90-
network_volume_id:Optional[str]=None
91-
) -> dict:
88+
name: str, image_name: str, gpu_type_id: str,
89+
cloud_type: str = "ALL", support_public_ip: bool = True,
90+
start_ssh: bool = True,
91+
data_center_id: Optional[str] = None, country_code: Optional[str] = None,
92+
gpu_count: int = 1, volume_in_gb: int = 0, container_disk_in_gb: Optional[int] = None,
93+
min_vcpu_count: int = 1, min_memory_in_gb: int = 1, docker_args: str = "",
94+
ports: Optional[str] = None, volume_mount_path: str = "/runpod-volume",
95+
env: Optional[dict] = None, template_id: Optional[str] = None,
96+
network_volume_id: Optional[str] = None
97+
) -> dict:
9298
'''
9399
Create a pod
94100
@@ -112,7 +118,7 @@ def create_pod(
112118
>>> pod_id = runpod.create_pod("test", "runpod/stack", "NVIDIA GeForce RTX 3070")
113119
'''
114120
# Input Validation
115-
get_gpu(gpu_type_id) # Check if GPU exists, will raise ValueError if not.
121+
get_gpu(gpu_type_id) # Check if GPU exists, will raise ValueError if not.
116122
if cloud_type not in ["ALL", "COMMUNITY", "SECURE"]:
117123
raise ValueError("cloud_type must be one of ALL, COMMUNITY or SECURE")
118124

@@ -196,9 +202,9 @@ def terminate_pod(pod_id: str):
196202

197203

198204
def create_template(
199-
name:str, image_name:str, docker_start_cmd:str=None,
200-
container_disk_in_gb:int=10, volume_in_gb:int=None, volume_mount_path:str=None,
201-
ports:str=None, env:dict=None, is_serverless:bool=False
205+
name: str, image_name: str, docker_start_cmd: str = None,
206+
container_disk_in_gb: int = 10, volume_in_gb: int = None, volume_mount_path: str = None,
207+
ports: str = None, env: dict = None, is_serverless: bool = False, registry_auth: str = None
202208
):
203209
'''
204210
Create a template
@@ -223,12 +229,13 @@ def create_template(
223229
template_mutations.generate_pod_template(
224230
name, image_name, docker_start_cmd,
225231
container_disk_in_gb, volume_in_gb, volume_mount_path,
226-
ports, env, is_serverless
232+
ports, env, is_serverless, registry_auth
227233
)
228234
)
229235

230236
return raw_response["data"]["saveTemplate"]
231237

238+
232239
def get_endpoints() -> dict:
233240
'''
234241
Get all endpoints
@@ -237,11 +244,12 @@ def get_endpoints() -> dict:
237244
cleaned_return = raw_return["data"]["myself"]["endpoints"]
238245
return cleaned_return
239246

247+
240248
def create_endpoint(
241-
name:str, template_id:str, gpu_ids:str="AMPERE_16",
242-
network_volume_id:str=None, locations:str=None,
243-
idle_timeout:int=5, scaler_type:str="QUEUE_DELAY", scaler_value:int=4,
244-
workers_min:int=0, workers_max:int=3
249+
name: str, template_id: str, gpu_ids: str = "AMPERE_16",
250+
network_volume_id: str = None, locations: str = None,
251+
idle_timeout: int = 5, scaler_type: str = "QUEUE_DELAY", scaler_value: int = 4,
252+
workers_min: int = 0, workers_max: int = 3
245253
):
246254
'''
247255
Create an endpoint
@@ -274,7 +282,7 @@ def create_endpoint(
274282

275283

276284
def update_endpoint_template(
277-
endpoint_id:str, template_id:str
285+
endpoint_id: str, template_id: str
278286
):
279287
'''
280288
Update an endpoint template

runpod/api/mutations/templates.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
""" RunPod | API Wrapper | Mutations | Templates """
22

3-
# pylint: disable=too-many-arguments
3+
# pylint: disable=too-many-arguments, too-many-branches
4+
45

56
def generate_pod_template(
6-
name:str, image_name:str, docker_start_cmd:str=None,
7-
container_disk_in_gb:int=10, volume_in_gb:int=None, volume_mount_path:str=None,
8-
ports:str=None, env:dict=None, is_serverless:bool=False
7+
name: str, image_name: str, docker_start_cmd: str = None,
8+
container_disk_in_gb: int = 10, volume_in_gb: int = None, volume_mount_path: str = None,
9+
ports: str = None, env: dict = None, is_serverless: bool = False, registry_auth: str = None
910
):
1011
""" Generate a string for a GraphQL mutation to create a new pod template. """
1112
input_fields = []
@@ -44,12 +45,16 @@ def generate_pod_template(
4445
else:
4546
input_fields.append('env: []')
4647

47-
4848
if is_serverless:
4949
input_fields.append('isServerless: true')
5050
else:
5151
input_fields.append('isServerless: false')
5252

53+
if registry_auth is not None:
54+
input_fields.append(f'containerRegistryAuthId : "{registry_auth}"')
55+
else:
56+
input_fields.append('containerRegistryAuthId : ""')
57+
5358
# ------------------------------ Enforced Fields ----------------------------- #
5459
input_fields.append('startSsh: true')
5560
input_fields.append('isPublic: false')

tests/test_api/test_mutations_templates.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from runpod.api.mutations.templates import generate_pod_template
66

7+
78
class TestGeneratePodTemplate(unittest.TestCase):
89
""" Unit tests for the function generate_pod_template in the file api_wrapper.py """
910

@@ -24,7 +25,8 @@ def test_optional_fields(self):
2425
result = generate_pod_template(
2526
"test_name", "test_image_name", docker_start_cmd="test_cmd",
2627
volume_in_gb=5, volume_mount_path="/path/to/volume",
27-
ports="8000, 8001", env={"VAR1": "val1", "VAR2": "val2"}, is_serverless=True
28+
ports="8000, 8001", env={"VAR1": "val1", "VAR2": "val2"}, is_serverless=True,
29+
registry_auth="test_auth"
2830
)
2931
self.assertIn('dockerArgs: "test_cmd"', result)
3032
self.assertIn('volumeInGb: 5', result)
@@ -33,3 +35,4 @@ def test_optional_fields(self):
3335
self.assertIn(
3436
'env: [{ key: "VAR1", value: "val1" }, { key: "VAR2", value: "val2" }]', result)
3537
self.assertIn('isServerless: true', result)
38+
self.assertIn('containerRegistryAuthId : "test_auth"', result)

0 commit comments

Comments
 (0)