66from torch .utils ._pytree import tree_map
77
88# pointwise functions applied to one Tensor with `0.0 → 0`
9- _pointwise_functions = {
9+ _POINTWISE_FUNCTIONS = {
1010 aten .abs .default ,
1111 aten .abs_ .default ,
1212 aten .absolute .default ,
6565 aten .leaky_relu .default ,
6666 aten .leaky_relu_ .default ,
6767}
68+ _HANDLED_FUNCTIONS = dict ()
69+ import functools
70+
71+
72+ def implements (torch_function ):
73+ """Register a torch function override for ScalarTensor"""
74+
75+ def decorator (func ):
76+ functools .update_wrapper (func , torch_function )
77+ _HANDLED_FUNCTIONS [torch_function ] = func
78+ return func
79+
80+ return decorator
6881
6982
7083class DiagonalSparseTensor (torch .Tensor ):
@@ -85,6 +98,10 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
8598 # (which is bad!)
8699 assert not data .requires_grad or not torch .is_grad_enabled ()
87100
101+ # TODO: assert a minimal data, all of its dimensions must be used at least once
102+ # TODO: If no repeat in v_to_p, return a view of data (non sparse tensor). If this cannot be
103+ # done in __new__, create a helper function for that, and use this one everywhere.
104+
88105 shape = [data .shape [i ] for i in v_to_p ]
89106 return Tensor ._make_wrapper_subclass (cls , shape , dtype = data .dtype , device = data .device )
90107
@@ -117,7 +134,7 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar
117134
118135 # If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0
119136 # Then we can apply the transformation to self._data and wrap the result.
120- if func in _pointwise_functions :
137+ if func in _POINTWISE_FUNCTIONS :
121138 assert (
122139 isinstance (args , tuple ) and len (args ) == 1 and func (torch .zeros ([])).item () == 0.0
123140 )
@@ -126,9 +143,8 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar
126143 new_data = func (sparse_tensor ._data )
127144 return DiagonalSparseTensor (new_data , sparse_tensor ._v_to_p )
128145
129- # TODO: Handle batched operations (apply to self._data and wrap)
130- # TODO: Handle all operations that can be represented with an einsum by translating them
131- # to operations on self._data and wrapping accordingly.
146+ if func in _HANDLED_FUNCTIONS :
147+ return _HANDLED_FUNCTIONS [func ](* args , ** kwargs )
132148
133149 # --- Fallback: Fold to Dense Tensor ---
134150 def unwrap_to_dense (t : Tensor ):
@@ -145,3 +161,15 @@ def __repr__(self):
145161 f"DiagonalSparseTensor(data={ self ._data } , v_to_p_map={ self ._v_to_p } , shape="
146162 f"{ self ._v_shape } )"
147163 )
164+
165+
166+ @implements (aten .mean .default )
167+ def mean_default (t : Tensor ) -> Tensor :
168+ assert isinstance (t , DiagonalSparseTensor )
169+ return aten .sum .default (t ._data ) / t .numel ()
170+
171+
172+ @implements (aten .sum .default )
173+ def sum_default (t : Tensor ) -> Tensor :
174+ assert isinstance (t , DiagonalSparseTensor )
175+ return aten .sum .default (t ._data )
0 commit comments