Skip to content

Commit 783b386

Browse files
saran-tcopybara-github
authored andcommitted
Allow dm_control.Physics.{get,set}_state to accept a sig parameter.
When sig is provided, these functions behave like mj_{get,set}State. PiperOrigin-RevId: 823171024 Change-Id: I9d42f3f59c6404e6894f3ab384e92ca928fa6b56
1 parent 049656f commit 783b386

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

dm_control/mujoco/engine.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,35 +232,57 @@ def render(
232232
camera._scene.free() # pylint: disable=protected-access
233233
return image
234234

235-
def get_state(self):
235+
def get_state(self, sig=None):
236236
"""Returns the physics state.
237237
238+
Args:
239+
sig: Optional integer, if specified then the returned array corresponds to
240+
the state obtained by calling `mj_getState` with `sig`.
241+
238242
Returns:
239243
NumPy array containing full physics simulation state.
240244
"""
241-
return np.concatenate(self._physics_state_items())
245+
if sig is None:
246+
return np.concatenate(self._physics_state_items())
247+
else:
248+
retval = np.empty(mujoco.mj_stateSize(self.model.ptr, sig), np.float64)
249+
mujoco.mj_getState(self.model.ptr, self.data.ptr, retval, sig)
250+
return retval
242251

243-
def set_state(self, physics_state):
252+
def set_state(self, physics_state, sig=None):
244253
"""Sets the physics state.
245254
246255
Args:
247256
physics_state: NumPy array containing the full physics simulation state.
257+
sig: Optional integer, if specified then physics_state is passed directly
258+
to `mj_setState` with `sig`.
248259
249260
Raises:
250261
ValueError: If `physics_state` has invalid size.
251262
"""
252-
state_items = self._physics_state_items()
253-
254-
expected_shape = (sum(item.size for item in state_items),)
255-
if expected_shape != physics_state.shape:
256-
raise ValueError('Input physics state has shape {}. Expected {}.'.format(
257-
physics_state.shape, expected_shape))
258-
259-
start = 0
260-
for state_item in state_items:
261-
size = state_item.size
262-
np.copyto(state_item, physics_state[start:start + size])
263-
start += size
263+
if sig is None:
264+
state_items = self._physics_state_items()
265+
266+
expected_shape = (sum(item.size for item in state_items),)
267+
if expected_shape != physics_state.shape:
268+
raise ValueError(
269+
'Input physics state has shape {}. Expected {}.'.format(
270+
physics_state.shape, expected_shape
271+
)
272+
)
273+
274+
start = 0
275+
for state_item in state_items:
276+
size = state_item.size
277+
np.copyto(state_item, physics_state[start:start + size])
278+
start += size
279+
else:
280+
mujoco.mj_setState(
281+
self.model.ptr,
282+
self.data.ptr,
283+
np.asarray(physics_state, np.float64),
284+
sig,
285+
)
264286

265287
def copy(self, share_model=False):
266288
"""Creates a copy of this `Physics` instance.

dm_control/mujoco/engine_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def testNamedViews(self):
370370
self.assertEqual(0., self._physics.time())
371371
self.assertEqual(0.01, self._physics.timestep())
372372

373-
def testSetGetPhysicsState(self):
373+
def testSetGetPhysicsStateLegacy(self):
374374
physics_state = self._physics.get_state()
375375

376376
# qpos, qvel, act
@@ -384,6 +384,27 @@ def testSetGetPhysicsState(self):
384384
np.testing.assert_allclose(new_physics_state,
385385
self._physics.get_state())
386386

387+
def testSetGetPhysicsState(self):
388+
actual_physics_state = self._physics.get_state(
389+
sig=mujoco.mjtState.mjSTATE_FULLPHYSICS.value
390+
)
391+
expected_physics_state = np.zeros(mujoco.mj_stateSize(
392+
self._physics.model.ptr, mujoco.mjtState.mjSTATE_FULLPHYSICS.value
393+
))
394+
mujoco.mj_getState(self._physics.model.ptr, self._physics.data.ptr,
395+
expected_physics_state,
396+
mujoco.mjtState.mjSTATE_FULLPHYSICS.value)
397+
np.testing.assert_array_equal(expected_physics_state, actual_physics_state)
398+
399+
new_physics_state = np.random.random_sample(expected_physics_state.shape)
400+
self._physics.set_state(
401+
new_physics_state, sig=mujoco.mjtState.mjSTATE_FULLPHYSICS.value
402+
)
403+
np.testing.assert_array_equal(
404+
new_physics_state,
405+
self._physics.get_state(sig=mujoco.mjtState.mjSTATE_FULLPHYSICS.value),
406+
)
407+
387408
def testSetGetPhysicsStateWithPlugin(self):
388409
# Model copied from mujoco/test/plugin/elasticity/elasticity_test.cc
389410
model_with_cable_plugin = """

0 commit comments

Comments
 (0)