Skip to content

Commit b002aa3

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Never return None from Physics.bind.
Physics.bind used to handle `mjcf_elements == None` by returning None. For clients using type annotations, this made the API annoying, as the None case had to be handled. After this commit, passing None will lead to a crash. PiperOrigin-RevId: 613142103 Change-Id: I2ce3a98dddb61400202c5b2cd7fc67810e44c487
1 parent 3adfe8c commit b002aa3

1 file changed

Lines changed: 7 additions & 11 deletions

File tree

dm_control/mjcf/physics.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def forward(self):
512512
super().forward()
513513
self._dirty = False
514514

515-
def bind(self, mjcf_elements):
515+
def bind(self, mjcf_elements) -> Binding:
516516
"""Creates a binding between this `Physics` instance and `mjcf.Element`s.
517517
518518
The binding allows for easier interaction with the `Physics` data structures
@@ -609,7 +609,7 @@ def bind(self, mjcf_elements):
609609
ValueError: If `mjcf_elements` cannot be bound to this Physics.
610610
"""
611611
if mjcf_elements is None:
612-
return None
612+
raise ValueError('mjcf_elements is None.')
613613

614614
# To reduce overhead from processing MJCF elements and making new bindings,
615615
# we cache and reuse existing Binding objects. The cheapest version of
@@ -625,13 +625,11 @@ def bind(self, mjcf_elements):
625625
# `mjcf_elements` is not iterable.
626626
cache_key = mjcf_elements
627627

628-
needs_new_binding = False
629628
try:
630-
binding = self._bindings[cache_key]
629+
return self._bindings[cache_key]
631630
except KeyError:
632631
# This means `mjcf_elements` is hashable, so we use it as cache key.
633632
namespace, named_index = names_from_elements(mjcf_elements)
634-
needs_new_binding = True
635633
except TypeError:
636634
# This means `mjcf_elements` is unhashable, fallback to caching by name.
637635
namespace, named_index = names_from_elements(mjcf_elements)
@@ -644,14 +642,12 @@ def bind(self, mjcf_elements):
644642
cache_key = (namespace, named_index)
645643

646644
try:
647-
binding = self._bindings[cache_key]
645+
return self._bindings[cache_key]
648646
except KeyError:
649-
needs_new_binding = True
650-
651-
if needs_new_binding:
652-
binding = Binding(weakref.proxy(self), namespace, named_index)
653-
self._bindings[cache_key] = binding
647+
pass
654648

649+
binding = Binding(weakref.proxy(self), namespace, named_index)
650+
self._bindings[cache_key] = binding
655651
return binding
656652

657653

0 commit comments

Comments
 (0)