Skip to content

Commit 3022f61

Browse files
committed
Add missing docstrings, update CategoricalHyperparameterSettings API to provide list of valid categories
1 parent a9e4742 commit 3022f61

1 file changed

Lines changed: 115 additions & 19 deletions

File tree

dataikuapi/dss/ml.py

Lines changed: 115 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,36 @@ def strategy(self, strategy):
395395
self._raw_settings["strategy"] = strategy
396396

397397
def set_grid_search(self, shuffle=True, seed=0):
398+
"""
399+
Sets the search strategy to "GRID" to perform a grid-search on the hyperparameters.
400+
:param bool shuffle: if True, the search will iterate over a shuffled grid as opposed to the lexicographical
401+
iteration over the cartesian product of the hyperparameters.
402+
:param int seed:
403+
"""
398404
self._raw_settings["strategy"] = "GRID"
399405
if shuffle is not None:
400406
if not isinstance(shuffle, bool):
401-
warnings.warn()
407+
warnings.warn("HyperparameterSearchSettings.set_grid_search ignoring invalid input: shuffle")
408+
warnings.warn("shuffle must be a boolean")
402409
else:
403410
self._raw_settings["randomized"] = shuffle
404411
self._set_seed(seed)
405412
return self
406413

407414
def set_random_search(self, seed=0):
415+
"""
416+
Sets the search strategy to "RANDOM" to perform a random search on the hyperparameters.
417+
:param int seed:
418+
"""
408419
self._raw_settings["strategy"] = "RANDOM"
409420
self._set_seed(seed)
410421
return self
411422

412423
def set_bayesian_search(self, seed=0):
424+
"""
425+
Sets the search strategy to "BAYESIAN" to perform a Bayesian search on the hyperparameters.
426+
:param int seed:
427+
"""
413428
self._raw_settings["strategy"] = "BAYESIAN"
414429
self._set_seed(seed)
415430
return self
@@ -420,17 +435,27 @@ def validation_mode(self):
420435

421436
@validation_mode.setter
422437
def validation_mode(self, mode):
438+
"""
439+
:param str mode: "KFOLD" | "SHUFFLE" | "TIME_SERIES_KFOLD" | "TIME_SERIES_SINGLE_SPLIT" | "CUSTOM"
440+
"""
423441
assert mode in {"KFOLD", "SHUFFLE", "TIME_SERIES_KFOLD", "TIME_SERIES_SINGLE_SPLIT", "CUSTOM"}
424442
self._raw_settings["mode"] = mode
425443

426444
def set_kfold_validation(self, n_folds=5, stratified=True):
445+
"""
446+
Sets the validation mode to k-fold cross-validation (either "KFOLD" or "TIME_SERIES_KFOLD" if time-based ordering
447+
is enabled).
448+
:param int n_folds: the number of folds used for the hyperparameter search
449+
:param bool stratified: if True, will keep the same proportion of each target classes in all folds
450+
:return:
451+
"""
427452
if self._raw_settings["mode"] == "TIME_SERIES_SINGLE_SPLIT":
428453
self._raw_settings["mode"] = "TIME_SERIES_KFOLD"
429454
else:
430455
self._raw_settings["mode"] = "KFOLD"
431456
if n_folds is not None:
432457
if not (isinstance(n_folds, int) and n_folds > 0):
433-
warnings.warn("HyperparameterSearchSettings.set_validation_mode_to_kfold ignoring invalid input: n_folds")
458+
warnings.warn("HyperparameterSearchSettings.set_kfold_validation ignoring invalid input: n_folds")
434459
warnings.warn("n_folds must be a positive integer")
435460
self._raw_settings["nFolds"] = n_folds
436461
if stratified is not None:
@@ -442,33 +467,51 @@ def set_kfold_validation(self, n_folds=5, stratified=True):
442467
return self
443468

