Skip to content

Commit 45b63ab

Browse files
committed
Rename AlgorithmSettings to PredictionAlgorithmSettings and AlgorithmMeta to PredictionAlgorithmMeta
1 parent b49d30e commit 45b63ab

1 file changed

Lines changed: 62 additions & 57 deletions

File tree

dataikuapi/dss/ml.py

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,15 @@ def set_value(self, value):
765765
self._algo_settings[self.name] = value
766766

767767

768-
class AlgorithmSettings(dict):
768+
class PredictionAlgorithmSettings(dict):
769769

770+
"""
771+
Object to read and modify the settings of a prediction ML algorithm.
772+
773+
Do not create this object directly, use :meth:`DSSMLTask.get_algorithm_settings(algorithm)` instead
774+
"""
770775
def __init__(self, raw_settings, hyperparameter_search_params):
771-
super(AlgorithmSettings, self).__init__(raw_settings)
776+
super(PredictionAlgorithmSettings, self).__init__(raw_settings)
772777
self._hyperparameter_search_params = hyperparameter_search_params
773778
self._hyperparameters_registry = dict()
774779

@@ -796,7 +801,7 @@ def _repr_html_(self):
796801
return res + "<details><pre>{}</pre></details>".format(self.__repr__())
797802

798803
def __repr__(self):
799-
return self.__class__.__name__ + "(values={})".format(super(AlgorithmSettings, self).copy())
804+
return self.__class__.__name__ + "(values={})".format(super(PredictionAlgorithmSettings, self).copy())
800805

801806
__str__ = __repr__
802807

@@ -817,7 +822,7 @@ def strategy(self):
817822
return self._hyperparameter_search_params["strategy"]
818823

819824

820-
class RandomForestSettings(AlgorithmSettings):
825+
class RandomForestSettings(PredictionAlgorithmSettings):
821826

