|
1 | 1 | import torch |
2 | | -from pytest import mark, raises |
| 2 | +from pytest import mark |
3 | 3 | from torch import Tensor, tensor |
4 | 4 | from torch.ops import aten # type: ignore |
5 | 5 | from torch.testing import assert_close |
|
12 | 12 | ) |
13 | 13 | from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim |
14 | 14 | from torchjd.sparse._coalesce import fix_zero_stride_columns |
15 | | -from torchjd.sparse._linalg import intdiv_c, mod_c |
16 | 15 | from torchjd.sparse._structured_sparse_tensor import ( |
17 | 16 | StructuredSparseTensor, |
18 | 17 | fix_ungrouped_dims, |
@@ -421,49 +420,3 @@ def test_fix_zero_stride_columns( |
421 | 420 | physical, strides = fix_zero_stride_columns(physical, strides) |
422 | 421 | assert torch.equal(physical, expected_physical) |
423 | 422 | assert torch.equal(strides, expected_strides) |
424 | | - |
425 | | - |
426 | | -@mark.parametrize( |
427 | | - ["t1", "t2", "expected"], |
428 | | - [ |
429 | | - (tensor([8, 12]), tensor([2, 3]), tensor([0, 0])), |
430 | | - (tensor([8, 12]), tensor([2, 4]), tensor([2, 0])), |
431 | | - (tensor([8, 12]), tensor([3, 3]), tensor([2, 6])), |
432 | | - (tensor([8, 12]), tensor([2, 0]), tensor([0, 12])), |
433 | | - (tensor([8, 12]), tensor([0, 2]), tensor([8, 0])), |
434 | | - ], |
435 | | -) |
436 | | -def test_mod_c( |
437 | | - t1: Tensor, |
438 | | - t2: Tensor, |
439 | | - expected: Tensor, |
440 | | -): |
441 | | - assert torch.equal(mod_c(t1, t2), expected) |
442 | | - |
443 | | - |
444 | | -def test_mod_c_by_0_raises(): |
445 | | - with raises(ZeroDivisionError): |
446 | | - mod_c(tensor([3, 4]), tensor([0, 0])) |
447 | | - |
448 | | - |
449 | | -@mark.parametrize( |
450 | | - ["t1", "t2", "expected"], |
451 | | - [ |
452 | | - (tensor([8, 12]), tensor([2, 3]), 4), |
453 | | - (tensor([8, 12]), tensor([2, 4]), 3), |
454 | | - (tensor([8, 12]), tensor([3, 3]), 2), |
455 | | - (tensor([8, 12]), tensor([2, 0]), 4), |
456 | | - (tensor([8, 12]), tensor([0, 2]), 6), |
457 | | - ], |
458 | | -) |
459 | | -def test_intdiv_c( |
460 | | - t1: Tensor, |
461 | | - t2: Tensor, |
462 | | - expected: Tensor, |
463 | | -): |
464 | | - assert intdiv_c(t1, t2) == expected |
465 | | - |
466 | | - |
467 | | -def test_intdiv_c_by_0_raises(): |
468 | | - with raises(ZeroDivisionError): |
469 | | - intdiv_c(tensor([3, 4]), tensor([0, 0])) |
0 commit comments