1313# limitations under the License.
1414# ============================================================================
1515
16- """Tests for distributions."""
17-
1816from absl .testing import absltest
1917from absl .testing import parameterized
2018from dm_control .composer .variation import distributions
@@ -43,6 +41,10 @@ def testUniform(self):
4341 variation (random_state = self ._variation_random_state ),
4442 self ._np_random_state .uniform (lower , upper ))
4543
44+ self .assertEqual (variation , distributions .Uniform (low = lower , high = upper ))
45+ self .assertNotEqual (variation , distributions .Uniform (low = upper , high = upper ))
46+ self .assertIn ('[2, 3, 4]' , repr (variation ))
47+
4648 def testUniformChoice (self ):
4749 choices = ['apple' , 'banana' , 'cherry' ]
4850 variation = distributions .UniformChoice (choices )
@@ -51,6 +53,8 @@ def testUniformChoice(self):
5153 variation (random_state = self ._variation_random_state ),
5254 self ._np_random_state .choice (choices ))
5355
56+ self .assertIn ('banana' , repr (variation ))
57+
5458 def testUniformPointOnSphere (self ):
5559 variation = distributions .UniformPointOnSphere ()
5660 samples = []
@@ -60,8 +64,11 @@ def testUniformPointOnSphere(self):
6064 np .testing .assert_approx_equal (np .linalg .norm (sample ), 1.0 )
6165 samples .append (sample )
6266 # Make sure that none of the samples are the same.
63- self .assertLen (
64- set (np .reshape (np .asarray (samples ), - 1 )), 3 * NUM_ITERATIONS )
67+ self .assertLen (set (np .reshape (np .asarray (samples ), - 1 )), 3 * NUM_ITERATIONS )
68+ self .assertEqual (variation , distributions .UniformPointOnSphere ())
69+ self .assertNotEqual (
70+ variation , distributions .UniformPointOnSphere (single_sample = True )
71+ )
6572
6673 def testNormal (self ):
6774 loc , scale = 1 , 2
@@ -70,6 +77,14 @@ def testNormal(self):
7077 self .assertEqual (
7178 variation (random_state = self ._variation_random_state ),
7279 self ._np_random_state .normal (loc , scale ))
80+ self .assertEqual (variation , distributions .Normal (loc = loc , scale = scale ))
81+ self .assertNotEqual (
82+ variation , distributions .Normal (loc = loc * 2 , scale = scale )
83+ )
84+ self .assertEqual (
85+ "Normal(args=(), kwargs={'loc': 1, 'scale': 2}, single_sample=False)" ,
86+ repr (variation ),
87+ )
7388
7489 def testExponential (self ):
7590 scale = 3
@@ -78,6 +93,14 @@ def testExponential(self):
7893 self .assertEqual (
7994 variation (random_state = self ._variation_random_state ),
8095 self ._np_random_state .exponential (scale ))
96+ self .assertEqual (variation , distributions .Exponential (scale = scale ))
97+ self .assertNotEqual (
98+ variation , distributions .Exponential (scale = scale * 2 )
99+ )
100+ self .assertEqual (
101+ "Exponential(args=(), kwargs={'scale': 3}, single_sample=False)" ,
102+ repr (variation ),
103+ )
81104
82105 def testPoisson (self ):
83106 lam = 4
@@ -86,6 +109,14 @@ def testPoisson(self):
86109 self .assertEqual (
87110 variation (random_state = self ._variation_random_state ),
88111 self ._np_random_state .poisson (lam ))
112+ self .assertEqual (variation , distributions .Poisson (lam = lam ))
113+ self .assertNotEqual (
114+ variation , distributions .Poisson (lam = lam * 2 )
115+ )
116+ self .assertEqual (
117+ "Poisson(args=(), kwargs={'lam': 4}, single_sample=False)" ,
118+ repr (variation ),
119+ )
89120
90121 @parameterized .parameters (0 , 10 )
91122 def testBiasedRandomWalk (self , timescale ):
0 commit comments