|
3 | 3 | """ |
4 | 4 | # pylint: disable=too-many-arguments, too-many-locals, too-many-branches |
5 | 5 |
|
| 6 | +from typing import Optional, List |
| 7 | + |
6 | 8 |
|
7 | 9 | def generate_pod_deployment_mutation( |
8 | | - name:str, image_name:str, gpu_type_id:str, |
9 | | - cloud_type:str="ALL", support_public_ip:bool=True, start_ssh:bool=True, |
| 10 | + name: str, image_name: str, gpu_type_id: str, |
| 11 | + cloud_type: str = "ALL", support_public_ip: bool = True, start_ssh: bool = True, |
10 | 12 | data_center_id=None, country_code=None, |
11 | 13 | gpu_count=None, volume_in_gb=None, container_disk_in_gb=None, min_vcpu_count=None, |
12 | 14 | min_memory_in_gb=None, docker_args=None, ports=None, volume_mount_path=None, |
13 | | - env:dict=None, template_id=None, network_volume_id=None): |
| 15 | + env: dict = None, template_id=None, network_volume_id=None, |
| 16 | + allowed_cuda_versions: Optional[List[str]] = None): |
14 | 17 | ''' |
15 | 18 | Generates a mutation to deploy a pod on demand. |
16 | 19 | ''' |
@@ -64,6 +67,11 @@ def generate_pod_deployment_mutation( |
64 | 67 | if network_volume_id is not None: |
65 | 68 | input_fields.append(f'networkVolumeId: "{network_volume_id}"') |
66 | 69 |
|
| 70 | + if allowed_cuda_versions is not None: |
| 71 | + allowed_cuda_versions_string = ", ".join( |
| 72 | + [f'"{version}"' for version in allowed_cuda_versions]) |
| 73 | + input_fields.append(f'allowedCudaVersions: [{allowed_cuda_versions_string}]') |
| 74 | + |
67 | 75 | # Format input fields |
68 | 76 | input_string = ", ".join(input_fields) |
69 | 77 |
|
|
0 commit comments