Skip to content

Commit 9c44791

Browse files
authored
Add the nplike test, testing broadcasting (#69)
1 parent a44650f commit 9c44791

3 files changed

Lines changed: 142 additions & 0 deletions

File tree

pep.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,50 @@ dataclasses-style method generation
10801080
]
10811081

10821082

1083+
.. _numpy-impl:
1084+
1085+
NumPy-style broadcasting
1086+
------------------------
1087+
1088+
::
1089+
1090+
class Array[DType, *Shape]:
1091+
def __add__[*Shape2](
1092+
self, other: Array[DType, *Shape2]
1093+
) -> Array[DType, *Broadcast[tuple[*Shape], tuple[*Shape2]]]:
1094+
raise BaseException
1095+
1096+
type AppendTuple[A, B] = tuple[
1097+
*[x for x in typing.Iter[A]],
1098+
B,
1099+
]
1100+
1101+
type MergeOne[T, S] = (
1102+
T
1103+
if typing.Matches[T, S] or typing.Matches[S, Literal[1]]
1104+
else S
1105+
if typing.Matches[T, Literal[1]]
1106+
else typing.RaiseError[Literal["Broadcast mismatch"], T, S]
1107+
)
1108+
1109+
type DropLast[T] = typing.Slice[T, Literal[0], Literal[-1]]
1110+
type Last[T] = typing.GetArg[T, tuple, Literal[-1]]
1111+
1112+
# Matching on Never here is intentional; it prevents infinite
1113+
# recursions when T is not a tuple.
1114+
type Empty[T] = typing.IsSub[typing.Length[T], Literal[0]]
1115+
1116+
type Broadcast[T, S] = (
1117+
S
1118+
if typing.Bool[Empty[T]]
1119+
else T
1120+
if typing.Bool[Empty[S]]
1121+
else AppendTuple[
1122+
Broadcast[DropLast[T], DropLast[S]], MergeOne[Last[T], Last[S]]
1123+
]
1124+
)
1125+
1126+
10831127
Rationale
10841128
=========
10851129

scripts/update-examples.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ scripts/py2rst.py tests/test_fastapilike_2.py --start "Begin PEP section: datacl
99

1010
scripts/py2rst.py tests/test_fastapilike_2.py --start "Begin PEP section: Automatically deriving FastAPI CRUD models" --end "End PEP section" \
1111
| scripts/rst_replace_section.py pep.rst fastapi-impl -i
12+
13+
scripts/py2rst.py tests/test_nplike.py --start "Begin PEP section" --end "End PEP section" \
14+
| scripts/rst_replace_section.py pep.rst numpy-impl -i

tests/test_nplike.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Literal
2+
3+
from typemap import typing
4+
5+
import pytest
6+
7+
# Begin PEP section
8+
9+
10+
class Array[DType, *Shape]:
11+
def __add__[*Shape2](
12+
self, other: Array[DType, *Shape2]
13+
) -> Array[DType, *Broadcast[tuple[*Shape], tuple[*Shape2]]]:
14+
raise BaseException
15+
16+
17+
type AppendTuple[A, B] = tuple[
18+
*[x for x in typing.Iter[A]],
19+
B,
20+
]
21+
22+
type MergeOne[T, S] = (
23+
T
24+
if typing.Matches[T, S] or typing.Matches[S, Literal[1]]
25+
else S
26+
if typing.Matches[T, Literal[1]]
27+
else typing.RaiseError[Literal["Broadcast mismatch"], T, S]
28+
)
29+
30+
type DropLast[T] = typing.Slice[T, Literal[0], Literal[-1]]
31+
type Last[T] = typing.GetArg[T, tuple, Literal[-1]]
32+
33+
# Matching on Never here is intentional; it prevents infinite
34+
# recursions when T is not a tuple.
35+
type Empty[T] = typing.IsSub[typing.Length[T], Literal[0]]
36+
37+
type Broadcast[T, S] = (
38+
S
39+
if typing.Bool[Empty[T]]
40+
else T
41+
if typing.Bool[Empty[S]]
42+
else AppendTuple[
43+
Broadcast[DropLast[T], DropLast[S]], MergeOne[Last[T], Last[S]]
44+
]
45+
)
46+
47+
# End PEP section
48+
49+
type GetElem[T] = typing.GetArg[T, Array, Literal[0]]
50+
type GetShape[T] = typing.Slice[typing.GetArgs[T, Array], Literal[1], None]
51+
52+
# type Apply[T, S] = Array[GetElem[T], *Broadcast[GetShape[T], GetShape[S]]]
53+
type Apply[T, S] = Array[
54+
GetElem[T],
55+
*[x for x in typing.Iter[Broadcast[GetShape[T], GetShape[S]]]],
56+
]
57+
58+
######
59+
from typemap.type_eval import eval_typing, TypeMapError
60+
61+
from typing import Literal as L
62+
63+
64+
def test_nplike_1():
65+
a1 = Array[float, L[4], L[1]]
66+
a2 = Array[float, L[3]]
67+
res = eval_typing(Apply[a1, a2])
68+
69+
assert res == Array[float, L[4], L[3]]
70+
71+
72+
def test_nplike_2():
73+
b1 = Array[float, int, int]
74+
b2 = Array[float, int]
75+
res = eval_typing(Apply[b1, b2])
76+
77+
assert res == Array[float, int, int]
78+
79+
80+
def test_nplike_3():
81+
c1 = Array[float, L[4], L[1], L[5]]
82+
c2 = Array[float, L[4], L[3], L[1]]
83+
res = eval_typing(Apply[c1, c2])
84+
85+
assert res == Array[float, L[4], L[3], L[5]]
86+
87+
88+
def test_nplike_4():
89+
err1 = Array[float, L[4], L[2]]
90+
err2 = Array[float, L[3]]
91+
92+
with pytest.raises(
93+
TypeMapError, match=r"Broadcast mismatch:.*Literal\[2\].*Literal\[3\]"
94+
):
95+
eval_typing(Apply[err1, err2])

0 commit comments

Comments
 (0)