Skip to content

Commit c2959e6

Browse files
DeepMindcopybara-github
authored andcommitted
Enable 2 changes to composer environment resets:
- Don't recompile the mjcf model every episode. This allows us to gain a lot of time in between resets if we are not making any changes to the mjcf model. - Have a fixed initial state for every episode. This allows to have repeatable episode if desired. PiperOrigin-RevId: 585901093 Change-Id: I9d0b29dc1aba80113b1437ff3fff3f06862923ef
1 parent 634d885 commit c2959e6

2 files changed

Lines changed: 106 additions & 19 deletions

File tree

dm_control/composer/environment.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,26 @@ def control_timestep(self):
292292
class Environment(_CommonEnvironment, dm_env.Environment):
293293
"""Reinforcement learning environment for Composer tasks."""
294294

295-
def __init__(self, task, time_limit=float('inf'), random_state=None,
296-
n_sub_steps=None,
297-
raise_exception_on_physics_error=True,
298-
strip_singleton_obs_buffer_dim=False,
299-
max_reset_attempts=1,
300-
delayed_observation_padding=ObservationPadding.ZERO,
301-
legacy_step: bool = True):
295+
def __init__(
296+
self,
297+
task,
298+
time_limit=float('inf'),
299+
random_state=None,
300+
n_sub_steps=None,
301+
raise_exception_on_physics_error=True,
302+
strip_singleton_obs_buffer_dim=False,
303+
max_reset_attempts=1,
304+
recompile_mjcf_every_episode=True,
305+
fixed_initial_state=False,
306+
delayed_observation_padding=ObservationPadding.ZERO,
307+
legacy_step: bool = True,
308+
):
302309
"""Initializes an instance of `Environment`.
303310
304311
Args:
305312
task: Instance of `composer.base.Task`.
306-
time_limit: (optional) A float, the time limit in seconds beyond which
307-
an episode is forced to terminate.
313+
time_limit: (optional) A float, the time limit in seconds beyond which an
314+
episode is forced to terminate.
308315
random_state: (optional) an int seed or `np.random.RandomState` instance.
309316
n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
310317
agent control step. New code should instead override the
@@ -313,15 +320,22 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
313320
`PhysicsError` should be raised as an exception. If `False`, physics
314321
errors will result in the current episode being terminated with a
315322
warning logged, and a new episode started.
316-
strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`,
317-
the array shape of observations with `buffer_size == 1` will not have a
318-
leading buffer dimension.
323+
strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`, the array
324+
shape of observations with `buffer_size == 1` will not have a leading
325+
buffer dimension.
319326
max_reset_attempts: (optional) Maximum number of times to try resetting
320-
the environment. If an `EpisodeInitializationError` is raised
321-
during this process, an environment reset is reattempted up to this
322-
number of times. If this count is exceeded then the most recent
323-
exception will be allowed to propagate. Defaults to 1, i.e. no failure
324-
is allowed.
327+
the environment. If an `EpisodeInitializationError` is raised during
328+
this process, an environment reset is reattempted up to this number of
329+
times. If this count is exceeded then the most recent exception will be
330+
allowed to propagate. Defaults to 1, i.e. no failure is allowed.
331+
recompile_mjcf_every_episode: If True will recompile the mjcf model
332+
between episodes. This specifically skips the `initialize_episode_mjcf`
333+
and `after_compile` steps. This allows a speedup if no changes are made
334+
to the model.
335+
fixed_initial_state: If True the starting state of every single episode
336+
will be the same. Meaning an identical sequence of action will lead to
337+
an identical final state. If False, will randomize the starting state at
338+
every episode.
325339
delayed_observation_padding: (optional) An `ObservationPadding` enum value
326340
specifying the padding behavior of the initial buffers for delayed
327341
observables. If `ZERO` then the buffer is initially filled with zeroes.
@@ -340,6 +354,10 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
340354
delayed_observation_padding=delayed_observation_padding,
341355
legacy_step=legacy_step)
342356
self._max_reset_attempts = max_reset_attempts
357+
self._recompile_mjcf_every_episode = recompile_mjcf_every_episode
358+
self._mjcf_never_compiled = True
359+
self._fixed_initial_state = fixed_initial_state
360+
self._fixed_random_state = self._random_state.get_state()
343361
self._reset_next_step = True
344362

345363
def reset(self):
@@ -355,8 +373,15 @@ def reset(self):
355373
raise
356374

357375
def _reset_attempt(self):
358-
self._hooks.initialize_episode_mjcf(self._random_state)
359-
self._recompile_physics_and_update_observables()
376+
if self._recompile_mjcf_every_episode or self._mjcf_never_compiled:
377+
if self._fixed_initial_state:
378+
self._random_state.set_state(self._fixed_random_state)
379+
self._hooks.initialize_episode_mjcf(self._random_state)
380+
self._recompile_physics_and_update_observables()
381+
self._mjcf_never_compiled = False
382+
383+
if self._fixed_initial_state:
384+
self._random_state.set_state(self._fixed_random_state)
360385
with self._physics.reset_context():
361386
self._hooks.initialize_episode(self._physics_proxy, self._random_state)
362387
self._observation_updater.reset(self._physics_proxy, self._random_state)

dm_control/composer/environment_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ def initialize_episode(self, physics, random_state):
5353
raise composer.EpisodeInitializationError()
5454

5555

56+
class DummyTaskWithRandomObservation(composer.NullTask):
57+
58+
def __init__(self):
59+
null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
60+
super().__init__(null_entity)
61+
62+
self._observation = [0.0] * 1000
63+
64+
def initialize_episode(self, physics, random_state):
65+
del physics
66+
self._observation = random_state.randint(1000, size=1000)
67+
68+
@property
69+
def task_observables(self):
70+
random_int = observable.Generic(lambda physics: self._observation)
71+
random_int.enabled = True
72+
return {'random_int': random_int}
73+
74+
5675
class EnvironmentTest(parameterized.TestCase):
5776

5877
def test_failed_resets(self):
@@ -96,5 +115,48 @@ def test_can_provide_observation(self):
96115
self.assertLen(obs, 1)
97116
np.testing.assert_array_equal(obs['time'], env.physics.time())
98117

118+
def test_dont_compile_mjcf_between_episodes(self):
119+
class AfterCompileHook(object):
120+
121+
def __init__(self):
122+
self.after_compile_call_count = 0
123+
124+
def __call__(self, physics, random_state):
125+
del physics, random_state
126+
self.after_compile_call_count += 1
127+
128+
after_compile_hook = AfterCompileHook()
129+
task = DummyTask()
130+
env = composer.Environment(task, recompile_mjcf_every_episode=False)
131+
env.add_extra_hook('after_compile', after_compile_hook)
132+
env.reset()
133+
self.assertEqual(after_compile_hook.after_compile_call_count, 1)
134+
for _ in range(4):
135+
env.reset()
136+
env.step([])
137+
138+
# Check the hook is not called.
139+
self.assertEqual(after_compile_hook.after_compile_call_count, 1)
140+
141+
def test_fixed_initial_state(self):
142+
task = DummyTaskWithRandomObservation()
143+
fixed_env = composer.Environment(task, fixed_initial_state=True)
144+
non_fixed_env = composer.Environment(task, fixed_initial_state=False)
145+
fixed_obs = fixed_env.reset().observation['random_int']
146+
non_fixed_obs = non_fixed_env.reset().observation['random_int']
147+
for _ in range(3):
148+
np.testing.assert_array_equal(
149+
fixed_env.reset().observation['random_int'], fixed_obs
150+
)
151+
self.assertTrue(
152+
np.any(
153+
np.not_equal(
154+
np.asarray(non_fixed_obs),
155+
np.asarray(non_fixed_env.reset().observation['random_int']),
156+
)
157+
)
158+
)
159+
160+
99161
if __name__ == '__main__':
100162
absltest.main()

0 commit comments

Comments
 (0)