Skip to content

Commit abed946

Browse files
Merge pull request #282 from runpod/allowed-cuda-pod
feat: add the ability to create pod with selected cuda
2 parents 0b6e703 + f095e4b commit abed946

3 files changed

Lines changed: 18 additions & 6 deletions

File tree

runpod/api/ctl_commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def create_pod(
9393
min_vcpu_count: int = 1, min_memory_in_gb: int = 1, docker_args: str = "",
9494
ports: Optional[str] = None, volume_mount_path: str = "/runpod-volume",
9595
env: Optional[dict] = None, template_id: Optional[str] = None,
96-
network_volume_id: Optional[str] = None
96+
network_volume_id: Optional[str] = None, allowed_cuda_versions: Optional[list] = None
9797
) -> dict:
9898
'''
9999
Create a pod
@@ -138,7 +138,7 @@ def create_pod(
138138
cloud_type, support_public_ip, start_ssh,
139139
data_center_id, country_code, gpu_count,
140140
volume_in_gb, container_disk_in_gb, min_vcpu_count, min_memory_in_gb, docker_args,
141-
ports, volume_mount_path, env, template_id, network_volume_id)
141+
ports, volume_mount_path, env, template_id, network_volume_id, allowed_cuda_versions)
142142
)
143143

144144
cleaned_response = raw_response["data"]["podFindAndDeployOnDemand"]

runpod/api/mutations/pods.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
"""
44
# pylint: disable=too-many-arguments, too-many-locals, too-many-branches
55

6+
from typing import Optional, List
7+
68

79
def generate_pod_deployment_mutation(
8-
name:str, image_name:str, gpu_type_id:str,
9-
cloud_type:str="ALL", support_public_ip:bool=True, start_ssh:bool=True,
10+
name: str, image_name: str, gpu_type_id: str,
11+
cloud_type: str = "ALL", support_public_ip: bool = True, start_ssh: bool = True,
1012
data_center_id=None, country_code=None,
1113
gpu_count=None, volume_in_gb=None, container_disk_in_gb=None, min_vcpu_count=None,
1214
min_memory_in_gb=None, docker_args=None, ports=None, volume_mount_path=None,
13-
env:dict=None, template_id=None, network_volume_id=None):
15+
env: dict = None, template_id=None, network_volume_id=None,
16+
allowed_cuda_versions: Optional[List[str]] = None):
1417
'''
1518
Generates a mutation to deploy a pod on demand.
1619
'''
@@ -64,6 +67,11 @@ def generate_pod_deployment_mutation(
6467
if network_volume_id is not None:
6568
input_fields.append(f'networkVolumeId: "{network_volume_id}"')
6669

70+
if allowed_cuda_versions is not None:
71+
allowed_cuda_versions_string = ", ".join(
72+
[f'"{version}"' for version in allowed_cuda_versions])
73+
input_fields.append(f'allowedCudaVersions: [{allowed_cuda_versions_string}]')
74+
6775
# Format input fields
6876
input_string = ", ".join(input_fields)
6977

tests/test_api/test_mutations_pods.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from runpod.api.mutations import pods
66

7+
78
class TestPodMutations(unittest.TestCase):
89
''' Test API Wrapper Pod Mutations '''
910

@@ -28,7 +29,9 @@ def test_generate_pod_deployment_mutation(self):
2829
volume_mount_path="/path",
2930
env={"ENV": "test"},
3031
support_public_ip=True,
31-
template_id="abcde")
32+
template_id="abcde",
33+
allowed_cuda_versions=["11.8", "12.0"]
34+
)
3235

3336
# Here you should check the correct structure of the result
3437
self.assertIn("mutation", result)
@@ -57,5 +60,6 @@ def test_generate_pod_terminate_mutation(self):
5760
# Here you should check the correct structure of the result
5861
self.assertIn("mutation", result)
5962

63+
6064
if __name__ == "__main__":
6165
unittest.main()

0 commit comments

Comments
 (0)