1818import contextlib
1919import copy
2020import ctypes
21+ from typing import Union
2122import weakref
2223
2324from absl import logging
@@ -235,7 +236,7 @@ def _get_model_ptr_from_binary(binary_path=None, byte_string=None):
235236class _MjModelMeta (type ):
236237 """Metaclass which allows MjModel below to delegate to mujoco.MjModel."""
237238
238- def __new__ (cls , name , bases , dct ):
239+ def __new__ (mcs , name , bases , dct ):
239240 for attr in dir (mujoco .MjModel ):
240241 if not attr .startswith ("_" ):
241242 if attr not in dct :
@@ -245,7 +246,7 @@ def __new__(cls, name, bases, dct):
245246 lambda self , value , attr = attr : setattr (self ._model , attr , value ))
246247 # pylint: enable=protected-access
247248 dct [attr ] = property (fget , fset )
248- return super ().__new__ (cls , name , bases , dct )
249+ return super ().__new__ (mcs , name , bases , dct )
249250
250251
251252class MjModel (metaclass = _MjModelMeta ):
@@ -426,7 +427,7 @@ def name(self):
426427class _MjDataMeta (type ):
427428 """Metaclass which allows MjData below to delegate to mujoco.MjData."""
428429
429- def __new__ (cls , name , bases , dct ):
430+ def __new__ (mcs , name , bases , dct ):
430431 for attr in dir (mujoco .MjData ):
431432 if not attr .startswith ("_" ):
432433 if attr not in dct :
@@ -435,7 +436,7 @@ def __new__(cls, name, bases, dct):
435436 fset = lambda self , value , attr = attr : setattr (self ._data , attr , value )
436437 # pylint: enable=protected-access
437438 dct [attr ] = property (fget , fset )
438- return super ().__new__ (cls , name , bases , dct )
439+ return super ().__new__ (mcs , name , bases , dct )
439440
440441
441442class MjData (metaclass = _MjDataMeta ):
@@ -447,31 +448,35 @@ class MjData(metaclass=_MjDataMeta):
447448
448449 _HAS_DYNAMIC_ATTRIBUTES = True
449450
450- def __init__ (self , model ):
451- """Construct a new MjData instance.
451+ def __init__ (self , model_or_data : Union [ MjModel , mujoco . MjData ] ):
452+ """Constructs a new MjData instance.
452453
453454 Args:
454- model: An MjModel instance.
455+ model_or_data: dm_control.mujoco.wrapper.MjModel instance, or
456+ mujoco.MjData.
455457 """
456- self ._model = model
457- self ._data = mujoco .MjData (model ._model )
458+ if isinstance (model_or_data , MjModel ):
459+ self ._model = model_or_data
460+ self ._data = mujoco .MjData (model_or_data ._model )
461+ elif isinstance (model_or_data , mujoco .MjData ):
462+ self ._data = model_or_data
463+ self ._model = MjModel (self ._data .model )
458464
459465 def __getstate__ (self ):
460- return ( self ._model , self . _data )
466+ return self ._data
461467
462468 def __setstate__ (self , state ):
463- self ._model , self ._data = state
469+ self ._data = state
470+ self ._model = MjModel (self ._data .model )
464471
465472 def __copy__ (self ):
466473 # This makes a shallow copy that shares the same parent MjModel instance.
467474 return self ._make_copy (share_model = True )
468475
469476 def _make_copy (self , share_model ):
470- # TODO(nimrod): Avoid allocating a new MjData just to replace it.
471- new_obj = self .__class__ (
472- self ._model if share_model else copy .copy (self ._model ))
473- super (self .__class__ , new_obj ).__setattr__ ("_data" , copy .copy (self ._data ))
474- return new_obj
477+ if share_model :
478+ return self .__class__ (copy .copy (self ._data ))
479+ return self .__class__ (copy .deepcopy (self ._data ))
475480
476481 def copy (self ):
477482 """Returns a copy of this MjData instance with the same parent MjModel."""
0 commit comments