Skip to content

Commit 0c7e4d4

Browse files
authored
Fix tests (#455)
* fix pytorch tests Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com> * fix tests Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com> * don't use formatting Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com> * fix issue Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com> * remove commented out code Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com> * simplify change Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com> --------- Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com>
1 parent 90fdb46 commit 0c7e4d4

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

dice_ml/explainer_interfaces/dice_genetic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,11 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
476476
if rest_members > 0:
477477
new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features))
478478
for new_gen_idx in range(rest_members):
479-
parent1 = random.choice(population[:int(len(population) / 2)])
480-
parent2 = random.choice(population[:int(len(population) / 2)])
479+
top_half = int(len(population) / 2)
480+
parent1_idx = random.randrange(top_half)
481+
parent2_idx = random.randrange(top_half)
482+
parent1 = population[parent1_idx]
483+
parent2 = population[parent2_idx]
481484
child = self.mate(parent1, parent2, features_to_vary, query_instance)
482485
new_generation_2[new_gen_idx] = child
483486

dice_ml/model_interfaces/pytorch_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ def __init__(self, model=None, model_path='', backend='PYT', func=None, kw_args=
2222

2323
super().__init__(model, model_path, backend, func, kw_args)
2424

25-
def load_model(self):
25+
def load_model(self, weights_only=False):
2626
if self.model_path != '':
27-
self.model = torch.load(self.model_path)
27+
self.model = torch.load(self.model_path, weights_only=weights_only)
2828

2929
def get_output(self, input_instance, model_score=True,
3030
transform_data=False, out_tensor=False):
@@ -35,9 +35,9 @@ def get_output(self, input_instance, model_score=True,
3535
"""
3636
input_tensor = input_instance
3737
if transform_data:
38-
input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy()).float()
38+
input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy(dtype=np.float64)).float()
3939
if not torch.is_tensor(input_instance):
40-
input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy()).float()
40+
input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy(dtype=np.float64)).float()
4141
out = self.model(input_tensor).float()
4242
if not out_tensor:
4343
out = out.data.numpy()

0 commit comments

Comments
 (0)