Skip to content

Commit e8add33

Browse files
committed
Add configure_union_single_collection_dispatch as a union strategy
1 parent 087e1ce commit e8add33

File tree

3 files changed

+161
-3
lines changed

3 files changed

+161
-3
lines changed

src/cattrs/strategies/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from ._class_methods import use_class_methods
44
from ._subclasses import include_subclasses
5-
from ._unions import configure_tagged_union, configure_union_passthrough
5+
from ._unions import configure_tagged_union, configure_union_passthrough, configure_union_single_collection_dispatch
66

77
__all__ = [
88
"configure_tagged_union",
99
"configure_union_passthrough",
10+
"configure_union_single_collection_dispatch",
1011
"include_subclasses",
1112
"use_class_methods",
1213
]

src/cattrs/strategies/_unions.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from collections import defaultdict
2-
from typing import Any, Callable, Union
1+
import collections.abc
2+
from collections import defaultdict, deque
3+
from typing import Any, Callable, Union, get_origin
34

45
from attrs import NOTHING, NothingType
56

@@ -282,3 +283,74 @@ def contains_native_union(exact_type: Any) -> bool:
282283
converter.register_structure_hook_factory(
283284
contains_native_union, make_structure_native_union
284285
)
286+
287+
288+
# Design choice: it was easy to extend the logic to deque and set types
289+
# but are they worth adding?
290+
_COLLECTION_TYPES = frozenset([
291+
collections.abc.MutableSequence,
292+
collections.abc.MutableSet,
293+
collections.abc.Sequence,
294+
collections.abc.Set,
295+
deque,
296+
frozenset,
297+
list,
298+
set,
299+
tuple,
300+
])
301+
302+
def configure_union_single_collection_dispatch(converter: BaseConverter):
303+
def is_union_single_collection(exact_type: Any) -> bool:
304+
# TODO: Handle TypeAliasType (see #742)
305+
306+
if not is_union_type(exact_type):
307+
return False
308+
309+
type_args = set(exact_type.__args__)
310+
if len(type_args) == 2 and type(None) in type_args:
311+
# As in union_passthrough, we do not want to handle optionals
312+
return False
313+
314+
# Design choice: only support the case where one of _COLLECTION_TYPES
315+
# appears in the Union
316+
collection_type_args = [
317+
t
318+
for t in type_args
319+
if t in _COLLECTION_TYPES or get_origin(t) in _COLLECTION_TYPES
320+
]
321+
return len(collection_type_args) == 1
322+
323+
def make_structure_union_single_collection(
324+
exact_type: Any, /
325+
) -> Callable[[Any, Any], Any]:
326+
# TODO: Handle TypeAliasType (see #742)
327+
328+
type_args = set(exact_type.__args__)
329+
collection_type_arg = next(
330+
t
331+
for t in type_args
332+
if t in _COLLECTION_TYPES or get_origin(t) in _COLLECTION_TYPES
333+
)
334+
335+
other_type_args = [t for t in type_args if t != collection_type_arg]
336+
spillover_type: Any = (
337+
Union[tuple(other_type_args)]
338+
if len(other_type_args) > 1
339+
else other_type_args[0]
340+
)
341+
342+
def structure_union_single_collection(
343+
val: Any,
344+
_: Any,
345+
collection_type=collection_type_arg,
346+
spillover=spillover_type,
347+
) -> Any:
348+
# Design choice: only detect known concrete types as valid source types
349+
# That avoids having to blacklist e.g. str or bytes
350+
if isinstance(val, (deque, frozenset, list, set, tuple)):
351+
return converter.structure(val, collection_type)
352+
return converter.structure(val, spillover)
353+
354+
return structure_union_single_collection
355+
356+
converter.register_structure_hook_factory(is_union_single_collection, make_structure_union_single_collection)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
2+
from collections import deque
3+
from collections.abc import Callable, Collection, Iterable, MutableSequence, MutableSet, Sequence, Set
4+
from typing import Any, Union
5+
6+
import pytest
7+
8+
from attrs import define
9+
from cattrs.converters import BaseConverter
10+
from cattrs.strategies import configure_union_single_collection_dispatch
11+
12+
13+
@define
14+
class CollectionParameter:
15+
type_factory: Callable[[Any], type[Collection]]
16+
factory: Callable[[Iterable], Collection]
17+
18+
19+
@pytest.fixture(
20+
params=[
21+
pytest.param(CollectionParameter(lambda t: deque[t], deque), id="deque"),
22+
pytest.param(CollectionParameter(lambda t: frozenset[t], frozenset), id="frozenset"),
23+
pytest.param(CollectionParameter(lambda t: list[t], list), id="list"),
24+
pytest.param(CollectionParameter(lambda t: MutableSequence[t], list), id="MutableSequence"),
25+
pytest.param(CollectionParameter(lambda t: MutableSet[t], set), id="MutableSet"),
26+
pytest.param(CollectionParameter(lambda t: Sequence[t], tuple), id="Sequence"),
27+
pytest.param(CollectionParameter(lambda t: Set[t], frozenset), id="Set"),
28+
pytest.param(CollectionParameter(lambda t: set[t], set), id="set"),
29+
pytest.param(CollectionParameter(lambda t: tuple[t, ...], tuple), id="tuple"),
30+
],
31+
)
32+
def collection(request: pytest.FixtureRequest) -> CollectionParameter:
33+
return request.param
34+
35+
36+
def test_works_with_simple_union(converter: BaseConverter, collection: CollectionParameter):
37+
configure_union_single_collection_dispatch(converter)
38+
39+
union = Union[collection.type_factory(str) | str]
40+
41+
assert converter.structure("abcd", union) == "abcd"
42+
assert converter.structure("abcd", str) == "abcd"
43+
44+
45+
expected_structured = collection.factory(["abcd"])
46+
assert converter.structure(["abcd"], union) == expected_structured
47+
assert converter.structure(["abcd"], collection.type_factory(str)) == expected_structured
48+
assert converter.structure(deque(["abcd"]), union) == expected_structured
49+
assert converter.structure(deque(["abcd"]), collection.type_factory(str)) == expected_structured
50+
assert converter.structure(frozenset(["abcd"]), union) == expected_structured
51+
assert converter.structure(frozenset(["abcd"]), collection.type_factory(str)) == expected_structured
52+
assert converter.structure(set(["abcd"]), union) == expected_structured
53+
assert converter.structure(set(["abcd"]), collection.type_factory(str)) == expected_structured
54+
assert converter.structure(tuple(["abcd"]), union) == expected_structured
55+
assert converter.structure(tuple(["abcd"]), collection.type_factory(str)) == expected_structured
56+
57+
58+
def test_apply_union_disambiguation(converter: BaseConverter, collection: CollectionParameter):
59+
configure_union_single_collection_dispatch(converter)
60+
61+
@define(frozen=True)
62+
class A:
63+
a: int
64+
65+
@define(frozen=True)
66+
class B:
67+
b: int
68+
69+
collection_type = collection.type_factory(Union[A, B])
70+
union = Union[collection_type, A, B]
71+
72+
assert converter.structure({"a": 1}, union) == A(1)
73+
assert converter.structure({"a": 1}, Union[A, B]) == A(1)
74+
assert converter.structure({"a": 1}, A) == A(1)
75+
assert converter.structure({"b": 2}, union) == B(2)
76+
assert converter.structure({"b": 2}, Union[A, B]) == B(2)
77+
assert converter.structure({"b": 2}, B) == B(2)
78+
79+
expected_structured = collection.factory([A(1), B(2)])
80+
assert converter.structure([{"a": 1}, {"b": 2}], union) == expected_structured
81+
assert converter.structure([{"a": 1}, {"b": 2}], collection_type) == expected_structured
82+
assert converter.structure(deque([{"a": 1}, {"b": 2}]), union) == expected_structured
83+
assert converter.structure(deque([{"a": 1}, {"b": 2}]), collection_type) == expected_structured
84+
assert converter.structure(tuple([{"a": 1}, {"b": 2}]), union) == expected_structured
85+
assert converter.structure(tuple([{"a": 1}, {"b": 2}]), collection_type) == expected_structured

0 commit comments

Comments
 (0)