@@ -232,35 +232,57 @@ def render(
232232 camera ._scene .free () # pylint: disable=protected-access
233233 return image
234234
235- def get_state (self ):
235+ def get_state (self , sig = None ):
236236 """Returns the physics state.
237237
238+ Args:
239+ sig: Optional integer, if specified then the returned array corresponds to
240+ the state obtained by calling `mj_getState` with `sig`.
241+
238242 Returns:
239243 NumPy array containing full physics simulation state.
240244 """
241- return np .concatenate (self ._physics_state_items ())
245+ if sig is None :
246+ return np .concatenate (self ._physics_state_items ())
247+ else :
248+ retval = np .empty (mujoco .mj_stateSize (self .model .ptr , sig ), np .float64 )
249+ mujoco .mj_getState (self .model .ptr , self .data .ptr , retval , sig )
250+ return retval
242251
243- def set_state (self , physics_state ):
252+ def set_state (self , physics_state , sig = None ):
244253 """Sets the physics state.
245254
246255 Args:
247256 physics_state: NumPy array containing the full physics simulation state.
257+ sig: Optional integer, if specified then physics_state is passed directly
258+ to `mj_setState` with `sig`.
248259
249260 Raises:
250261 ValueError: If `physics_state` has invalid size.
251262 """
252- state_items = self ._physics_state_items ()
253-
254- expected_shape = (sum (item .size for item in state_items ),)
255- if expected_shape != physics_state .shape :
256- raise ValueError ('Input physics state has shape {}. Expected {}.' .format (
257- physics_state .shape , expected_shape ))
258-
259- start = 0
260- for state_item in state_items :
261- size = state_item .size
262- np .copyto (state_item , physics_state [start :start + size ])
263- start += size
263+ if sig is None :
264+ state_items = self ._physics_state_items ()
265+
266+ expected_shape = (sum (item .size for item in state_items ),)
267+ if expected_shape != physics_state .shape :
268+ raise ValueError (
269+ 'Input physics state has shape {}. Expected {}.' .format (
270+ physics_state .shape , expected_shape
271+ )
272+ )
273+
274+ start = 0
275+ for state_item in state_items :
276+ size = state_item .size
277+ np .copyto (state_item , physics_state [start :start + size ])
278+ start += size
279+ else :
280+ mujoco .mj_setState (
281+ self .model .ptr ,
282+ self .data .ptr ,
283+ np .asarray (physics_state , np .float64 ),
284+ sig ,
285+ )
264286
265287 def copy (self , share_model = False ):
266288 """Creates a copy of this `Physics` instance.
0 commit comments