@@ -395,21 +395,36 @@ def strategy(self, strategy):
395395 self ._raw_settings ["strategy" ] = strategy
396396
397397 def set_grid_search (self , shuffle = True , seed = 0 ):
398+ """
399+ Sets the search strategy to "GRID" to perform a grid-search on the hyperparameters.
400+ :param bool shuffle: if True, the search will iterate over a shuffled grid as opposed to the lexicographical
401+ iteration over the cartesian product of the hyperparameters.
402+ :param int seed:
403+ """
398404 self ._raw_settings ["strategy" ] = "GRID"
399405 if shuffle is not None :
400406 if not isinstance (shuffle , bool ):
401- warnings .warn ()
407+ warnings .warn ("HyperparameterSearchSettings.set_grid_search ignoring invalid input: shuffle" )
408+ warnings .warn ("shuffle must be a boolean" )
402409 else :
403410 self ._raw_settings ["randomized" ] = shuffle
404411 self ._set_seed (seed )
405412 return self
406413
407414 def set_random_search (self , seed = 0 ):
415+ """
416+ Sets the search strategy to "RANDOM" to perform a random search on the hyperparameters.
417+ :param int seed:
418+ """
408419 self ._raw_settings ["strategy" ] = "RANDOM"
409420 self ._set_seed (seed )
410421 return self
411422
412423 def set_bayesian_search (self , seed = 0 ):
424+ """
425+ Sets the search strategy to "BAYESIAN" to perform a Bayesian search on the hyperparameters.
426+ :param int seed:
427+ """
413428 self ._raw_settings ["strategy" ] = "BAYESIAN"
414429 self ._set_seed (seed )
415430 return self
@@ -420,17 +435,27 @@ def validation_mode(self):
420435
421436 @validation_mode .setter
422437 def validation_mode (self , mode ):
438+ """
439+ :param str mode: "KFOLD" | "SHUFFLE" | "TIME_SERIES_KFOLD" | "TIME_SERIES_SINGLE_SPLIT" | "CUSTOM"
440+ """
423441 assert mode in {"KFOLD" , "SHUFFLE" , "TIME_SERIES_KFOLD" , "TIME_SERIES_SINGLE_SPLIT" , "CUSTOM" }
424442 self ._raw_settings ["mode" ] = mode
425443
426444 def set_kfold_validation (self , n_folds = 5 , stratified = True ):
445+ """
446+ Sets the validation mode to k-fold cross-validation (either "KFOLD" or "TIME_SERIES_KFOLD" if time-based ordering
447+ is enabled).
448+ :param int n_folds: the number of folds used for the hyperparameter search
449+ :param bool stratified: if True, will keep the same proportion of each target classes in all folds
450+ :return:
451+ """
427452 if self ._raw_settings ["mode" ] == "TIME_SERIES_SINGLE_SPLIT" :
428453 self ._raw_settings ["mode" ] = "TIME_SERIES_KFOLD"
429454 else :
430455 self ._raw_settings ["mode" ] = "KFOLD"
431456 if n_folds is not None :
432457 if not (isinstance (n_folds , int ) and n_folds > 0 ):
433- warnings .warn ("HyperparameterSearchSettings.set_validation_mode_to_kfold ignoring invalid input: n_folds" )
458+ warnings .warn ("HyperparameterSearchSettings.set_kfold_validation ignoring invalid input: n_folds" )
434459 warnings .warn ("n_folds must be a positive integer" )
435460 self ._raw_settings ["nFolds" ] = n_folds
436461 if stratified is not None :
@@ -442,33 +467,51 @@ def set_kfold_validation(self, n_folds=5, stratified=True):
442467 return self
443468
444469 def set_single_split_validation (self , split_ratio = 0.8 , stratified = True ):
470+ """
471+ Sets the validation mode to single split (either "SHUFFLE" or "TIME_SERIES_SINGLE_SPLIT" if time-based ordering
472+ is enabled).
473+ :param float split_ratio: ratio of the data used for the train during hyperparameter search
474+ :param bool stratified: if True, will keep the same proportion of each target classes in both splits
475+ :return:
476+ """
445477 if self ._raw_settings ["mode" ] == "TIME_SERIES_KFOLD" :
446478 self ._raw_settings ["mode" ] = "TIME_SERIES_SINGLE_SPLIT"
447479 else :
448480 self ._raw_settings ["mode" ] = "SHUFFLE"
449481 if split_ratio is not None :
450482 if not (isinstance (split_ratio , float ) and split_ratio > 0 and split_ratio < 1 ):
451- warnings .warn ("HyperparameterSearchSettings.set_validation_mode_to_single_split ignoring invalid input: split_ratio" )
483+ warnings .warn ("HyperparameterSearchSettings.set_single_split_validation ignoring invalid input: split_ratio" )
452484 warnings .warn (" split_ratio must be float between 0 and 1" )
453485 self ._raw_settings ["splitRatio" ] = split_ratio
454486 if stratified is not None :
455487 if not isinstance (stratified , bool ):
456- warnings .warn ("HyperparameterSearchSettings.set_validation_mode_to_single_split ignoring invalid input: stratified" )
488+ warnings .warn ("HyperparameterSearchSettings.set_single_split_validation ignoring invalid input: stratified" )
457489 warnings .warn ("stratified must be a boolean" )
458490 else :
459491 self ._raw_settings ["stratified" ] = stratified
460492 return self
461493
462494 def set_custom_validation (self , code = None ):
495+ """
496+ Sets the validation mode to "CUSTOM".
497+ :param str code: definition of the validation
498+ """
463499 self ._raw_settings ["mode" ] = "CUSTOM"
464500 if code is not None :
465501 if not isinstance (code , string_types ):
466- warnings .warn ("HyperparameterSearchSettings.set_validation_mode_to_custom ignoring invalid input: code" )
502+ warnings .warn ("HyperparameterSearchSettings.set_custom_validation ignoring invalid input: code" )
467503 warnings .warn ("code must be a Python interpretable string" )
468504 self ._raw_settings ["code" ] = code
469505 return self
470506
471507 def set_search_distribution (self , distributed = False , n_containers = 4 ):
508+ """
509+ Sets the distribution parameters for the hyperparameter search execution.
510+ :param bool distributed: if True, search will be distributed across n_containers containers in the Kubernetes
511+ cluster selected in containerized execution configuration of the runtime environment
512+ :param int n_containers:
513+ :return:
514+ """
472515 assert isinstance (distributed , bool )
473516 if n_containers is not None :
474517 assert isinstance (n_containers , int )
@@ -567,6 +610,12 @@ def definition_mode(self):
567610
568611 @definition_mode .setter
569612 def definition_mode (self , mode ):
613+ """
614+ "EXPLICIT" means that the hyperparameter search is performed over a given set of values (default for grid search)
615+ "RANGE" means that the hyperparameter search is performed over a range of values (default for random and Bayesian
616+ searches)
617+ :param str mode: "EXPLICIT" | "RANGE"
618+ """
570619 assert mode in ["EXPLICIT" , "RANGE" ], "Hyperparameter definition mode must be either \" EXPLICIT\" or \" RANGE\" "
571620 if self ._algo_settings .strategy == "GRID" :
572621 self ._algo_settings [self .name ]["gridMode" ] = mode
@@ -575,6 +624,12 @@ def definition_mode(self, mode):
575624 self ._algo_settings [self .name ]["randomMode" ] = mode
576625
577626 def set_explicit_values (self , values ):
627+ """
628+ Sets both:
629+ - the explicit values to search over for the current numerical hyperparameter
630+ - the definition mode of the current numerical hyperparameter to "EXPLICIT"
631+ :param list values: the explicit list of numerical values that will be searched for this hyperparameter
632+ """
578633 self .values (values )
579634 self .definition_mode = "EXPLICIT"
580635
@@ -584,6 +639,9 @@ def values(self):
584639
585640 @values .setter
586641 def values (self , values ):
642+ """
643+ :param list values: the explicit list of numerical values that will be searched for this hyperparameter
644+ """
587645 error_message = "Invalid values input type for hyperparameter " \
588646 "\" {}\" : " .format (self .name ) + \
589647 " expecting a non-empty list of numbers"
@@ -630,6 +688,14 @@ def _set_range(self, min=None, max=None, nb_values=None):
630688 self ._algo_settings [self .name ]["range" ]["nbValues" ] = nb_values
631689
632690 def set_range (self , min = None , max = None , nb_values = None ):
691+ """
692+ Sets both:
693+ - the Range parameters to search over for the current numerical hyperparameter
694+ - the definition mode of the current numerical hyperparameter to "RANGE"
695+ :param min: the lower bound of the Range that will be searched for this hyperparameter
696+ :param max: the upper bound of the Range that will be searched for this hyperparameter
697+ :param nb_values: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
698+ """
633699 self ._set_range (min = min , max = max , nb_values = nb_values )
634700 self .definition_mode = "RANGE"
635701
@@ -650,6 +716,9 @@ def min(self):
650716
651717 @min .setter
652718 def min (self , val ):
719+ """
720+ :param float | int val: the lower bound of the Range that will be searched for this hyperparameter
721+ """
653722 self ._numerical_hyperparameter_settings ._set_range (min = val )
654723
655724 @property
@@ -658,6 +727,9 @@ def max(self):
658727
659728 @max .setter
660729 def max (self , val ):
730+ """
731+ :param float | int val: the upper bound of the Range that will be searched for this hyperparameter
732+ """
661733 self ._numerical_hyperparameter_settings ._set_range (max = val )
662734
663735 @property
@@ -666,6 +738,9 @@ def nb_values(self):
666738
667739 @nb_values .setter
668740 def nb_values (self , val ):
741+ """
742+ :param int val: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
743+ """
669744 self ._numerical_hyperparameter_settings ._set_range (nb_values = val )
670745
671746
@@ -679,7 +754,7 @@ def __repr__(self):
679754 def _pretty_repr (self ):
680755 return self .__class__ .__name__ + "(hyperparameter=\" {}\" , settings={})" .format (self .name , json .dumps (self ._algo_settings [self .name ], indent = 4 ))
681756
682- def set_values (self , values = None ):
757+ def _set_values (self , values = None ):
683758 if values is None :
684759 warnings .warn ("Categorical hyperparameter \" {}\" not modified" .format (self .name ))
685760 else :
@@ -698,28 +773,45 @@ def set_values(self, values=None):
698773 self ._algo_settings [self .name ]["values" ][category ] = setting
699774
700775 def enable_categories (self , categories , disable_others = False ):
701- accepted_categories = self ._algo_settings [self .name ]["values" ].keys ()
776+ """
777+ Enables the search over categories listed in the first argument.
778+ :param list categories: will enable the search over the provided categories
779+ :param bool disable_others: if True, will also disable the search over categories not listed in the first argument
780+ """
781+ accepted_categories = self .get_all_categories ()
702782 for category in categories :
703783 assert isinstance (category , string_types )
704784 assert category in accepted_categories
705- self .set_values ({category : {"enabled" : True }
706- for category in categories })
785+ self ._set_values ({category : {"enabled" : True }
786+ for category in categories })
707787 if disable_others :
708- self .set_values ({category : {"enabled" : False }
709- for category in accepted_categories
710- if category not in categories })
788+ self ._set_values ({category : {"enabled" : False }
789+ for category in accepted_categories
790+ if category not in categories })
711791
712792 def disable_categories (self , categories , enable_others = False ):
713- accepted_categories = self ._algo_settings [self .name ]["values" ].keys ()
793+ """
794+ Disables the search over categories listed in the first argument.
795+ :param list categories: will disable the search over the provided categories
796+ :param bool enable_others: if True, will also enable the search over categories not listed in the first argument
797+ """
798+ accepted_categories = self .get_all_categories ()
714799 for category in categories :
715800 assert isinstance (category , string_types )
716801 assert category in accepted_categories
717- self .set_values ({category : {"enabled" : False }
718- for category in categories })
802+ self ._set_values ({category : {"enabled" : False }
803+ for category in categories })
719804 if enable_others :
720- self .set_values ({category : {"enabled" : True }
721- for category in accepted_categories
722- if category not in categories })
805+ self ._set_values ({category : {"enabled" : True }
806+ for category in accepted_categories
807+ if category not in categories })
808+
809+ def get_all_categories (self ):
810+ """
811+ :return: list of valid categories for this hyperparameter
812+ """
813+ return list (self ._algo_settings [self .name ]["values" ].keys ())
814+
723815
724816
725817class SingleValueHyperparameterSettings (HyperparameterSettings ):
@@ -1138,7 +1230,7 @@ def get_algorithm_settings(self, algorithm_name):
11381230
11391231 This method returns the settings for this algorithm as an AlgorithmSettings (extended dict).
11401232 All algorithm dicts have at least an "enabled" property/key in the settings.
1141- The ' enabled' key/ property indicates whether this algorithm will be trained.
1233+ The " enabled" property/key indicates whether this algorithm will be trained.
11421234
11431235 Other settings are algorithm-dependent and are the various hyperparameters of the
11441236 algorithm. The precise properties/keys for each algorithm are not all documented. You can print
@@ -1170,6 +1262,10 @@ def get_algorithm_settings(self, algorithm_name):
11701262
11711263 def get_hyperparameter_search_settings (self ):
11721264 """
1265+ Gets the hyperparameter search parameters of the current DSSPredictionMLTaskSettings instance as a
1266+ HyperparameterSearchSettings object. This object can be used to both get and set properties relevant to
1267+ hyperparameter search, such as search strategy, cross-validation method, execution limits and parallelism.
1268+
11731269 :return: A HyperparameterSearchSettings
11741270 :rtype: :class:`HyperparameterSearchSettings`
11751271 """
0 commit comments