444469
def set_single_split_validation(self, split_ratio=0.8, stratified=True):
470+
"""
471+
Sets the validation mode to single split (either "SHUFFLE" or "TIME_SERIES_SINGLE_SPLIT" if time-based ordering
472+
is enabled).
473+
:param float split_ratio: ratio of the data used for the train during hyperparameter search
474+
:param bool stratified: if True, will keep the same proportion of each target classes in both splits
475+
:return:
476+
"""
445477
if self._raw_settings["mode"] == "TIME_SERIES_KFOLD":
446478
self._raw_settings["mode"] = "TIME_SERIES_SINGLE_SPLIT"
447479
else:
448480
self._raw_settings["mode"] = "SHUFFLE"
449481
if split_ratio is not None:
450482
if not (isinstance(split_ratio, float) and split_ratio > 0 and split_ratio < 1):
451-
warnings.warn("HyperparameterSearchSettings.set_validation_mode_to_single_split ignoring invalid input: split_ratio")
483+
warnings.warn("HyperparameterSearchSettings.set_single_split_validation ignoring invalid input: split_ratio")
452484
warnings.warn(" split_ratio must be float between 0 and 1")
453485
self._raw_settings["splitRatio"] = split_ratio
454486
if stratified is not None:
455487
if not isinstance(stratified, bool):
456-
warnings.warn("HyperparameterSearchSettings.set_validation_mode_to_single_split ignoring invalid input: stratified")
488+
warnings.warn("HyperparameterSearchSettings.set_single_split_validation ignoring invalid input: stratified")
457489
warnings.warn("stratified must be a boolean")
458490
else:
459491
self._raw_settings["stratified"] = stratified
460492
return self
461493

462494
def set_custom_validation(self, code=None):
495+
"""
496+
Sets the validation mode to "CUSTOM".
497+
:param str code: definition of the validation
498+
"""
463499
self._raw_settings["mode"] = "CUSTOM"
464500
if code is not None:
465501
if not isinstance(code, string_types):
466-
warnings.warn("HyperparameterSearchSettings.set_validation_mode_to_custom ignoring invalid input: code")
502+
warnings.warn("HyperparameterSearchSettings.set_custom_validation ignoring invalid input: code")
467503
warnings.warn("code must be a Python interpretable string")
468504
self._raw_settings["code"] = code
469505
return self
470506

471507
def set_search_distribution(self, distributed=False, n_containers=4):
508+
"""
509+
Sets the distribution parameters for the hyperparameter search execution.
510+
:param bool distributed: if True, search will be distributed across n_containers containers in the Kubernetes
511+
cluster selected in containerized execution configuration of the runtime environment
512+
:param int n_containers:
513+
:return:
514+
"""
472515
assert isinstance(distributed, bool)
473516
if n_containers is not None:
474517
assert isinstance(n_containers, int)
@@ -567,6 +610,12 @@ def definition_mode(self):
567610

568611
@definition_mode.setter
569612
def definition_mode(self, mode):
613+
"""
614+
"EXPLICIT" means that the hyperparameter search is performed over a given set of values (default for grid search)
615+
"RANGE" means that the hyperparameter search is performed over a range of values (default for random and Bayesian
616+
searches)
617+
:param str mode: "EXPLICIT" | "RANGE"
618+
"""
570619
assert mode in ["EXPLICIT", "RANGE"], "Hyperparameter definition mode must be either \"EXPLICIT\" or \"RANGE\""
571620
if self._algo_settings.strategy == "GRID":
572621
self._algo_settings[self.name]["gridMode"] = mode
@@ -575,6 +624,12 @@ def definition_mode(self, mode):
575624
self._algo_settings[self.name]["randomMode"] = mode
576625

577626
def set_explicit_values(self, values):
627+
"""
628+
Sets both:
629+
- the explicit values to search over for the current numerical hyperparameter
630+
- the definition mode of the current numerical hyperparameter to "EXPLICIT"
631+
:param list values: the explicit list of numerical values that will be searched for this hyperparameter
632+
"""
578633
self.values(values)
579634
self.definition_mode = "EXPLICIT"
580635

@@ -584,6 +639,9 @@ def values(self):
584639

585640
@values.setter
586641
def values(self, values):
642+
"""
643+
:param list values: the explicit list of numerical values that will be searched for this hyperparameter
644+
"""
587645
error_message = "Invalid values input type for hyperparameter " \
588646
"\"{}\": ".format(self.name) + \
589647
" expecting a non-empty list of numbers"
@@ -630,6 +688,14 @@ def _set_range(self, min=None, max=None, nb_values=None):
630688
self._algo_settings[self.name]["range"]["nbValues"] = nb_values
631689

632690
def set_range(self, min=None, max=None, nb_values=None):
691+
"""
692+
Sets both:
693+
- the Range parameters to search over for the current numerical hyperparameter
694+
- the definition mode of the current numerical hyperparameter to "RANGE"
695+
:param min: the lower bound of the Range that will be searched for this hyperparameter
696+
:param max: the upper bound of the Range that will be searched for this hyperparameter
697+
:param nb_values: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
698+
"""
633699
self._set_range(min=min, max=max, nb_values=nb_values)
634700
self.definition_mode = "RANGE"
635701

@@ -650,6 +716,9 @@ def min(self):
650716

651717
@min.setter
652718
def min(self, val):
719+
"""
720+
:param float | int val: the lower bound of the Range that will be searched for this hyperparameter
721+
"""
653722
self._numerical_hyperparameter_settings._set_range(min=val)
654723

