|
20 | 20 | __all__ = ['TensorFunction', 'TensorTimeFunction', 'VectorFunction', 'VectorTimeFunction'] |
21 | 21 |
|
22 | 22 |
|
| 23 | +def staggering(stagg, i, j, d, dims): |
| 24 | + if stagg is None: |
| 25 | + # No input |
| 26 | + return NODE if i == j else (d, dims[j]) |
| 27 | + elif isinstance(stagg, (tuple, list)): |
| 28 | + # User input as list or tuple |
| 29 | + return stagg[i][j] |
| 30 | + elif isinstance(stagg, AbstractTensor): |
| 31 | + # From rebuild/tensor property. Indexed as a sympy Matrix |
| 32 | + return stagg[i, j] |
| 33 | + |
| 34 | + |
23 | 35 | class TensorFunction(AbstractTensor): |
24 | 36 | """ |
25 | 37 | Tensor valued Function represented as a Matrix. |
@@ -128,8 +140,7 @@ def __subfunc_setup__(cls, *args, **kwargs): |
128 | 140 | start = i if (symm or diag) else 0 |
129 | 141 | stop = i + 1 if diag else len(dims) |
130 | 142 | for j in range(start, stop): |
131 | | - staggj = (stagg[i][j] if stagg is not None |
132 | | - else (NODE if i == j else (d, dims[j]))) |
| 143 | + staggj = staggering(stagg, i, j, d, dims) |
133 | 144 | sub_kwargs = cls._component_kwargs((i, j), **kwargs) |
134 | 145 | sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}", |
135 | 146 | 'staggered': staggj}) |
|
0 commit comments