diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 03d0572..c38a859 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -307,18 +307,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim(): if cursor_self == self.tensor.dim() and cursor_plan != len(new_shape): new_shape_check = new_shape[cursor_plan] - if (isinstance(new_shape_check, int) and new_shape_check == 1) or ( - new_shape_check == (1, 0) + if ( + (isinstance(new_shape_check, int) and new_shape_check == 1) + or new_shape_check == (1, 0) + or new_shape_check == (0, 1) ): if cursor_plan < len(self.arrow): arrow.append(self.arrow[cursor_plan]) else: arrow.append(False) - edges.append((1, 0)) + if new_shape_check == (0, 1): + edges.append((0, 1)) + else: + edges.append((1, 0)) shape.append(1) cursor_plan += 1 continue - raise AssertionError( + raise ValueError( "New shape exceeds after exhausting self dimensions: " f"edges={self.edges}, new_shape={new_shape}" ) @@ -359,6 +364,14 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens # A trivial self edge cursor_self += 1 continue + if cursor_plan == len(new_shape) and cursor_self != self.tensor.dim(): + if self.tensor.shape[cursor_self] == 1: + cursor_self += 1 + continue + raise ValueError( + "New shape exhausted but self still has non-trivial dimensions: " + f"edges={self.edges}, new_shape={new_shape}, cursor_self={cursor_self}" + ) cursor_new_shape = new_shape[cursor_plan] total = ( cursor_new_shape @@ -493,7 +506,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens torch.zeros([], dtype=torch.bool, device=self.tensor.device), ) tensor = torch.where(splitting_parity, -tensor, +tensor) - tensor = tensor.reshape(shape) merging_parity = functools.reduce( diff --git a/tests/reshape_test.py b/tests/reshape_test.py index 8a8ee34..f904e25 100644 --- a/tests/reshape_test.py +++ b/tests/reshape_test.py @@ -170,6 +170,30 @@ def test_reshape_equal_edges_nontrivial_merging() -> None: _ = a.reshape(((3, 1),)) +def test_reshape_pure_even_merging() -> None: + arrow = (True, True, True) + edges = ((2, 2), (2, 2), (0, 1)) + a = GrassmannTensor(arrow, edges, torch.randn([4, 4, 1])) + _ = a.reshape(((8, 8),)) + + +def test_reshape_pure_even_splitting() -> None: + arrow = (True, True) + edges = ((2, 2), (2, 2)) + a = GrassmannTensor(arrow, edges, torch.randn([4, 4])) + _ = a.reshape(((2, 2), (2, 2), (0, 1))) + + +def test_reshape_merging_plan_exhausted_self_remaining_nontrivial() -> None: + arrow = (True, True, True) + edges = ((2, 2), (2, 2), (0, 2)) + a = GrassmannTensor(arrow, edges, torch.randn([4, 4, 2])) + with pytest.raises( + ValueError, match="New shape exhausted but self still has non-trivial dimensions" + ): + _ = a.reshape(((8, 8),)) + + def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None: arrow = (True, True, True, True) edges = ((1, 3), (1, 0), (0, 1), (2, 2)) @@ -229,7 +253,7 @@ def test_reshape_with_one_dimension( def test_reshape_trailing_nontrivial_dim_raises() -> None: a = GrassmannTensor((True,), ((2, 2),), torch.randn([4])) - with pytest.raises(AssertionError, match="New shape exceeds after exhausting self dimensions"): + with pytest.raises(ValueError, match="New shape exceeds after exhausting self dimensions"): _ = a.reshape((-1, (2, 2))) @@ -395,6 +419,22 @@ def test_named_tensor_equal_edges_nontrivial_merging() -> None: _ = a.merge_edge({"a": ("a", "b", "c")}) +def test_named_tensor_pure_even_merging() -> None: + names = ("a", "b", "c") + arrow = (True, True, True) + edges = ((2, 2), (2, 2), (0, 1)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4, 1])) + _ = a.merge_edge({"b": ("b", "c")}) + + +def test_named_tensor_pure_even_splitting() -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((2, 2), (2, 2)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4])) + _ = a.split_edge({"b": (("b", (2, 2)), ("c", (0, 1)))}) + + def test_named_tensor_equal_edges_nontrivial_merging_with_other_edge() -> None: names = ("a", "b", "c", "d") arrow = (True, True, True, True)