-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathbackward.py
More file actions
36 lines (26 loc) · 1.29 KB
/
backward.py
File metadata and controls
36 lines (26 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torch import Tensor
from torch.ops import aten # type: ignore
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
@impl(aten.threshold_backward.default)
def threshold_backward_default(
grad_output: StructuredSparseTensor, self: Tensor, threshold
) -> StructuredSparseTensor:
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
return StructuredSparseTensor(new_physical, grad_output.strides)
@impl(aten.hardtanh_backward.default)
def hardtanh_backward_default(
grad_output: StructuredSparseTensor,
self: Tensor,
min_val: Tensor | int | float,
max_val: Tensor | int | float,
) -> StructuredSparseTensor:
if isinstance(self, StructuredSparseTensor):
raise NotImplementedError()
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
return StructuredSparseTensor(new_physical, grad_output.strides)
@impl(aten.hardswish_backward.default)
def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor):
if isinstance(self, StructuredSparseTensor):
raise NotImplementedError()
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
return StructuredSparseTensor(new_physical, grad_output.strides)