Skip to content

Commit 316e0f1

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Throw EpisodeInitializationError instead of raw RuntimeError when prop initializer fails.
The more specific error allows callers to catch it if need be. Existing code should continue to work because EpisodeInitializationError is a subclass of RuntimeError. PiperOrigin-RevId: 718344531 Change-Id: I344dd518dab878b36e9a99cd0f4ea3ca4fffcf2a
1 parent 3634882 commit 316e0f1

4 files changed

Lines changed: 27 additions & 18 deletions

File tree

dm_control/composer/initializers/prop_initializer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self,
8888
velocities are less than this threshold.
8989
max_attempts_per_prop: The maximum number of rejection sampling attempts
9090
per prop. If a non-colliding pose cannot be found before this limit is
91-
reached, a `RuntimeError` will be raised.
91+
reached, an `EpisodeInitializationError` will be raised.
9292
settle_physics: (optional) If True, the physics simulation will be
9393
advanced for a few steps to allow the prop positions to settle.
9494
min_settle_physics_time: (optional) When `settle_physics` is True, lower
@@ -170,8 +170,9 @@ def __call__(self, physics, random_state, ignore_contacts_with_entities=None):
170170
subsequently).
171171
172172
Raises:
173-
RuntimeError: If `ignore_collisions == False` and a non-colliding prop
174-
pose could not be found within `max_attempts_per_prop`.
173+
EpisodeInitializationError: If `ignore_collisions == False` and a
174+
non-colliding prop pose could not be found within
175+
`max_attempts_per_prop`.
175176
"""
176177
if ignore_contacts_with_entities is None:
177178
ignore_contacts_with_entities = []
@@ -222,9 +223,12 @@ def place_props():
222223
break
223224

224225
if not success:
225-
raise RuntimeError(_REJECTION_SAMPLING_FAILED.format(
226-
model_name=prop.mjcf_model.model,
227-
max_attempts=self._max_attempts_per_prop))
226+
raise composer.EpisodeInitializationError(
227+
_REJECTION_SAMPLING_FAILED.format(
228+
model_name=prop.mjcf_model.model,
229+
max_attempts=self._max_attempts_per_prop,
230+
)
231+
)
228232

229233
for prop in ignore_contacts_with_entities:
230234
self._restore_contact_parameters(physics, prop, cached_contact_params)
@@ -256,7 +260,7 @@ def place_and_settle():
256260
physics.data.time = original_time
257261

258262
if self._raise_exception_on_settle_failure:
259-
raise RuntimeError(
263+
raise composer.EpisodeInitializationError(
260264
_SETTLING_PHYSICS_FAILED.format(
261265
max_attempts=self._max_settle_physics_attempts,
262266
max_time=self._max_settle_physics_time,

dm_control/composer/initializers/prop_initializer_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def test_rejection_sampling_failure(self):
109109
expected_message = prop_initializer._REJECTION_SAMPLING_FAILED.format(
110110
model_name=spheres[1].mjcf_model.model, # Props are placed in order.
111111
max_attempts=max_attempts_per_prop)
112-
with self.assertRaisesWithLiteralMatch(RuntimeError, expected_message):
112+
with self.assertRaisesWithLiteralMatch(
113+
composer.EpisodeInitializationError, expected_message
114+
):
113115
prop_placer(physics, random_state=np.random.RandomState(0))
114116

115117
def test_ignore_contacts_with_entities(self):
@@ -141,7 +143,9 @@ def test_ignore_contacts_with_entities(self):
141143
prop_placer_init(physics, random_state=np.random.RandomState(0))
142144
expected_message = prop_initializer._REJECTION_SAMPLING_FAILED.format(
143145
model_name=spheres[0].mjcf_model.model, max_attempts=1)
144-
with self.assertRaisesWithLiteralMatch(RuntimeError, expected_message):
146+
with self.assertRaisesWithLiteralMatch(
147+
composer.EpisodeInitializationError, expected_message
148+
):
145149
prop_placer_seq[0](physics, random_state=np.random.RandomState(0))
146150

147151
# Placing the first sphere should succeed if we ignore contacts involving

dm_control/composer/initializers/tcp_initializer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def __call__(self, physics, random_state):
131131
random_state: An `np.random.RandomState` instance.
132132
133133
Raises:
134-
RuntimeError: If a collision-free pose could not be found within
135-
`max_ik_attempts`.
134+
composer.EpisodeInitializationError: If a collision-free pose could not be
135+
found within `max_ik_attempts`.
136136
"""
137137
if self._hand is not None:
138138
target_site = self._hand.tool_center_point
@@ -162,6 +162,9 @@ def __call__(self, physics, random_state):
162162
# positions and try again with a new target.
163163
physics.bind(self._arm.joints).qpos = initial_qpos
164164

165-
raise RuntimeError(_REJECTION_SAMPLING_FAILED.format(
166-
max_rejection_samples=self._max_rejection_samples,
167-
max_ik_attempts=self._max_ik_attempts))
165+
raise composer.EpisodeInitializationError(
166+
_REJECTION_SAMPLING_FAILED.format(
167+
max_rejection_samples=self._max_rejection_samples,
168+
max_ik_attempts=self._max_ik_attempts,
169+
)
170+
)

dm_control/composer/initializers/tcp_initializer_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16-
"""Tests for tcp_initializer."""
17-
1816
import functools
1917

2018
from absl.testing import absltest
@@ -105,7 +103,7 @@ def test_exception_if_hand_colliding_with_fixed_body(self):
105103

106104
initializer = make_initializer()
107105
with self.assertRaisesWithLiteralMatch(
108-
RuntimeError,
106+
composer.EpisodeInitializationError,
109107
tcp_initializer._REJECTION_SAMPLING_FAILED.format(
110108
max_rejection_samples=max_rejection_samples,
111109
max_ik_attempts=max_ik_attempts)):
@@ -145,7 +143,7 @@ def test_exception_if_self_collision(self, with_hand):
145143

146144
initializer = make_initializer()
147145
with self.assertRaisesWithLiteralMatch(
148-
RuntimeError,
146+
composer.EpisodeInitializationError,
149147
tcp_initializer._REJECTION_SAMPLING_FAILED.format(
150148
max_rejection_samples=max_rejection_samples,
151149
max_ik_attempts=max_ik_attempts)):

0 commit comments

Comments
 (0)