Skip to content

Commit 20e5986

Browse files
authored
Merge pull request #272 from python-adaptive/pickle-LearnerND
Make LearnerND pickleable
2 parents a612a0e + 700bbc8 commit 20e5986

4 files changed

Lines changed: 20 additions & 13 deletions

File tree

adaptive/learner/base_learner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
from contextlib import suppress
3-
from copy import deepcopy
3+
4+
import cloudpickle
45

56
from adaptive.utils import _RequireAttrsABCMeta, load, save
67

@@ -191,7 +192,7 @@ def load(self, fname, compress=True):
191192
self._set_data(data)
192193

193194
def __getstate__(self):
194-
return deepcopy(self.__dict__)
195+
return cloudpickle.dumps(self.__dict__)
195196

196197
def __setstate__(self, state):
197-
self.__dict__ = state
198+
self.__dict__ = cloudpickle.loads(state)

adaptive/learner/learnerND.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
from collections import OrderedDict
55
from collections.abc import Iterable
6+
from copy import deepcopy
67

78
import numpy as np
89
import scipy.spatial
@@ -319,6 +320,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
319320
else:
320321
self._bounds_points = sorted(list(map(tuple, itertools.product(*bounds))))
321322
self._bbox = tuple(tuple(map(float, b)) for b in bounds)
323+
self._interior = None
322324

323325
self.ndim = len(self._bbox)
324326

@@ -337,6 +339,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
337339
# for the output
338340
self._min_value = None
339341
self._max_value = None
342+
self._old_scale = None
340343
self._output_multiplier = (
341344
1 # If we do not know anything, do not scale the values
342345
)
@@ -453,7 +456,7 @@ def _simplex_exists(self, simplex):
453456

454457
def inside_bounds(self, point):
455458
"""Check whether a point is inside the bounds."""
456-
if hasattr(self, "_interior"):
459+
if self._interior is not None:
457460
return self._interior.find_simplex(point, tol=1e-8) >= 0
458461
else:
459462
eps = 1e-8
@@ -988,13 +991,6 @@ def plot_3D(self, with_triangulation=False):
988991

989992
return plotly.offline.iplot(fig)
990993

991-
def _get_data(self):
992-
return self.data
993-
994-
def _set_data(self, data):
995-
if data:
996-
self.tell_many(*zip(*data.items()))
997-
998994
def _get_iso(self, level=0.0, which="surface"):
999995
if which == "surface":
1000996
if self.ndim != 3 or self.vdim != 1:
@@ -1182,3 +1178,10 @@ def _get_plane_color(simplex):
11821178
opacity=opacity,
11831179
lighting=lighting,
11841180
)
1181+
1182+
def _get_data(self):
1183+
return deepcopy(self.__dict__)
1184+
1185+
def _set_data(self, state):
1186+
for k, v in state.items():
1187+
setattr(self, k, v)

adaptive/tests/test_pickling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
IntegratorLearner,
1010
Learner1D,
1111
Learner2D,
12+
LearnerND,
1213
SequenceLearner,
1314
)
1415
from adaptive.runner import simple
@@ -70,6 +71,7 @@ def balancing_learner(f, learner_type, learner_kwargs):
7071
balancing_learner,
7172
dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1))),
7273
),
74+
(LearnerND, dict(bounds=((-1, 1), (-1, 1), (-1, 1)))),
7375
]
7476

7577
serializers = [(pickle, pickleable_f)]

adaptive/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from contextlib import contextmanager
77
from itertools import product
88

9+
import cloudpickle
910
from atomicwrites import AtomicWriter
1011

1112

@@ -46,7 +47,7 @@ def save(fname, data, compress=True):
4647
if dirname:
4748
os.makedirs(dirname, exist_ok=True)
4849

49-
blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
50+
blob = cloudpickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
5051
if compress:
5152
blob = gzip.compress(blob)
5253

@@ -58,7 +59,7 @@ def load(fname, compress=True):
5859
fname = os.path.expanduser(fname)
5960
_open = gzip.open if compress else open
6061
with _open(fname, "rb") as f:
61-
return pickle.load(f)
62+
return cloudpickle.load(f)
6263

6364

6465
def copy_docstring_from(other):

0 commit comments

Comments
 (0)