@@ -292,19 +292,26 @@ def control_timestep(self):
292292class Environment (_CommonEnvironment , dm_env .Environment ):
293293 """Reinforcement learning environment for Composer tasks."""
294294
295- def __init__ (self , task , time_limit = float ('inf' ), random_state = None ,
296- n_sub_steps = None ,
297- raise_exception_on_physics_error = True ,
298- strip_singleton_obs_buffer_dim = False ,
299- max_reset_attempts = 1 ,
300- delayed_observation_padding = ObservationPadding .ZERO ,
301- legacy_step : bool = True ):
295+ def __init__ (
296+ self ,
297+ task ,
298+ time_limit = float ('inf' ),
299+ random_state = None ,
300+ n_sub_steps = None ,
301+ raise_exception_on_physics_error = True ,
302+ strip_singleton_obs_buffer_dim = False ,
303+ max_reset_attempts = 1 ,
304+ recompile_mjcf_every_episode = True ,
305+ fixed_initial_state = False ,
306+ delayed_observation_padding = ObservationPadding .ZERO ,
307+ legacy_step : bool = True ,
308+ ):
302309 """Initializes an instance of `Environment`.
303310
304311 Args:
305312 task: Instance of `composer.base.Task`.
306- time_limit: (optional) A float, the time limit in seconds beyond which
307- an episode is forced to terminate.
313+ time_limit: (optional) A float, the time limit in seconds beyond which an
314+ episode is forced to terminate.
308315 random_state: (optional) an int seed or `np.random.RandomState` instance.
309316 n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
310317 agent control step. New code should instead override the
@@ -313,15 +320,22 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
313320 `PhysicsError` should be raised as an exception. If `False`, physics
314321 errors will result in the current episode being terminated with a
315322 warning logged, and a new episode started.
316- strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`,
317- the array shape of observations with `buffer_size == 1` will not have a
318- leading buffer dimension.
323+ strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`, the array
324+ shape of observations with `buffer_size == 1` will not have a leading
325+ buffer dimension.
319326 max_reset_attempts: (optional) Maximum number of times to try resetting
320- the environment. If an `EpisodeInitializationError` is raised
321- during this process, an environment reset is reattempted up to this
322- number of times. If this count is exceeded then the most recent
323- exception will be allowed to propagate. Defaults to 1, i.e. no failure
324- is allowed.
327+ the environment. If an `EpisodeInitializationError` is raised during
328+ this process, an environment reset is reattempted up to this number of
329+ times. If this count is exceeded then the most recent exception will be
330+ allowed to propagate. Defaults to 1, i.e. no failure is allowed.
331+ recompile_mjcf_every_episode: If True will recompile the mjcf model
332+ between episodes. This specifically skips the `initialize_episode_mjcf`
333+ and `after_compile` steps. This allows a speedup if no changes are made
334+ to the model.
335+ fixed_initial_state: If True the starting state of every single episode
336+ will be the same. Meaning an identical sequence of action will lead to
337+ an identical final state. If False, will randomize the starting state at
338+ every episode.
325339 delayed_observation_padding: (optional) An `ObservationPadding` enum value
326340 specifying the padding behavior of the initial buffers for delayed
327341 observables. If `ZERO` then the buffer is initially filled with zeroes.
@@ -340,6 +354,10 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
340354 delayed_observation_padding = delayed_observation_padding ,
341355 legacy_step = legacy_step )
342356 self ._max_reset_attempts = max_reset_attempts
357+ self ._recompile_mjcf_every_episode = recompile_mjcf_every_episode
358+ self ._mjcf_never_compiled = True
359+ self ._fixed_initial_state = fixed_initial_state
360+ self ._fixed_random_state = self ._random_state .get_state ()
343361 self ._reset_next_step = True
344362
345363 def reset (self ):
@@ -355,8 +373,15 @@ def reset(self):
355373 raise
356374
357375 def _reset_attempt (self ):
358- self ._hooks .initialize_episode_mjcf (self ._random_state )
359- self ._recompile_physics_and_update_observables ()
376+ if self ._recompile_mjcf_every_episode or self ._mjcf_never_compiled :
377+ if self ._fixed_initial_state :
378+ self ._random_state .set_state (self ._fixed_random_state )
379+ self ._hooks .initialize_episode_mjcf (self ._random_state )
380+ self ._recompile_physics_and_update_observables ()
381+ self ._mjcf_never_compiled = False
382+
383+ if self ._fixed_initial_state :
384+ self ._random_state .set_state (self ._fixed_random_state )
360385 with self ._physics .reset_context ():
361386 self ._hooks .initialize_episode (self ._physics_proxy , self ._random_state )
362387 self ._observation_updater .reset (self ._physics_proxy , self ._random_state )
0 commit comments