Skip to content

Commit b49d30e

Browse files
committed
Redispatch get_algorithm_settings and get_hyperparameter_search_settings methods to appropriate subclasses of DSSMLTaskSettings
1 parent 5851b24 commit b49d30e

1 file changed

Lines changed: 68 additions & 38 deletions

File tree

dataikuapi/dss/ml.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -233,41 +233,7 @@ def use_feature(self, feature_name):
233233
self.get_feature_preprocessing(feature_name)["role"] = "INPUT"
234234

235235
def get_algorithm_settings(self, algorithm_name):
236-
"""
237-
Gets the training settings for a particular algorithm. This returns a reference to the
238-
algorithm's settings, not a copy, so changes made to the returned object will be reflected when saving.
239-
240-
This method returns the settings for this algorithm as an AlgorithmSettings (extended dict).
241-
All algorithm dicts have at least an "enabled" property/key in the settings.
242-
The 'enabled' key/property indicates whether this algorithm will be trained.
243-
244-
Other settings are algorithm-dependent and are the various hyperparameters of the
245-
algorithm. The precise properties/keys for each algorithm are not all documented. You can print
246-
the returned AlgorithmSettings to learn more about the settings of each particular algorithm.
247-
248-
Please refer to the documentation for details on available algorithms.
249-
250-
:param str algorithm_name: Name (in capitals) of the algorithm.
251-
:return: An AlgorithmSettings (extended dict) for a single built-in algorithm,
252-
or a plain dict for the settings of custom inline-code (CUSTOM_*) or plugin (PLUGIN_*) algorithms.
253-
:rtype: AlgorithmSettings | dict
254-
"""
255-
if algorithm_name in ["CUSTOM_MLLIB", "CUSTOM_PYTHON", "PLUGIN_PYTHON"]:
256-
return self.mltask_settings["modeling"][algorithm_name.lower()]
257-
elif algorithm_name in self.__class__.algorithm_remap:
258-
algorithm_meta = self.__class__.algorithm_remap[algorithm_name]
259-
algorithm_name = algorithm_meta.algorithm_name
260-
algorithm_settings_class = algorithm_meta.algorithm_settings_class
261-
262-
algorithm_settings = self.mltask_settings["modeling"][algorithm_name.lower()]
263-
if not isinstance(algorithm_settings, AlgorithmSettings):
264-
raw_hyperparameter_search_params = self.mltask_settings["modeling"]["gridSearchParams"]
265-
algorithm_settings = algorithm_settings_class(algorithm_settings, raw_hyperparameter_search_params)
266-
# Subsequent calls get the same object
267-
self.mltask_settings["modeling"][algorithm_name.lower()] = algorithm_settings
268-
return self.mltask_settings["modeling"][algorithm_name.lower()]
269-
else:
270-
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
236+
raise NotImplementedError()
271237

