Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/yaml_define/Kubernetes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ tasks:
task_type: K8S
image: ds-dev
namespace: '{ "name": "default","cluster": "lab" }'
minCpuCores: 2.0
minMemorySpace: 10.0
min_cpu_cores: 2.0
min_memory_space: 10.0
26 changes: 22 additions & 4 deletions src/pydolphinscheduler/core/yaml_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Any

from pydolphinscheduler import configuration, tasks
from pydolphinscheduler.constants import Symbol
from pydolphinscheduler.constants import Symbol, TaskType
from pydolphinscheduler.core.parameter import ParameterType
from pydolphinscheduler.core.task import Task
from pydolphinscheduler.core.workflow import Workflow
Expand Down Expand Up @@ -113,15 +113,33 @@ def get_possible_path(file_path, base_folder):
return possible_path


# Alias from a ``TaskType`` constant value to its task class name, for the
# cases where the task type constant sent to the backend differs from the
# Python task class name (e.g. ``TaskType.KUBERNETES`` is ``"K8S"`` while the
# class is ``Kubernetes``). Add an entry here whenever a new task type's
# constant value does not equal its class name (case-insensitively).
_TASK_TYPE_ALIAS = {
TaskType.KUBERNETES: "Kubernetes",
}


def get_task_cls(task_type) -> Task:
"""Get the task class object by task_type (case compatible)."""
"""Get the task class object by task_type (case compatible).

Match by class name first (case-insensitively), then fall back to the
``TaskType`` constant values so a YAML ``task_type`` like ``K8S`` resolves
to the ``Kubernetes`` task class.
"""
# only get task class from tasks.__all__
all_task_types = {type_.capitalize(): type_ for type_ in tasks.__all__}
task_type_cap = task_type.capitalize()
if task_type_cap not in all_task_types:
if task_type_cap in all_task_types:
standard_name = all_task_types[task_type_cap]
elif task_type.upper() in _TASK_TYPE_ALIAS:
standard_name = _TASK_TYPE_ALIAS[task_type.upper()]
else:
raise PyDSTaskNoFoundException(f"cant not find task {task_type}")

standard_name = all_task_types[task_type_cap]
return getattr(tasks, standard_name)


Expand Down
4 changes: 4 additions & 0 deletions tests/core/test_yaml_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def test_parse_tool_parse_possible_path_file():
("SubWorkflow", tasks.SubWorkflow),
("Switch", tasks.Switch),
("SageMaker", tasks.SageMaker),
("Kubernetes", tasks.Kubernetes),
("K8S", tasks.Kubernetes),
("k8s", tasks.Kubernetes),
],
)
def test_get_task(task_type, expect):
Expand Down Expand Up @@ -168,6 +171,7 @@ def test_get_error(task_type):
("Flink.yaml"),
("Procedure.yaml"),
("Http.yaml"),
("Kubernetes.yaml"),
("MapReduce.yaml"),
("Python.yaml"),
("Shell.yaml"),
Expand Down