|
27 | 27 | EmptyValue, LongType, SetType, UTF8Type, |
28 | 28 | cql_typename, int8_pack, int64_pack, lookup_casstype, |
29 | 29 | lookup_casstype_simple, parse_casstype_args, |
30 | | - int32_pack, Int32Type, ListType, MapType |
| 30 | + int32_pack, Int32Type, ListType, MapType, VectorType, |
| 31 | + FloatType |
31 | 32 | ) |
32 | 33 | from cassandra.encoder import cql_quote |
33 | 34 | from cassandra.pool import Host |
@@ -190,6 +191,12 @@ class BarType(FooType): |
190 | 191 | self.assertEqual(UTF8Type, ctype.subtypes[2]) |
191 | 192 | self.assertEqual([b'city', None, b'zip'], ctype.names) |
192 | 193 |
|
| 194 | + def test_parse_casstype_vector(self): |
| 195 | + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)") |
| 196 | + self.assertTrue(issubclass(ctype, VectorType)) |
| 197 | + self.assertEqual(3, ctype.vector_size) |
| 198 | + self.assertEqual(FloatType, ctype.subtype) |
| 199 | + |
193 | 200 | def test_empty_value(self): |
194 | 201 | self.assertEqual(str(EmptyValue()), 'EMPTY') |
195 | 202 |
|
@@ -303,6 +310,19 @@ def test_cql_quote(self): |
303 | 310 | self.assertEqual(cql_quote('test'), "'test'") |
304 | 311 | self.assertEqual(cql_quote(0), '0') |
305 | 312 |
|
| 313 | + def test_vector_round_trip(self): |
| 314 | + base = [3.4, 2.9, 41.6, 12.0] |
| 315 | + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") |
| 316 | + base_bytes = ctype.serialize(base, 0) |
| 317 | + self.assertEqual(16, len(base_bytes)) |
| 318 | + result = ctype.deserialize(base_bytes, 0) |
| 319 | + self.assertEqual(len(base), len(result)) |
| 320 | + for idx in range(0,len(base)): |
| 321 | + self.assertAlmostEqual(base[idx], result[idx], places=5) |
| 322 | + |
| 323 | + def test_vector_cql_parameterized_type(self): |
| 324 | + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") |
| 325 | + self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>") |
306 | 326 |
|
307 | 327 | ZERO = datetime.timedelta(0) |
308 | 328 |
|
|
0 commit comments