|
9 | 9 | import tensorflow as tf |
10 | 10 |
|
11 | 11 | from dice_ml import diverse_counterfactuals as exp |
12 | | -from dice_ml.counterfactual_explanations import CounterfactualExplanations |
13 | 12 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase |
14 | 13 |
|
15 | 14 |
|
@@ -49,20 +48,22 @@ def __init__(self, data_interface, model_interface): |
49 | 48 | self.hyperparameters = [1, 1, 1] # proximity_weight, diversity_weight, categorical_penalty |
50 | 49 | self.optimizer_weights = [] # optimizer, learning_rate |
51 | 50 |
|
52 | | - def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5, |
53 | | - diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", |
54 | | - features_to_vary="all", permitted_range=None, yloss_type="hinge_loss", |
55 | | - diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad", |
56 | | - optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000, |
57 | | - project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False, |
58 | | - init_near_query_instance=True, tie_random=False, stopping_threshold=0.5, |
59 | | - posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000): |
| 51 | + def _generate_counterfactuals(self, query_instance, total_CFs, |
| 52 | + desired_class="opposite", desired_range=None, |
| 53 | + proximity_weight=0.5, |
| 54 | + diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all", |
| 55 | + permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist", |
| 56 | + feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, |
| 57 | + max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False, |
| 58 | + init_near_query_instance=True, tie_random=False, stopping_threshold=0.5, |
| 59 | + posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000): |
60 | 60 | """Generates diverse counterfactual explanations |
61 | 61 |
|
62 | 62 | :param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe |
63 | 63 | :param total_CFs: Total number of counterfactuals required. |
64 | 64 | :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the |
65 | 65 | outcome class of query_instance for binary classification. |
| 66 | + :param desired_range: Not supported currently. |
66 | 67 | :param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to |
67 | 68 | the query_instance. |
68 | 69 | :param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are. |
@@ -136,16 +137,14 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp |
136 | 137 | init_near_query_instance, tie_random, stopping_threshold, |
137 | 138 | posthoc_sparsity_param, posthoc_sparsity_algorithm, limit_steps_ls) |
138 | 139 |
|
139 | | - counterfactual_explanations = exp.CounterfactualExamples( |
| 140 | + return exp.CounterfactualExamples( |
140 | 141 | data_interface=self.data_interface, |
141 | 142 | final_cfs_df=final_cfs_df, |
142 | 143 | test_instance_df=test_instance_df, |
143 | 144 | final_cfs_df_sparse=final_cfs_df_sparse, |
144 | 145 | posthoc_sparsity_param=posthoc_sparsity_param, |
145 | 146 | desired_class=desired_class) |
146 | 147 |
|
147 | | - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) |
148 | | - |
149 | 148 | def predict_fn(self, input_instance): |
150 | 149 | """prediction function""" |
151 | 150 | temp_preds = self.model.get_output(input_instance).numpy() |
|
0 commit comments