Skip to content

Commit 405f74c

Browse files
sbohezcopybara-github
authored andcommitted
Add QuatRotate variation and use dm_control.utils.transformations for calcs.
PiperOrigin-RevId: 719432052 Change-Id: I17cfbe510532b6b97212bd1c5ae87f5aa204ee78
1 parent 316e0f1 commit 405f74c

3 files changed

Lines changed: 80 additions & 9 deletions

File tree

dm_control/composer/variation/rotations.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from dm_control.composer.variation import base
2020
from dm_control.composer.variation import variation_values
21+
from dm_control.utils import transformations
2122
import numpy as np
2223

2324
IDENTITY_QUATERNION = np.array([1., 0., 0., 0.])
@@ -48,8 +49,7 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
4849
self._axis, initial_value, current_value, random_state)
4950
angle = variation_values.evaluate(
5051
self._angle, initial_value, current_value, random_state)
51-
sine, cosine = np.sin(angle / 2), np.cos(angle / 2)
52-
return np.array([cosine, axis[0] * sine, axis[1] * sine, axis[2] * sine])
52+
return transformations.axisangle_to_quat(np.asarray(axis) * angle)
5353

5454

5555
class QuaternionPreMultiply(base.Variation):
@@ -70,8 +70,32 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
7070
q1 = variation_values.evaluate(self._quat, initial_value, current_value,
7171
random_state)
7272
q2 = current_value if self._cumulative else initial_value
73-
return np.array([
74-
q1[0]*q2[0] - q1[1]*q2[1] - q1[2]*q2[2] - q1[3]*q2[3],
75-
q1[0]*q2[1] + q1[1]*q2[0] + q1[2]*q2[3] - q1[3]*q2[2],
76-
q1[0]*q2[2] - q1[1]*q2[3] + q1[2]*q2[0] + q1[3]*q2[1],
77-
q1[0]*q2[3] + q1[1]*q2[2] - q1[2]*q2[1] + q1[3]*q2[0]])
73+
return transformations.quat_mul(np.asarray(q1), np.asarray(q2))
74+
75+
76+
class QuaternionRotate(base.Variation):
77+
"""Variation that rotates a given vector by the given quaternion.
78+
79+
The vector can either be an existing value passed at evaluation, or specified
80+
as a separate variation at construction. In the former case, cumulative mode
81+
determines whether to use the current or initial value of the vector. The#
82+
quaternion is always specified by a variation at construction.
83+
"""
84+
85+
def __init__(self, quat, vec=None, cumulative=False):
86+
self._quat = quat
87+
self._vec = vec
88+
self._cumulative = cumulative
89+
90+
def __call__(self, initial_value=None, current_value=None, random_state=None):
91+
random_state = random_state or np.random
92+
quat = variation_values.evaluate(
93+
self._quat, initial_value, current_value, random_state
94+
)
95+
if self._vec is None:
96+
vec = current_value if self._cumulative else initial_value
97+
else:
98+
vec = variation_values.evaluate(
99+
self._vec, initial_value, current_value, random_state
100+
)
101+
return transformations.quat_rotate(np.asarray(quat), np.asarray(vec))

dm_control/utils/transformations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,26 @@ def mat_to_quat(mat):
625625
return q
626626

627627

628+
def axisangle_to_quat(axisangle, tol=0.0):
629+
"""Returns the quaternion corresponding to the provided axis-angle.
630+
631+
Args:
632+
axisangle: A 3x1 numpy array describing the axis of rotation, with angle
633+
encoded by its length.
634+
tol: Tolerance for the angle magnitude below which the identity quaternion
635+
is returned.
636+
637+
Returns:
638+
A quaternion [w, i, j, k].
639+
"""
640+
axisangle = np.asarray(axisangle)
641+
angle = np.linalg.norm(axisangle, axis=-1, keepdims=True)
642+
axis = np.where(angle <= tol, [1.0, 0.0, 0.0], axisangle / angle)
643+
angle = np.where(angle <= tol, [0.0], angle)
644+
sine, cosine = np.sin(angle / 2), np.cos(angle / 2)
645+
return np.concatenate([cosine, axis * sine], axis=-1)
646+
647+
628648
# ################
629649
# # 2D Functions #
630650
# ################

dm_control/utils/transformations_test.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16-
"""Tests for dm_control.locomotion.tasks.transformations."""
17-
1816
import itertools
1917

2018
from absl.testing import absltest
@@ -261,6 +259,35 @@ def _random_quaternion(self):
261259
(np.cos(t2) * r2, np.sin(t1) * r1, np.cos(t1) * r1, np.sin(t2) * r2),
262260
dtype=np.float64)
263261

262+
def test_axisangle_to_quat(self):
263+
axisangle = np.array([0.1, 0.2, 0.3])
264+
quat = transformations.axisangle_to_quat(axisangle)
265+
np.testing.assert_allclose(
266+
quat, [0.982551, 0.0497088, 0.0994177, 0.1491265], atol=1e-6
267+
)
268+
269+
def test_axisangle_to_quat_zero(self):
270+
axisangle = np.array([0, 0, 0])
271+
quat = transformations.axisangle_to_quat(axisangle)
272+
np.testing.assert_allclose(quat, [1, 0, 0, 0])
273+
274+
def test_axisangle_to_quat_zero_tol(self):
275+
axisangle = np.array([0, 0, 1e-2])
276+
quat = transformations.axisangle_to_quat(axisangle, tol=1e-1)
277+
np.testing.assert_allclose(quat, [1, 0, 0, 0])
278+
279+
def test_axisangle_to_quat_batched(self):
280+
axisangle = np.stack([np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6])])
281+
quat = transformations.axisangle_to_quat(axisangle)
282+
np.testing.assert_allclose(
283+
quat,
284+
[
285+
[0.982551, 0.0497088, 0.0994177, 0.1491265],
286+
[0.9052841, 0.1936448, 0.242056, 0.2904672],
287+
],
288+
atol=1e-6,
289+
)
290+
264291

265292
if __name__ == '__main__':
266293
absltest.main()

0 commit comments

Comments
 (0)