655724
@property
@@ -658,6 +727,9 @@ def max(self):
658727

659728
@max.setter
660729
def max(self, val):
730+
"""
731+
:param float | int val: the upper bound of the Range that will be searched for this hyperparameter
732+
"""
661733
self._numerical_hyperparameter_settings._set_range(max=val)
662734

663735
@property
@@ -666,6 +738,9 @@ def nb_values(self):
666738

667739
@nb_values.setter
668740
def nb_values(self, val):
741+
"""
742+
:param int val: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
743+
"""
669744
self._numerical_hyperparameter_settings._set_range(nb_values=val)
670745

671746

@@ -679,7 +754,7 @@ def __repr__(self):
679754
def _pretty_repr(self):
680755
return self.__class__.__name__ + "(hyperparameter=\"{}\", settings={})".format(self.name, json.dumps(self._algo_settings[self.name], indent=4))
681756

682-
def set_values(self, values=None):
757+
def _set_values(self, values=None):
683758
if values is None:
684759
warnings.warn("Categorical hyperparameter \"{}\" not modified".format(self.name))
685760
else:
@@ -698,28 +773,45 @@ def set_values(self, values=None):
698773
self._algo_settings[self.name]["values"][category] = setting
699774

700775
def enable_categories(self, categories, disable_others=False):
701-
accepted_categories = self._algo_settings[self.name]["values"].keys()
776+
"""
777+
Enables the search over categories listed in the first argument.
778+
:param list categories: will enable the search over the provided categories
779+
:param bool disable_others: if True, will also disable the search over categories not listed in the first argument
780+
"""
781+
accepted_categories = self.get_all_categories()
702782
for category in categories:
703783
assert isinstance(category, string_types)
704784
assert category in accepted_categories
705-
self.set_values({category: {"enabled": True}
706-
for category in categories})
785+
self._set_values({category: {"enabled": True}
786+
for category in categories})
707787
if disable_others:
708-
self.set_values({category: {"enabled": False}
709-
for category in accepted_categories
710-
if category not in categories})
788+
self._set_values({category: {"enabled": False}
789+
for category in accepted_categories
790+
if category not in categories})
711791

712792
def disable_categories(self, categories, enable_others=False):
713-
accepted_categories = self._algo_settings[self.name]["values"].keys()
793+
"""
794+
Disables the search over categories listed in the first argument.
795+
:param list categories: will disable the search over the provided categories
796+
:param bool enable_others: if True, will also enable the search over categories not listed in the first argument
797+
"""
798+
accepted_categories = self.get_all_categories()
714799
for category in categories:
715800
assert isinstance(category, string_types)
716801
assert category in accepted_categories
717-
self.set_values({category: {"enabled": False}
718-
for category in categories})
802+
self._set_values({category: {"enabled": False}
803+
for category in categories})
719804
if enable_others:
720-
self.set_values({category: {"enabled": True}
721-
for category in accepted_categories
722-
if category not in categories})
805+
self._set_values({category: {"enabled": True}
806+
for category in accepted_categories
807+
if category not in categories})
808+
809+
def get_all_categories(self):
810+
"""
811+
:return: list of valid categories for this hyperparameter
812+
"""
813+
return list(self._algo_settings[self.name]["values"].keys())
814+
723815

724816

725817
class SingleValueHyperparameterSettings(HyperparameterSettings):
@@ -1138,7 +1230,7 @@ def get_algorithm_settings(self, algorithm_name):
11381230
11391231
This method returns the settings for this algorithm as an AlgorithmSettings (extended dict).
11401232
All algorithm dicts have at least an "enabled" property/key in the settings.
1141-
The 'enabled' key/property indicates whether this algorithm will be trained.
1233+
The "enabled" property/key indicates whether this algorithm will be trained.
11421234
11431235
Other settings are algorithm-dependent and are the various hyperparameters of the
11441236
algorithm. The precise properties/keys for each algorithm are not all documented. You can print
@@ -1170,6 +1262,10 @@ def get_algorithm_settings(self, algorithm_name):
11701262

11711263
def get_hyperparameter_search_settings(self):
11721264
"""
1265+
Gets the hyperparameter search parameters of the current DSSPredictionMLTaskSettings instance as a
1266+
HyperparameterSearchSettings object. This object can be used to both get and set properties relevant to
1267+
hyperparameter search, such as search strategy, cross-validation method, execution limits and parallelism.
1268+
11731269
:return: A HyperparameterSearchSettings
11741270
:rtype: :class:`HyperparameterSearchSettings`
11751271
"""

0 commit comments

Comments
 (0)