Skip to content

Commit cb2edaf

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Make use of the new MjData.model attribute to avoid unnecessary copies.
PiperOrigin-RevId: 581939935 Change-Id: I8b5548266580d2dfe63d66070c903a8cd9e305f5
1 parent d4e76b3 commit cb2edaf

1 file changed

Lines changed: 21 additions & 16 deletions

File tree

dm_control/mujoco/wrapper/core.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import contextlib
1919
import copy
2020
import ctypes
21+
from typing import Union
2122
import weakref
2223

2324
from absl import logging
@@ -235,7 +236,7 @@ def _get_model_ptr_from_binary(binary_path=None, byte_string=None):
235236
class _MjModelMeta(type):
236237
"""Metaclass which allows MjModel below to delegate to mujoco.MjModel."""
237238

238-
def __new__(cls, name, bases, dct):
239+
def __new__(mcs, name, bases, dct):
239240
for attr in dir(mujoco.MjModel):
240241
if not attr.startswith("_"):
241242
if attr not in dct:
@@ -245,7 +246,7 @@ def __new__(cls, name, bases, dct):
245246
lambda self, value, attr=attr: setattr(self._model, attr, value))
246247
# pylint: enable=protected-access
247248
dct[attr] = property(fget, fset)
248-
return super().__new__(cls, name, bases, dct)
249+
return super().__new__(mcs, name, bases, dct)
249250

250251

251252
class MjModel(metaclass=_MjModelMeta):
@@ -426,7 +427,7 @@ def name(self):
426427
class _MjDataMeta(type):
427428
"""Metaclass which allows MjData below to delegate to mujoco.MjData."""
428429

429-
def __new__(cls, name, bases, dct):
430+
def __new__(mcs, name, bases, dct):
430431
for attr in dir(mujoco.MjData):
431432
if not attr.startswith("_"):
432433
if attr not in dct:
@@ -435,7 +436,7 @@ def __new__(cls, name, bases, dct):
435436
fset = lambda self, value, attr=attr: setattr(self._data, attr, value)
436437
# pylint: enable=protected-access
437438
dct[attr] = property(fget, fset)
438-
return super().__new__(cls, name, bases, dct)
439+
return super().__new__(mcs, name, bases, dct)
439440

440441

441442
class MjData(metaclass=_MjDataMeta):
@@ -447,31 +448,35 @@ class MjData(metaclass=_MjDataMeta):
447448

448449
_HAS_DYNAMIC_ATTRIBUTES = True
449450

450-
def __init__(self, model):
451-
"""Construct a new MjData instance.
451+
def __init__(self, model_or_data: Union[MjModel, mujoco.MjData]):
452+
"""Constructs a new MjData instance.
452453
453454
Args:
454-
model: An MjModel instance.
455+
model_or_data: dm_control.mujoco.wrapper.MjModel instance, or
456+
mujoco.MjData.
455457
"""
456-
self._model = model
457-
self._data = mujoco.MjData(model._model)
458+
if isinstance(model_or_data, MjModel):
459+
self._model = model_or_data
460+
self._data = mujoco.MjData(model_or_data._model)
461+
elif isinstance(model_or_data, mujoco.MjData):
462+
self._data = model_or_data
463+
self._model = MjModel(self._data.model)
458464

459465
def __getstate__(self):
460-
return (self._model, self._data)
466+
return self._data
461467

462468
def __setstate__(self, state):
463-
self._model, self._data = state
469+
self._data = state
470+
self._model = MjModel(self._data.model)
464471

465472
def __copy__(self):
466473
# This makes a shallow copy that shares the same parent MjModel instance.
467474
return self._make_copy(share_model=True)
468475

469476
def _make_copy(self, share_model):
470-
# TODO(nimrod): Avoid allocating a new MjData just to replace it.
471-
new_obj = self.__class__(
472-
self._model if share_model else copy.copy(self._model))
473-
super(self.__class__, new_obj).__setattr__("_data", copy.copy(self._data))
474-
return new_obj
477+
if share_model:
478+
return self.__class__(copy.copy(self._data))
479+
return self.__class__(copy.deepcopy(self._data))
475480

476481
def copy(self):
477482
"""Returns a copy of this MjData instance with the same parent MjModel."""

0 commit comments

Comments
 (0)