Skip to content

Commit 90fdb46

Browse files
fix: TensorFlow explainers have the same signature as the rest (#415)
Right now the TensorFlow gradient explainers do not implement the `_generate_counterfactuals` method from the `BaseExplainer` class. This means that: 1) You cannot instantiate object of class `DiceTensorFlow(1/2)` without replacing the `__class__` of another method, because it is not a valid child of `ExplainerBase`. 2) By overriding the parent `generate_counterfactuals` method, the two classes bypass any validation steps that would normally be carried out by the base class (e.g. checking that the number of CF queries is non-negative). Signed-off-by: Asen Dotsinski <asendotsinski@proton.me>
1 parent 7081b33 commit 90fdb46

2 files changed

Lines changed: 22 additions & 24 deletions

File tree

dice_ml/explainer_interfaces/dice_tensorflow1.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import tensorflow as tf
1111

1212
from dice_ml import diverse_counterfactuals as exp
13-
from dice_ml.counterfactual_explanations import CounterfactualExplanations
1413
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1514

1615

@@ -61,20 +60,22 @@ def __init__(self, data_interface, model_interface):
6160
self.loss_weights = [] # yloss_type, diversity_loss_type, feature_weights
6261
self.optimizer_weights = [] # optimizer
6362

64-
def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5,
65-
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF",
66-
features_to_vary="all", permitted_range=None, yloss_type="hinge_loss",
67-
diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad",
68-
optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000,
69-
project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
70-
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
71-
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
63+
def _generate_counterfactuals(self, query_instance, total_CFs,
64+
desired_class="opposite", desired_range=None,
65+
proximity_weight=0.5,
66+
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all",
67+
permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist",
68+
feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500,
69+
max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
70+
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
71+
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
7272
"""Generates diverse counterfactual explanations
7373
7474
:param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe.
7575
:param total_CFs: Total number of counterfactuals required.
7676
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the
7777
outcome class of query_instance for binary classification.
78+
:param desired_range: Not supported currently.
7879
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the
7980
query_instance.
8081
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
@@ -159,16 +160,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
159160
loss_diff_thres, loss_converge_maxiter, verbose, init_near_query_instance, tie_random,
160161
stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm)
161162

162-
counterfactual_explanations = exp.CounterfactualExamples(
163+
return exp.CounterfactualExamples(
163164
data_interface=self.data_interface,
164165
final_cfs_df=final_cfs_df,
165166
test_instance_df=test_instance_df,
166167
final_cfs_df_sparse=final_cfs_df_sparse,
167168
posthoc_sparsity_param=posthoc_sparsity_param,
168169
desired_class=desired_class)
169170

170-
return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])
171-
172171
def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
173172
"""Intializes TF variables required for CF generation."""
174173

dice_ml/explainer_interfaces/dice_tensorflow2.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import tensorflow as tf
1010

1111
from dice_ml import diverse_counterfactuals as exp
12-
from dice_ml.counterfactual_explanations import CounterfactualExplanations
1312
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1413

1514

@@ -49,20 +48,22 @@ def __init__(self, data_interface, model_interface):
4948
self.hyperparameters = [1, 1, 1] # proximity_weight, diversity_weight, categorical_penalty
5049
self.optimizer_weights = [] # optimizer, learning_rate
5150

52-
def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5,
53-
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF",
54-
features_to_vary="all", permitted_range=None, yloss_type="hinge_loss",
55-
diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad",
56-
optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000,
57-
project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
58-
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
59-
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
51+
def _generate_counterfactuals(self, query_instance, total_CFs,
52+
desired_class="opposite", desired_range=None,
53+
proximity_weight=0.5,
54+
diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all",
55+
permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist",
56+
feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500,
57+
max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False,
58+
init_near_query_instance=True, tie_random=False, stopping_threshold=0.5,
59+
posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000):
6060
"""Generates diverse counterfactual explanations
6161
6262
:param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe
6363
:param total_CFs: Total number of counterfactuals required.
6464
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the
6565
outcome class of query_instance for binary classification.
66+
:param desired_range: Not supported currently.
6667
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to
6768
the query_instance.
6869
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
@@ -136,16 +137,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
136137
init_near_query_instance, tie_random, stopping_threshold,
137138
posthoc_sparsity_param, posthoc_sparsity_algorithm, limit_steps_ls)
138139

139-
counterfactual_explanations = exp.CounterfactualExamples(
140+
return exp.CounterfactualExamples(
140141
data_interface=self.data_interface,
141142
final_cfs_df=final_cfs_df,
142143
test_instance_df=test_instance_df,
143144
final_cfs_df_sparse=final_cfs_df_sparse,
144145
posthoc_sparsity_param=posthoc_sparsity_param,
145146
desired_class=desired_class)
146147

147-
return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])
148-
149148
def predict_fn(self, input_instance):
150149
"""prediction function"""
151150
temp_preds = self.model.get_output(input_instance).numpy()

0 commit comments

Comments
 (0)