Skip to content

Commit d70e155

Browse files
authored
fixed failing tests for py311 (#456)
* fixed errors with tf and pytorch version updates Signed-off-by: Amit Sharma <amit_sharma@live.com> * updated params error Signed-off-by: Amit Sharma <amit_sharma@live.com> * updated isort for xgboost files Signed-off-by: Amit Sharma <amit_sharma@live.com> * fixed flake8 errors Signed-off-by: Amit Sharma <amit_sharma@live.com> --------- Signed-off-by: Amit Sharma <amit_sharma@live.com>
1 parent 0c7e4d4 commit d70e155

File tree

6 files changed

+10
-12
lines changed

6 files changed

+10
-12
lines changed

dice_ml/explainer_interfaces/dice_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def predict_fn(self, input_instance):
150150

151151
def predict_fn_for_sparsity(self, input_instance):
152152
"""prediction function for sparsity correction"""
153-
input_instance = self.model.transformer.transform(input_instance).to_numpy()[0]
153+
input_instance = self.model.transformer.transform(input_instance).to_numpy(dtype=np.float64)[0]
154154
return self.predict_fn(torch.tensor(input_instance).float())
155155

156156
def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
@@ -418,7 +418,7 @@ def find_counterfactuals(self, query_instance, desired_class, optimizer, learnin
418418
init_near_query_instance, tie_random, stopping_threshold, posthoc_sparsity_param,
419419
posthoc_sparsity_algorithm, limit_steps_ls):
420420
"""Finds counterfactuals by gradient-descent."""
421-
query_instance = self.model.transformer.transform(query_instance).to_numpy()[0]
421+
query_instance = self.model.transformer.transform(query_instance).to_numpy(dtype=np.float64)[0]
422422
self.x1 = torch.tensor(query_instance)
423423

424424
# find the predicted value of query_instance

dice_ml/explainer_interfaces/dice_xgboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
22

3+
34
class DiceXGBoost(ExplainerBase):
45
def __init__(self, data_interface, model_interface):
56
"""Initialize with data and model interfaces"""

dice_ml/model_interfaces/xgboost_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import xgboost as xgb
2-
from dice_ml.model_interfaces.base_model import BaseModel
2+
33
from dice_ml.constants import ModelTypes
4+
from dice_ml.model_interfaces.base_model import BaseModel
5+
46

57
class XGBoostModel(BaseModel):
6-
8+
79
def __init__(self, model=None, model_path='', backend='', func=None, kw_args=None):
810
super().__init__(model=model, model_path=model_path, backend='xgboost', func=func, kw_args=kw_args)
911
if model is None and model_path:
@@ -27,4 +29,4 @@ def get_output(self, input_instance, model_score=True):
2729
return self.model.predict(input_instance)
2830

2931
def get_gradient(self):
30-
raise NotImplementedError("XGBoost does not support gradient calculation in this context")
32+
raise NotImplementedError("XGBoost does not support gradient calculation in this context")
1.36 KB
Binary file not shown.

tests/test_dice_interface/test_dice_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def _initiate_exp_object(self, pyt_exp_object, sample_adultincome_query):
3030
# query_instance = self.exp.data_interface.prepare_query_instance(
3131
# query_instance=sample_adultincome_query, encoding='one-hot')
3232
# self.query_instance = query_instance.iloc[0].values
33-
self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(sample_adultincome_query).iloc[0].values
33+
self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(
34+
sample_adultincome_query).iloc[0].to_numpy(dtype=np.float64)
3435

3536
self.exp.initialize_CFs(self.query_instance, init_near_query_instance=True) # initialize CFs
3637
self.exp.target_cf_class = torch.tensor(1).float() # set desired class to 1

tests/test_dice_interface/test_explainer_base.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,6 @@ def test_desired_class(
261261
ans = exp.generate_counterfactuals(query_instances=sample_custom_query_2,
262262
features_to_vary='all',
263263
total_CFs=2, desired_class=desired_class,
264-
proximity_weight=0.2, sparsity_weight=0.2,
265-
diversity_weight=5.0,
266-
categorical_penalty=0.1,
267264
permitted_range=None)
268265
if method != 'kdtree':
269266
assert all(ans.cf_examples_list[0].final_cfs_df[exp.data_interface.outcome_name].values == [desired_class] * 2)
@@ -277,9 +274,6 @@ def test_desired_class(
277274
ans = new_exp.generate_counterfactuals(query_instances=sample_custom_query_2,
278275
features_to_vary='all',
279276
total_CFs=2, desired_class=desired_class,
280-
proximity_weight=0.2, sparsity_weight=0.2,
281-
diversity_weight=5.0,
282-
categorical_penalty=0.1,
283277
permitted_range=None)
284278
if method != 'kdtree':
285279
assert all(ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * 2)

0 commit comments

Comments
 (0)