Skip to content

Commit 46390cf

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Implement repr and equality for all stateless variations in dm_control.
Equality is useful for tests, and the repr implementations are useful for debugging, especially for binary operations, as at the moment you get opaque strings. PiperOrigin-RevId: 737684596 Change-Id: I2a3a466200efdfc97f60ff0e4b825333a3b44a00
1 parent 6504caf commit 46390cf

9 files changed

Lines changed: 297 additions & 13 deletions

File tree

dm_control/composer/variation/base.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,21 @@ def __init__(self, op, variation):
9191
self._op = op
9292
self._variation = variation
9393

94+
def __eq__(self, other):
95+
if not isinstance(other, _UnaryOperation):
96+
return False
97+
return self._op == other._op and self._variation == other._variation
98+
99+
def __str__(self):
100+
return f"{self._op.__name__}({self._variation})"
101+
102+
def __repr__(self):
103+
return f"UnaryOperation({self._op.__name__}({self._variation}))"
104+
94105
def __call__(self, initial_value=None, current_value=None, random_state=None):
95106
value = variation_values.evaluate(
96-
self._variation, initial_value, current_value, random_state)
107+
self._variation, initial_value, current_value, random_state
108+
)
97109
return self._op(value)
98110

99111

@@ -105,21 +117,56 @@ def __init__(self, op, first, second):
105117
self._second = second
106118
self._op = op
107119

120+
def __eq__(self, other):
121+
if not isinstance(other, _BinaryOperation):
122+
return False
123+
return (
124+
self._op == other._op
125+
and self._first == other._first
126+
and self._second == other._second
127+
)
128+
129+
def __str__(self):
130+
return f"{self._op.__name__}({self._first}, {self._second})"
131+
132+
def __repr__(self):
133+
return (
134+
f"BinaryOperation({self._op.__name__}({self._first!r},"
135+
f" {self._second!r}))"
136+
)
137+
108138
def __call__(self, initial_value=None, current_value=None, random_state=None):
109139
first_value = variation_values.evaluate(
110-
self._first, initial_value, current_value, random_state)
140+
self._first, initial_value, current_value, random_state
141+
)
111142
second_value = variation_values.evaluate(
112-
self._second, initial_value, current_value, random_state)
143+
self._second, initial_value, current_value, random_state
144+
)
113145
return self._op(first_value, second_value)
114146

115147

116148
class _GetItemOperation(Variation):
149+
"""Returns a single element from the output of a Variation."""
117150

118151
def __init__(self, variation, index):
119152
self._variation = variation
120153
self._index = index
121154

155+
def __eq__(self, other):
156+
if not isinstance(other, _GetItemOperation):
157+
return False
158+
return self._variation == other._variation and self._index == other._index
159+
160+
def __str__(self):
161+
return f"{self._variation}[{self._index}]"
162+
163+
def __repr__(self):
164+
return (
165+
f"GetItemOperation({self._variation!r}[{self._index}])"
166+
)
167+
122168
def __call__(self, initial_value=None, current_value=None, random_state=None):
123169
value = variation_values.evaluate(
124-
self._variation, initial_value, current_value, random_state)
170+
self._variation, initial_value, current_value, random_state
171+
)
125172
return np.asarray(value)[self._index]

dm_control/composer/variation/colors.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
4343
variation_values.evaluate([self._r, self._g, self._b, self._alpha],
4444
initial_value, current_value, random_state))
4545

46+
def __eq__(self, other):
47+
if not isinstance(other, RgbVariation):
48+
return False
49+
return (
50+
self._r == other._r
51+
and self._g == other._g
52+
and self._b == other._b
53+
and self._alpha == other._alpha
54+
)
55+
4656

4757
class HsvVariation(base.Variation):
4858
"""Represents a variation in the HSV color space.
@@ -62,6 +72,22 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
6272
random_state)
6373
return np.asarray(list(colorsys.hsv_to_rgb(h, s, v)) + [alpha])
6474

75+
def __eq__(self, other):
76+
if not isinstance(other, HsvVariation):
77+
return False
78+
return (
79+
self._h == other._h
80+
and self._s == other._s
81+
and self._v == other._v
82+
and self._alpha == other._alpha
83+
)
84+
85+
def __repr__(self):
86+
return (
87+
f"HsvVariation(h={self._h}, s={self._s}, v={self._v}, "
88+
f"alpha={self._alpha})"
89+
)
90+
6591

6692
class GrayVariation(HsvVariation):
6793
"""Represents a variation in gray level.
@@ -73,3 +99,8 @@ class GrayVariation(HsvVariation):
7399

