diff --git a/examples/yaml_define/Kubernetes.yaml b/examples/yaml_define/Kubernetes.yaml index 1128ca70..fd94d79b 100644 --- a/examples/yaml_define/Kubernetes.yaml +++ b/examples/yaml_define/Kubernetes.yaml @@ -25,5 +25,5 @@ tasks: task_type: K8S image: ds-dev namespace: '{ "name": "default","cluster": "lab" }' - minCpuCores: 2.0 - minMemorySpace: 10.0 \ No newline at end of file + min_cpu_cores: 2.0 + min_memory_space: 10.0 \ No newline at end of file diff --git a/src/pydolphinscheduler/core/yaml_workflow.py b/src/pydolphinscheduler/core/yaml_workflow.py index d7c57f16..c536dc43 100644 --- a/src/pydolphinscheduler/core/yaml_workflow.py +++ b/src/pydolphinscheduler/core/yaml_workflow.py @@ -26,7 +26,7 @@ from typing import Any from pydolphinscheduler import configuration, tasks -from pydolphinscheduler.constants import Symbol +from pydolphinscheduler.constants import Symbol, TaskType from pydolphinscheduler.core.parameter import ParameterType from pydolphinscheduler.core.task import Task from pydolphinscheduler.core.workflow import Workflow @@ -113,15 +113,33 @@ def get_possible_path(file_path, base_folder): return possible_path +# Alias from a ``TaskType`` constant value to its task class name, for the +# cases where the task type constant sent to the backend differs from the +# Python task class name (e.g. ``TaskType.KUBERNETES`` is ``"K8S"`` while the +# class is ``Kubernetes``). Add an entry here whenever a new task type's +# constant value does not equal its class name (case-insensitively). +_TASK_TYPE_ALIAS = { + TaskType.KUBERNETES: "Kubernetes", +} + + def get_task_cls(task_type) -> Task: - """Get the task class object by task_type (case compatible).""" + """Get the task class object by task_type (case compatible). + + Match by class name first (case-insensitively), then fall back to the + ``TaskType`` constant values so a YAML ``task_type`` like ``K8S`` resolves + to the ``Kubernetes`` task class. + """ # only get task class from tasks.__all__ all_task_types = {type_.capitalize(): type_ for type_ in tasks.__all__} task_type_cap = task_type.capitalize() - if task_type_cap not in all_task_types: + if task_type_cap in all_task_types: + standard_name = all_task_types[task_type_cap] + elif task_type.upper() in _TASK_TYPE_ALIAS: + standard_name = _TASK_TYPE_ALIAS[task_type.upper()] + else: raise PyDSTaskNoFoundException(f"cant not find task {task_type}") - standard_name = all_task_types[task_type_cap] return getattr(tasks, standard_name) diff --git a/tests/core/test_yaml_workflow.py b/tests/core/test_yaml_workflow.py index 45a35f0c..48761938 100644 --- a/tests/core/test_yaml_workflow.py +++ b/tests/core/test_yaml_workflow.py @@ -137,6 +137,9 @@ def test_parse_tool_parse_possible_path_file(): ("SubWorkflow", tasks.SubWorkflow), ("Switch", tasks.Switch), ("SageMaker", tasks.SageMaker), + ("Kubernetes", tasks.Kubernetes), + ("K8S", tasks.Kubernetes), + ("k8s", tasks.Kubernetes), ], ) def test_get_task(task_type, expect): @@ -168,6 +171,7 @@ def test_get_error(task_type): ("Flink.yaml"), ("Procedure.yaml"), ("Http.yaml"), + ("Kubernetes.yaml"), ("MapReduce.yaml"), ("Python.yaml"), ("Shell.yaml"),