822827
def __init__(self, raw_settings, hyperparameter_search_params):
823828
super(RandomForestSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -830,7 +835,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
830835
self.selection_mode = self._register_single_category_hyperparameter("selection_mode", accepted_values=["auto", "sqrt", "log2", "number", "prop"])
831836

832837

833-
class XGBoostSettings(AlgorithmSettings):
838+
class XGBoostSettings(PredictionAlgorithmSettings):
834839

835840
def __init__(self, raw_settings, hyperparameter_search_params):
836841
super(XGBoostSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -860,7 +865,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
860865
self.early_stopping_rounds = self._register_single_value_hyperparameter("early_stopping_rounds", accepted_types=[int])
861866

862867

863-
class GradientBoostedTreesSettings(AlgorithmSettings):
868+
class GradientBoostedTreesSettings(PredictionAlgorithmSettings):
864869

865870
def __init__(self, raw_settings, hyperparameter_search_params):
866871
super(GradientBoostedTreesSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -874,7 +879,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
874879
self.selection_mode = self._register_single_category_hyperparameter("selection_mode", accepted_values=["auto", "sqrt", "log2", "number", "prop"])
875880

876881

877-
class DecisionTreeSettings(AlgorithmSettings):
882+
class DecisionTreeSettings(PredictionAlgorithmSettings):
878883

879884
def __init__(self, raw_settings, hyperparameter_search_params):
880885
super(DecisionTreeSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -884,7 +889,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
884889
self.splitter = self._register_categorical_hyperparameter("splitter")
885890

886891

887-
class LogitSettings(AlgorithmSettings):
892+
class LogitSettings(PredictionAlgorithmSettings):
888893

889894
def __init__(self, raw_settings, hyperparameter_search_params):
890895
super(LogitSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -894,38 +899,38 @@ def __init__(self, raw_settings, hyperparameter_search_params):
894899
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
895900

896901

897-
class RidgeRegressionSettings(AlgorithmSettings):
902+
class RidgeRegressionSettings(PredictionAlgorithmSettings):
898903

899904
def __init__(self, raw_settings, hyperparameter_search_params):
900905
super(RidgeRegressionSettings, self).__init__(raw_settings, hyperparameter_search_params)
901906
self.alpha = self._register_numerical_hyperparameter("alpha", self)
902907
self.alpha_mode = self._register_single_category_hyperparameter("alphaMode", accepted_values=["MANUAL", "AUTO"])
903908

904909

905-
class LassoRegressionSettings(AlgorithmSettings):
910+
class LassoRegressionSettings(PredictionAlgorithmSettings):
906911

907912
def __init__(self, raw_settings, hyperparameter_search_params):
908913
super(LassoRegressionSettings, self).__init__(raw_settings, hyperparameter_search_params)
909914
self.alpha = self._register_numerical_hyperparameter("alpha")
910915
self.alpha_mode = self._register_single_category_hyperparameter("alphaMode", accepted_values=["MANUAL", "AUTO_CV", "AUTO_IC"]) # TODO: enforce attribute name = parameter name ?
911916

912917

913-
class OLSSettings(AlgorithmSettings):
918+
class OLSSettings(PredictionAlgorithmSettings):
914919

915920
def __init__(self, raw_settings, hyperparameter_search_params):
916921
super(OLSSettings, self).__init__(raw_settings, hyperparameter_search_params)
917922
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
918923

919924

920-
class LARSSettings(AlgorithmSettings):
925+
class LARSSettings(PredictionAlgorithmSettings):
921926

922927
def __init__(self, raw_settings, hyperparameter_search_params):
923928
super(LARSSettings, self).__init__(raw_settings, hyperparameter_search_params)
924929
self.max_features = self._register_single_value_hyperparameter("max_features", accepted_types=[int])
925930
self.K = self._register_single_value_hyperparameter("K", accepted_types=[int])
926931

927932

928-
class SGDSettings(AlgorithmSettings):
933+
class SGDSettings(PredictionAlgorithmSettings):
929934

930935
def __init__(self, raw_settings, hyperparameter_search_params):
931936
super(SGDSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -938,7 +943,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
938943
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
939944

940945

941-
class KNNSettings(AlgorithmSettings):
946+
class KNNSettings(PredictionAlgorithmSettings):
942947

943948
def __init__(self, raw_settings, hyperparameter_search_params):
944949
super(KNNSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -949,7 +954,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
949954
self.leaf_size = self._register_single_value_hyperparameter("leaf_size", accepted_types=[int])
950955

951956

952-
class SVMSettings(AlgorithmSettings):
957+
class SVMSettings(PredictionAlgorithmSettings):
953958

954959
def __init__(self, raw_settings, hyperparameter_search_params):
955960
super(SVMSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -962,7 +967,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
962967
self.max_iter = self._register_single_value_hyperparameter("max_iter", accepted_types=[int])
963968

964969

965-
class MLPSettings(AlgorithmSettings):
970+
class MLPSettings(PredictionAlgorithmSettings):
966971

967972
def __init__(self, raw_settings, hyperparameter_search_params):
968973
super(MLPSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -988,7 +993,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
988993
self.learning_rate_init = self._register_single_value_hyperparameter("learning_rate_init", accepted_types=[int, float])
989994

990995

991-
class MLLibLogitSettings(AlgorithmSettings):
996+
class MLLibLogitSettings(PredictionAlgorithmSettings):
992997

993998
def __init__(self, raw_settings, hyperparameter_search_params):
994999
super(MLLibLogitSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -997,14 +1002,14 @@ def __init__(self, raw_settings, hyperparameter_search_params):
9971002
self.max_iter = self._register_single_value_hyperparameter("max_iter", accepted_types=[int])
9981003

9991004

1000-
class MLLibNaiveBayesSettings(AlgorithmSettings):
1005+
class MLLibNaiveBayesSettings(PredictionAlgorithmSettings):
10011006

10021007
def __init__(self, raw_settings, hyperparameter_search_params):
10031008
super(MLLibNaiveBayesSettings, self).__init__(raw_settings, hyperparameter_search_params)
10041009
self.lambda_ = self._register_numerical_hyperparameter("lambda")
10051010

10061011

1007-
class MLLibLinearRegressionSettings(AlgorithmSettings):
1012+
class MLLibLinearRegressionSettings(PredictionAlgorithmSettings):
10081013

10091014
def __init__(self, raw_settings, hyperparameter_search_params):
10101015
super(MLLibLinearRegressionSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -1013,7 +1018,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10131018
self.max_iter = self._register_single_value_hyperparameter("max_iter", accepted_types=[int])
10141019

10151020

1016-
class MLLibDecisionTreeSettings(AlgorithmSettings):
1021+
class MLLibDecisionTreeSettings(PredictionAlgorithmSettings):
10171022

10181023
def __init__(self, raw_settings, hyperparameter_search_params):
10191024
super(MLLibDecisionTreeSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -1026,7 +1031,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10261031
self.min_instance_per_node = self._register_single_value_hyperparameter("min_instance_per_node", accepted_types=[int])
10271032

10281033

1029-
class _MLLibTreeEnsembleSettings(AlgorithmSettings):
1034+
class _MLLibTreeEnsembleSettings(PredictionAlgorithmSettings):
10301035

10311036
def __init__(self, raw_settings, hyperparameter_search_params):
10321037
super(_MLLibTreeEnsembleSettings, self).__init__(raw_settings, hyperparameter_search_params)
@@ -1060,49 +1065,49 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10601065
self.step_size = self._register_numerical_hyperparameter("step_size")
10611066

10621067

1063-
class AlgorithmMeta:
1064-
def __init__(self, algorithm_name, algorithm_settings_class=AlgorithmSettings):
1068+
class PredictionAlgorithmMeta:
1069+
def __init__(self, algorithm_name, algorithm_settings_class=PredictionAlgorithmSettings):
10651070
self.algorithm_name = algorithm_name
10661071
self.algorithm_settings_class = algorithm_settings_class
10671072

10681073

10691074
class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
10701075
__doc__ = []
10711076
algorithm_remap = {
1072-
"RANDOM_FOREST_CLASSIFICATION": AlgorithmMeta("random_forest_classification", RandomForestSettings),
1073-
"RANDOM_FOREST_REGRESSION": AlgorithmMeta("random_forest_regression", RandomForestSettings),
1074-
"EXTRA_TREES": AlgorithmMeta("extra_trees", RandomForestSettings),
1075-
"GBT_CLASSIFICATION": AlgorithmMeta("gbt_classification", GradientBoostedTreesSettings),
1076-
"GBT_REGRESSION": AlgorithmMeta("gbt_regression", GradientBoostedTreesSettings),
1077-
"DECISION_TREE_CLASSIFICATION": AlgorithmMeta("decision_tree_classification", DecisionTreeSettings),
1078-
"DECISION_TREE_REGRESSION": AlgorithmMeta("decision_tree_regression", DecisionTreeSettings),
1079-
"RIDGE_REGRESSION": AlgorithmMeta("ridge_regression", RidgeRegressionSettings),
1080-
"LASSO_REGRESSION": AlgorithmMeta("lasso_regression", LassoRegressionSettings),
1081-
"LEASTSQUARE_REGRESSION": AlgorithmMeta("leastsquare_regression", OLSSettings),
1082-
"SGD_REGRESSION": AlgorithmMeta("sgd_regression", SGDSettings),
1083-
"KNN": AlgorithmMeta("knn", KNNSettings),
1084-
"LOGISTIC_REGRESSION": AlgorithmMeta("logistic_regression", LogitSettings),
1085-
"NEURAL_NETWORK": AlgorithmMeta("neural_network", MLPSettings),
1086-
"SVC_CLASSIFICATION": AlgorithmMeta("svc_classifier", SVMSettings),
1087-
"SVM_REGRESSION": AlgorithmMeta("svm_regression", SVMSettings),
1088-
"SGD_CLASSIFICATION": AlgorithmMeta("sgd_classifier", SGDSettings),
1089-
"LARS": AlgorithmMeta("lars_params", LARSSettings),
1090-
"XGBOOST_CLASSIFICATION": AlgorithmMeta("xgboost", XGBoostSettings),
1091-
"XGBOOST_REGRESSION": AlgorithmMeta("xgboost", XGBoostSettings),
1092-
"SPARKLING_DEEP_LEARNING": AlgorithmMeta("deep_learning_sparkling"),
1093-
"SPARKLING_GBM": AlgorithmMeta("gbm_sparkling"),
1094-
"SPARKLING_RF": AlgorithmMeta("rf_sparkling"),
1095-
"SPARKLING_GLM": AlgorithmMeta("glm_sparkling"),
1096-
"SPARKLING_NB": AlgorithmMeta("nb_sparkling"),
1097-
"MLLIB_LOGISTIC_REGRESSION": AlgorithmMeta("mllib_logit", MLLibLogitSettings),
1098-
"MLLIB_NAIVE_BAYES": AlgorithmMeta("mllib_naive_bayes", MLLibNaiveBayesSettings),
1099-
"MLLIB_LINEAR_REGRESSION": AlgorithmMeta("mllib_linreg", MLLibLinearRegressionSettings),
1100-
"MLLIB_RANDOM_FOREST": AlgorithmMeta("mllib_rf", MLLibRandomForestSettings),
1101-
"MLLIB_GBT": AlgorithmMeta("mllib_gbt", MLLibGBTSettings),
1102-
"MLLIB_DECISION_TREE": AlgorithmMeta("mllib_dt", MLLibDecisionTreeSettings),
1103-
"VERTICA_LINEAR_REGRESSION": AlgorithmMeta("vertica_linear_regression"),
1104-
"VERTICA_LOGISTIC_REGRESSION": AlgorithmMeta("vertica_logistic_regression"),
1105-
"KERAS_CODE": AlgorithmMeta("keras")
1077+
"RANDOM_FOREST_CLASSIFICATION": PredictionAlgorithmMeta("random_forest_classification", RandomForestSettings),
1078+
"RANDOM_FOREST_REGRESSION": PredictionAlgorithmMeta("random_forest_regression", RandomForestSettings),
1079+
"EXTRA_TREES": PredictionAlgorithmMeta("extra_trees", RandomForestSettings),
1080+
"GBT_CLASSIFICATION": PredictionAlgorithmMeta("gbt_classification", GradientBoostedTreesSettings),
1081+
"GBT_REGRESSION": PredictionAlgorithmMeta("gbt_regression", GradientBoostedTreesSettings),
1082+
"DECISION_TREE_CLASSIFICATION": PredictionAlgorithmMeta("decision_tree_classification", DecisionTreeSettings),
1083+
"DECISION_TREE_REGRESSION": PredictionAlgorithmMeta("decision_tree_regression", DecisionTreeSettings),
1084+
"RIDGE_REGRESSION": PredictionAlgorithmMeta("ridge_regression", RidgeRegressionSettings),
1085+
"LASSO_REGRESSION": PredictionAlgorithmMeta("lasso_regression", LassoRegressionSettings),
1086+
"LEASTSQUARE_REGRESSION": PredictionAlgorithmMeta("leastsquare_regression", OLSSettings),
1087+
"SGD_REGRESSION": PredictionAlgorithmMeta("sgd_regression", SGDSettings),
1088+
"KNN": PredictionAlgorithmMeta("knn", KNNSettings),
1089+
"LOGISTIC_REGRESSION": PredictionAlgorithmMeta("logistic_regression", LogitSettings),
1090+
"NEURAL_NETWORK": PredictionAlgorithmMeta("neural_network", MLPSettings),
1091+
"SVC_CLASSIFICATION": PredictionAlgorithmMeta("svc_classifier", SVMSettings),
1092+
"SVM_REGRESSION": PredictionAlgorithmMeta("svm_regression", SVMSettings),
1093+
"SGD_CLASSIFICATION": PredictionAlgorithmMeta("sgd_classifier", SGDSettings),
1094+
"LARS": PredictionAlgorithmMeta("lars_params", LARSSettings),
1095+
"XGBOOST_CLASSIFICATION": PredictionAlgorithmMeta("xgboost", XGBoostSettings),
1096+
"XGBOOST_REGRESSION": PredictionAlgorithmMeta("xgboost", XGBoostSettings),
1097+
"SPARKLING_DEEP_LEARNING": PredictionAlgorithmMeta("deep_learning_sparkling"),
1098+
"SPARKLING_GBM": PredictionAlgorithmMeta("gbm_sparkling"),
1099+
"SPARKLING_RF": PredictionAlgorithmMeta("rf_sparkling"),
1100+
"SPARKLING_GLM": PredictionAlgorithmMeta("glm_sparkling"),
1101+
"SPARKLING_NB": PredictionAlgorithmMeta("nb_sparkling"),
1102+
"MLLIB_LOGISTIC_REGRESSION": PredictionAlgorithmMeta("mllib_logit", MLLibLogitSettings),
1103+
"MLLIB_NAIVE_BAYES": PredictionAlgorithmMeta("mllib_naive_bayes", MLLibNaiveBayesSettings),
1104+
"MLLIB_LINEAR_REGRESSION": PredictionAlgorithmMeta("mllib_linreg", MLLibLinearRegressionSettings),
1105+
"MLLIB_RANDOM_FOREST": PredictionAlgorithmMeta("mllib_rf", MLLibRandomForestSettings),
1106+
"MLLIB_GBT": PredictionAlgorithmMeta("mllib_gbt", MLLibGBTSettings),
1107+
"MLLIB_DECISION_TREE": PredictionAlgorithmMeta("mllib_dt", MLLibDecisionTreeSettings),
1108+
"VERTICA_LINEAR_REGRESSION": PredictionAlgorithmMeta("vertica_linear_regression"),
1109+
"VERTICA_LOGISTIC_REGRESSION": PredictionAlgorithmMeta("vertica_logistic_regression"),
1110+
"KERAS_CODE": PredictionAlgorithmMeta("keras")
11061111
}
11071112

11081113
class PredictionTypes:

0 commit comments

Comments
 (0)