74100
def __init__(self, gray_level, alpha=1.0):
75101
super().__init__(h=0.0, s=0.0, v=gray_level, alpha=alpha)
102+
103+
def __repr__(self):
104+
return (
105+
f"GrayVariation(gray_level={self._v}, alpha={self._alpha})"
106+
)

dm_control/composer/variation/deterministic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ def __init__(self, value):
3333
def __call__(self, initial_value=None, current_value=None, random_state=None):
3434
return self._value
3535

36+
def __eq__(self, other):
37+
if not isinstance(other, Constant):
38+
return False
39+
return self._value == other._value
40+
41+
def __str__(self):
42+
return f"{self._value}"
43+
44+
def __repr__(self):
45+
return f"Constant({self._value!r})"
46+
3647

3748
class Sequence(base.Variation):
3849
"""Variation representing a fixed sequence of values."""
@@ -52,6 +63,8 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
5263

5364

5465
class Identity(base.Variation):
55-
5666
def __call__(self, initial_value=None, current_value=None, random_state=None):
5767
return current_value
68+
69+
def __eq__(self, other):
70+
return isinstance(other, Identity)

dm_control/composer/variation/distributions.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,22 @@ def __getattr__(self, name):
7171
def _callable(self, random_state):
7272
raise NotImplementedError
7373

74+
def __eq__(self, other):
75+
if not isinstance(other, type(self)):
76+
return False
77+
return (
78+
self._args == other._args
79+
and self._kwargs == other._kwargs
80+
and self._single_sample == other._single_sample
81+
)
82+
83+
def __repr__(self):
84+
return '{}(args={}, kwargs={}, single_sample={})'.format(
85+
type(self).__name__,
86+
self._args,
87+
self._kwargs,
88+
self._single_sample)
89+
7490

7591
class Uniform(Distribution):
7692
__slots__ = ()
@@ -118,6 +134,17 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
118134
axis /= np.linalg.norm(axis, axis=-1, keepdims=True)
119135
return axis
120136

137+
def __eq__(self, other):
138+
if not isinstance(other, UniformPointOnSphere):
139+
return False
140+
return self._single_sample == other._single_sample
141+
142+
def __repr__(self):
143+
return '{}(single_sample={})'.format(
144+
type(self).__name__,
145+
self._single_sample,
146+
)
147+
121148

122149
class Normal(Distribution):
123150
__slots__ = ()
@@ -215,3 +242,17 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
215242
self._retain * self._value +
216243
random_state.normal(loc=0.0, scale=self._scale))
217244
return self._value
245+
246+
def __eq__(self, other):
247+
# __eq__ shouldn't be used for this one, because it's stateful.
248+
return id(self) == id(other)
249+
250+
def __repr__(self):
251+
# include id(self), to make sure that two instances with the same parameters
252+
# don't appear equal in logs.
253+
return '{}(id={}, scale={}, retain={}, value={})'.format(
254+
type(self).__name__,
255+
id(self),
256+
self._scale,
257+
self._retain,
258+
self._value)

dm_control/composer/variation/distributions_test.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16-
"""Tests for distributions."""
17-
1816
from absl.testing import absltest
1917
from absl.testing import parameterized
2018
from dm_control.composer.variation import distributions
@@ -43,6 +41,10 @@ def testUniform(self):
4341
variation(random_state=self._variation_random_state),
4442
self._np_random_state.uniform(lower, upper))
4543

44+
self.assertEqual(variation, distributions.Uniform(low=lower, high=upper))
45+
self.assertNotEqual(variation, distributions.Uniform(low=upper, high=upper))
46+
self.assertIn('[2, 3, 4]', repr(variation))
47+
4648
def testUniformChoice(self):
4749
choices = ['apple', 'banana', 'cherry']
4850
variation = distributions.UniformChoice(choices)
@@ -51,6 +53,8 @@ def testUniformChoice(self):
5153
variation(random_state=self._variation_random_state),
5254
self._np_random_state.choice(choices))
5355

56+
self.assertIn('banana', repr(variation))
57+
5458
def testUniformPointOnSphere(self):
5559
variation = distributions.UniformPointOnSphere()
5660
samples = []
@@ -60,8 +64,11 @@ def testUniformPointOnSphere(self):
6064
np.testing.assert_approx_equal(np.linalg.norm(sample), 1.0)
6165
samples.append(sample)
6266
# Make sure that none of the samples are the same.
63-
self.assertLen(
64-
set(np.reshape(np.asarray(samples), -1)), 3 * NUM_ITERATIONS)
67+
self.assertLen(set(np.reshape(np.asarray(samples), -1)), 3 * NUM_ITERATIONS)
68+
self.assertEqual(variation, distributions.UniformPointOnSphere())
69+
self.assertNotEqual(
70+
variation, distributions.UniformPointOnSphere(single_sample=True)
71+
)
6572

