Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim():
if cursor_self == self.tensor.dim() and cursor_plan != len(new_shape):
new_shape_check = new_shape[cursor_plan]
if (isinstance(new_shape_check, int) and new_shape_check == 1) or (
new_shape_check == (1, 0)
if (
(isinstance(new_shape_check, int) and new_shape_check == 1)
or new_shape_check == (1, 0)
or new_shape_check == (0, 1)
):
if cursor_plan < len(self.arrow):
arrow.append(self.arrow[cursor_plan])
else:
arrow.append(False)
edges.append((1, 0))
if new_shape_check == (0, 1):
edges.append((0, 1))
else:
edges.append((1, 0))
shape.append(1)
cursor_plan += 1
continue
raise AssertionError(
raise ValueError(
"New shape exceeds after exhausting self dimensions: "
f"edges={self.edges}, new_shape={new_shape}"
)
Expand Down Expand Up @@ -359,6 +364,14 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
# A trivial self edge
cursor_self += 1
continue
if cursor_plan == len(new_shape) and cursor_self != self.tensor.dim():
if self.tensor.shape[cursor_self] == 1:
cursor_self += 1
continue
raise ValueError(
"New shape exhausted but self still has non-trivial dimensions: "
f"edges={self.edges}, new_shape={new_shape}, cursor_self={cursor_self}"
)
cursor_new_shape = new_shape[cursor_plan]
total = (
cursor_new_shape
Expand Down Expand Up @@ -493,7 +506,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
)
tensor = torch.where(splitting_parity, -tensor, +tensor)

tensor = tensor.reshape(shape)

merging_parity = functools.reduce(
Expand Down
42 changes: 41 additions & 1 deletion tests/reshape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def test_reshape_equal_edges_nontrivial_merging() -> None:
_ = a.reshape(((3, 1),))


def test_reshape_pure_even_merging() -> None:
arrow = (True, True, True)
edges = ((2, 2), (2, 2), (0, 1))
a = GrassmannTensor(arrow, edges, torch.randn([4, 4, 1]))
_ = a.reshape(((8, 8),))


def test_reshape_pure_even_splitting() -> None:
arrow = (True, True)
edges = ((2, 2), (2, 2))
a = GrassmannTensor(arrow, edges, torch.randn([4, 4]))
_ = a.reshape(((2, 2), (2, 2), (0, 1)))


def test_reshape_merging_plan_exhausted_self_remaining_nontrivial() -> None:
arrow = (True, True, True)
edges = ((2, 2), (2, 2), (0, 2))
a = GrassmannTensor(arrow, edges, torch.randn([4, 4, 2]))
with pytest.raises(
ValueError, match="New shape exhausted but self still has non-trivial dimensions"
):
_ = a.reshape(((8, 8),))


def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None:
arrow = (True, True, True, True)
edges = ((1, 3), (1, 0), (0, 1), (2, 2))
Expand Down Expand Up @@ -229,7 +253,7 @@ def test_reshape_with_one_dimension(

def test_reshape_trailing_nontrivial_dim_raises() -> None:
a = GrassmannTensor((True,), ((2, 2),), torch.randn([4]))
with pytest.raises(AssertionError, match="New shape exceeds after exhausting self dimensions"):
with pytest.raises(ValueError, match="New shape exceeds after exhausting self dimensions"):
_ = a.reshape((-1, (2, 2)))


Expand Down Expand Up @@ -395,6 +419,22 @@ def test_named_tensor_equal_edges_nontrivial_merging() -> None:
_ = a.merge_edge({"a": ("a", "b", "c")})


def test_named_tensor_pure_even_merging() -> None:
names = ("a", "b", "c")
arrow = (True, True, True)
edges = ((2, 2), (2, 2), (0, 1))
a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4, 1]))
_ = a.merge_edge({"b": ("b", "c")})


def test_named_tensor_pure_even_splitting() -> None:
names = ("a", "b")
arrow = (True, True)
edges = ((2, 2), (2, 2))
a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4]))
_ = a.split_edge({"b": (("b", (2, 2)), ("c", (0, 1)))})


def test_named_tensor_equal_edges_nontrivial_merging_with_other_edge() -> None:
names = ("a", "b", "c", "d")
arrow = (True, True, True, True)
Expand Down
Loading