272238
def set_algorithm_enabled(self, algorithm_name, enabled):
273239
"""
@@ -346,9 +312,6 @@ def set_metric(self, metric=None, custom_metric=None, custom_metric_greater_is_b
346312
self.mltask_settings["modeling"]["metrics"]["customEvaluationMetricGIB"] = custom_metric_greater_is_better
347313
self.mltask_settings["modeling"]["metrics"]["customEvaluationMetricNeedsProba"] = custom_metric_use_probas
348314

349-
def get_hyperparameter_search_settings(self):
350-
return HyperparameterSearchSettings(self.mltask_settings["modeling"]["gridSearchParams"])
351-
352315
def save(self):
353316
"""Saves back these settings to the ML Task"""
354317

@@ -1158,6 +1121,50 @@ def __init__(self, client, project_key, analysis_id, mltask_id, mltask_settings)
11581121
def get_prediction_type(self):
11591122
return self.mltask_settings['predictionType']
11601123

1124+
def get_algorithm_settings(self, algorithm_name):
1125+
"""
1126+
Gets the training settings for a particular algorithm. This returns a reference to the
1127+
algorithm's settings, not a copy, so changes made to the returned object will be reflected when saving.
1128+
1129+
This method returns the settings for this algorithm as an AlgorithmSettings (extended dict).
1130+
All algorithm dicts have at least an "enabled" property/key in the settings.
1131+
The 'enabled' key/property indicates whether this algorithm will be trained.
1132+
1133+
Other settings are algorithm-dependent and are the various hyperparameters of the
1134+
algorithm. The precise properties/keys for each algorithm are not all documented. You can print
1135+
the returned AlgorithmSettings to learn more about the settings of each particular algorithm.
1136+
1137+
Please refer to the documentation for details on available algorithms.
1138+
1139+
:param str algorithm_name: Name (in capitals) of the algorithm.
1140+
:return: A PredictionAlgorithmSettings (extended dict) for a single built-in prediction algorithm,
1141+
or a plain dict for the settings of custom inline-code (CUSTOM_*) or plugin (PLUGIN_*) algorithms.
1142+
:rtype: PredictionAlgorithmSettings | dict
1143+
"""
1144+
if algorithm_name in ["CUSTOM_MLLIB", "CUSTOM_PYTHON", "PLUGIN_PYTHON"]:
1145+
return self.mltask_settings["modeling"][algorithm_name.lower()]
1146+
elif algorithm_name in self.__class__.algorithm_remap:
1147+
algorithm_meta = self.__class__.algorithm_remap[algorithm_name]
1148+
algorithm_name = algorithm_meta.algorithm_name
1149+
algorithm_settings_class = algorithm_meta.algorithm_settings_class
1150+
1151+
algorithm_settings = self.mltask_settings["modeling"][algorithm_name.lower()]
1152+
if not isinstance(algorithm_settings, PredictionAlgorithmSettings):
1153+
raw_hyperparameter_search_params = self.mltask_settings["modeling"]["gridSearchParams"]
1154+
algorithm_settings = algorithm_settings_class(algorithm_settings, raw_hyperparameter_search_params)
1155+
# Subsequent calls get the same object
1156+
self.mltask_settings["modeling"][algorithm_name.lower()] = algorithm_settings
1157+
return self.mltask_settings["modeling"][algorithm_name.lower()]
1158+
else:
1159+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
1160+
1161+
def get_hyperparameter_search_settings(self):
1162+
"""
1163+
:return: A HyperparameterSearchSettings
1164+
:rtype: :class:`HyperparameterSearchSettings`
1165+
"""
1166+
return HyperparameterSearchSettings(self.mltask_settings["modeling"]["gridSearchParams"])
1167+
11611168
@property
11621169
def split_params(self):
11631170
"""
@@ -1262,6 +1269,29 @@ class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
12621269
"DBSCAN" : "db_scan_clustering",
12631270
}
12641271

1272+
def get_algorithm_settings(self, algorithm_name):
1273+
"""
1274+
Gets the training settings for a particular algorithm. This returns a reference to the
1275+
algorithm's settings, not a copy, so changes made to the returned object will be reflected when saving.
1276+
1277+
This method returns a dictionary of the settings for this algorithm.
1278+
All algorithm dicts have at least an "enabled" key in the dictionary.
1279+
The 'enabled' key indicates whether this algorithm will be trained
1280+
1281+
Other settings are algorithm-dependent and are the various hyperparameters of the
1282+
algorithm. The precise keys for each algorithm are not all documented. You can print
1283+
the returned dictionary to learn more about the settings of each particular algorithm
1284+
1285+
Please refer to the documentation for details on available algorithms.
1286+
1287+
:param str algorithm_name: Name (in capitals) of the algorithm.
1288+
:return: A dict of the settings for an algorithm
1289+
:rtype: dict
1290+
"""
1291+
if algorithm_name in self.__class__.algorithm_remap:
1292+
algorithm_name = self.__class__.algorithm_remap[algorithm_name]
1293+
1294+
return self.mltask_settings["modeling"][algorithm_name.lower()]
12651295

12661296

12671297
class DSSTrainedModelDetails(object):

0 commit comments

Comments
 (0)