6673
def testNormal(self):
6774
loc, scale = 1, 2
@@ -70,6 +77,14 @@ def testNormal(self):
7077
self.assertEqual(
7178
variation(random_state=self._variation_random_state),
7279
self._np_random_state.normal(loc, scale))
80+
self.assertEqual(variation, distributions.Normal(loc=loc, scale=scale))
81+
self.assertNotEqual(
82+
variation, distributions.Normal(loc=loc*2, scale=scale)
83+
)
84+
self.assertEqual(
85+
"Normal(args=(), kwargs={'loc': 1, 'scale': 2}, single_sample=False)",
86+
repr(variation),
87+
)
7388

7489
def testExponential(self):
7590
scale = 3
@@ -78,6 +93,14 @@ def testExponential(self):
7893
self.assertEqual(
7994
variation(random_state=self._variation_random_state),
8095
self._np_random_state.exponential(scale))
96+
self.assertEqual(variation, distributions.Exponential(scale=scale))
97+
self.assertNotEqual(
98+
variation, distributions.Exponential(scale=scale*2)
99+
)
100+
self.assertEqual(
101+
"Exponential(args=(), kwargs={'scale': 3}, single_sample=False)",
102+
repr(variation),
103+
)
81104

82105
def testPoisson(self):
83106
lam = 4
@@ -86,6 +109,14 @@ def testPoisson(self):
86109
self.assertEqual(
87110
variation(random_state=self._variation_random_state),
88111
self._np_random_state.poisson(lam))
112+
self.assertEqual(variation, distributions.Poisson(lam=lam))
113+
self.assertNotEqual(
114+
variation, distributions.Poisson(lam=lam*2)
115+
)
116+
self.assertEqual(
117+
"Poisson(args=(), kwargs={'lam': 4}, single_sample=False)",
118+
repr(variation),
119+
)
89120

90121
@parameterized.parameters(0, 10)
91122
def testBiasedRandomWalk(self, timescale):

dm_control/composer/variation/math.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
class MathOp(base.Variation):
2727
"""Base MathOp class for applying math operations on variation objects.
2828
29-
Subclasses need to implement `_op`, which takes in a single value and applies
30-
the desired math operation. This operation gets applied to the result of the
31-
evaluated base variation object passed at construction. Structured variation
32-
objects are automatically traversed.
29+
Subclasses need to implement `_callable`, which takes in a single value and
30+
applies the desired math operation. This operation gets applied to the result
31+
of the evaluated base variation object passed at construction. Structured
32+
variation objects are automatically traversed.
3333
"""
3434

3535
def __init__(self, *args, **kwargs):
@@ -54,6 +54,21 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
5454
def _callable(self):
5555
pass
5656

57+
def __eq__(self, other):
58+
if not isinstance(other, type(self)):
59+
return False
60+
return (
61+
self._args == other._args
62+
and self._kwargs == other._kwargs
63+
)
64+
65+
def __repr__(self):
66+
return '{}(args={}, kwargs={})'.format(
67+
type(self).__name__,
68+
self._args,
69+
self._kwargs,
70+
)
71+
5772

5873
class Log(MathOp):
5974

dm_control/composer/variation/noises.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
3939
variation_values.evaluate(self._variation, initial_value, current_value,
4040
random_state))
4141

42+
def __eq__(self, other):
43+
if not isinstance(other, Additive):
44+
return False
45+
return (
46+
self._variation == other._variation
47+
and self._cumulative == other._cumulative
48+
)
49+
50+
def __repr__(self):
51+
return (
52+
f"Additive(variation={self._variation}, cumulative={self._cumulative})"
53+
)
54+
4255

4356
class Multiplicative(base.Variation):
4457
"""A variation that multiplies to an existing value.
@@ -58,3 +71,17 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
5871
return base_value * (
5972
variation_values.evaluate(self._variation, initial_value, current_value,
6073
random_state))
74+
75+
def __eq__(self, other):
76+
if not isinstance(other, Multiplicative):
77+
return False
78+
return (
79+
self._variation == other._variation
80+
and self._cumulative == other._cumulative
81+
)
82+
83+
def __repr__(self):
84+
return (
85+
f"Multiplicative(variation={self._variation}, "
86+
f"cumulative={self._cumulative})"
87+
)

0 commit comments

Comments
 (0)