Skip to content

Commit fd21e5d

Browse files
PaliCMark SaroufimCopilot
authored
[ez] update references so they run locally (#299)
* [ez] update references so they run locally * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent afe623c commit fd21e5d

4 files changed

Lines changed: 16 additions & 13 deletions

File tree

examples/identity_py/reference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from task import input_t, output_t
33
from utils import verbose_allclose
4-
4+
from typing import Tuple
55

66
def generate_input(size: int, seed: int) -> input_t:
77
gen = torch.Generator(device='cuda')
@@ -15,13 +15,13 @@ def ref_kernel(data: input_t) -> output_t:
1515
return data
1616

1717

18-
def check_implementation(data, output) -> str:
18+
def check_implementation(data: input_t, output: output_t) -> Tuple[bool, str]:
1919
expected = ref_kernel(data)
2020
reasons = verbose_allclose(output, expected)
2121
if len(reasons) > 0:
2222
# TODO better processing of reasons
23-
return "mismatch found! custom implementation doesn't match reference.: " + reasons[0]
23+
return False, "mismatch found! custom implementation doesn't match reference: " + reasons[0]
2424

25-
return ''
25+
return True, ''
2626

2727

examples/matmul_py/reference.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from task import input_t, output_t
33
from utils import verbose_allclose
4+
from typing import Tuple
45

56
def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
67
gen = torch.Generator(device='cuda')
@@ -15,12 +16,12 @@ def ref_kernel(data: input_t) -> output_t:
1516
a, b = data
1617
return a @ b
1718

18-
def check_implementation(data: input_t, output: output_t) -> str:
19+
def check_implementation(data: input_t, output: output_t) -> Tuple[bool, str]:
1920
expected = ref_kernel(data)
2021
reasons = verbose_allclose(output, expected)
2122
if len(reasons) > 0:
2223
# TODO better processing of reasons
23-
return "mismatch found! custom implementation doesn't match reference.: " + reasons[0]
24+
return False, "mismatch found! custom implementation doesn't match reference: " + reasons[0]
2425

25-
return ''
26+
return True, ''
2627

examples/softmax_py/reference.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from utils import verbose_allclose, get_device
33
from task import input_t, output_t
4+
from typing import Tuple
45

56

67
def generate_input(size: int, seed: int) -> input_t:
@@ -13,12 +14,12 @@ def generate_input(size: int, seed: int) -> input_t:
1314
def ref_kernel(data: input_t) -> output_t:
1415
return torch.nn.functional.softmax(data, dim=-1)
1516

16-
def check_implementation(data: input_t, output: output_t) -> str:
17+
def check_implementation(data: input_t, output: output_t) -> Tuple[bool, str]:
1718

1819
expected = ref_kernel(data)
1920
reasons = verbose_allclose(output, expected)
2021

2122
if len(reasons) > 0:
22-
return "mismatch found! custom implementation doesn't match reference: " + reasons[0]
23+
return False, "mismatch found! custom implementation doesn't match reference: " + reasons[0]
2324

24-
return ''
25+
return True, ''

examples/vectoradd_py/reference.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from utils import verbose_allclose
22
import torch
33
from task import input_t, output_t
4+
from typing import Tuple
45

56
def ref_kernel(data: input_t) -> output_t:
67
"""
@@ -28,11 +29,11 @@ def generate_input(size: int, seed: int) -> input_t:
2829
def check_implementation(
2930
data: input_t,
3031
output: output_t,
31-
) -> bool:
32+
) -> Tuple[bool, str]:
3233
expected = ref_kernel(data)
3334
reasons = verbose_allclose(output, expected)
3435

3536
if len(reasons) > 0:
36-
return "mismatch found! custom implementation doesn't match reference: " + reasons[0]
37+
return False, "mismatch found! custom implementation doesn't match reference: " + reasons[0]
3738

38-
return ''
39+
return True, ''

0 commit comments

Comments
 (0)