Skip to content

Commit 50e69aa

Browse files
authored
Include LightGBM to the list of ML algorithms in the public API (#159)
1 parent 8041263 commit 50e69aa

1 file changed

Lines changed: 27 additions & 0 deletions

File tree

dataikuapi/dss/ml.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,31 @@ def __init__(self, raw_settings, hyperparameter_search_params):
11471147
self.selection_mode = self._register_single_category_hyperparameter("selection_mode", accepted_values=["auto", "sqrt", "log2", "number", "prop"])
11481148

11491149

1150+
class LightGBMSettings(PredictionAlgorithmSettings):
1151+
1152+
def __init__(self, raw_settings, hyperparameter_search_params):
1153+
super(LightGBMSettings, self).__init__(raw_settings, hyperparameter_search_params)
1154+
self.boosting_type = self._register_categorical_hyperparameter("boosting_type")
1155+
self.num_leaves = self._register_numerical_hyperparameter("num_leaves")
1156+
self.learning_rate = self._register_numerical_hyperparameter("learning_rate")
1157+
self.n_estimators = self._register_numerical_hyperparameter("n_estimators")
1158+
self.min_split_gain = self._register_numerical_hyperparameter("min_split_gain")
1159+
self.min_child_weight = self._register_numerical_hyperparameter("min_child_weight")
1160+
self.min_child_samples = self._register_numerical_hyperparameter("min_child_samples")
1161+
self.colsample_bytree = self._register_numerical_hyperparameter("colsample_bytree")
1162+
self.reg_alpha = self._register_numerical_hyperparameter("reg_alpha")
1163+
self.reg_lambda = self._register_numerical_hyperparameter("reg_lambda")
1164+
1165+
self.early_stopping = self._register_single_value_hyperparameter("early_stopping", accepted_types=[bool])
1166+
self.early_stopping_rounds = self._register_single_value_hyperparameter("early_stopping_rounds", accepted_types=[int])
1167+
self.random_state = self._register_single_value_hyperparameter("random_state", accepted_types=[int])
1168+
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
1169+
self.max_depth = self._register_single_value_hyperparameter("max_depth", accepted_types=[int])
1170+
self.subsample = self._register_single_value_hyperparameter("subsample", accepted_types=[float])
1171+
self.subsample_freq = self._register_single_value_hyperparameter("subsample_freq", accepted_types=[int])
1172+
self.use_bagging = self._register_single_value_hyperparameter("use_bagging", accepted_types=[bool])
1173+
1174+
11501175
class XGBoostSettings(PredictionAlgorithmSettings):
11511176

11521177
def __init__(self, raw_settings, hyperparameter_search_params):
@@ -1403,6 +1428,8 @@ class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
14031428
"SVM_REGRESSION": PredictionAlgorithmMeta("svm_regression", SVMSettings),
14041429
"SGD_CLASSIFICATION": PredictionAlgorithmMeta("sgd_classifier", SGDSettings),
14051430
"LARS": PredictionAlgorithmMeta("lars_params", LARSSettings),
1431+
"LIGHTGBM_CLASSIFICATION": PredictionAlgorithmMeta("lightgbm_classification", LightGBMSettings),
1432+
"LIGHTGBM_REGRESSION": PredictionAlgorithmMeta("lightgbm_regression", LightGBMSettings),
14061433
"XGBOOST_CLASSIFICATION": PredictionAlgorithmMeta("xgboost", XGBoostSettings),
14071434
"XGBOOST_REGRESSION": PredictionAlgorithmMeta("xgboost", XGBoostSettings),
14081435
"SPARKLING_DEEP_LEARNING": PredictionAlgorithmMeta("deep_learning_sparkling"),

0 commit comments

Comments
 (0)