Skip to content

Batch Support for AxisAlignedBoundingBox#495

Open
zhx06 wants to merge 6 commits intomainfrom
zxiao/bbox_batch_support
Open

Batch Support for AxisAlignedBoundingBox#495
zhx06 wants to merge 6 commits intomainfrom
zxiao/bbox_batch_support

Conversation

@zhx06
Copy link
Copy Markdown
Collaborator

@zhx06 zhx06 commented Mar 23, 2026

Summary

Add batch dimension support to AxisAlignedBoundingBox

Detailed description

  • What was the reason for the change?
    • To support multi-environment batched placement, AxisAlignedBoundingBox needs to handle N bounding boxes simultaneously instead of only one.
  • What has been changed?
    • Refactored AxisAlignedBoundingBox from a dataclass to a class with (N, 3) float32 tensor storage. Returns tuples/floats for N=1, tensors for N>1.
    • Fixed float32 rounding issue in _validate_on_relations Z boundary check.
    • Added test_bounding_box.py for single-env and multi-env coverage.
  • What is the impact of this change?
    • No behavior change for existing code: backward compatible for single env (N=1).
    • Enables future multi-env batch support for the coming relation solver MR.

@zhx06 zhx06 changed the title add batch support and test cases for AxisAlignedBoundingBox Batch Support for AxisAlignedBoundingBox Mar 23, 2026
@zhx06 zhx06 force-pushed the zxiao/bbox_batch_support branch 2 times, most recently from 9aa9df0 to 46af265 Compare March 23, 2026 20:40
min_point: tuple[float, float, float]
"""Local minimum extent (x, y, z) relative to object origin."""
Stores min/max extents as (N, 3) tensors where N is the number of environments.
Properties return tuples/floats when N=1 and tensors when N>1.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the dual-mode return type is a reliability issue. Properties return tuple/float when N=1 and tensor when N>1. Code that works in single-env silently changes behavior when N increases. Recommend to always returning tensors

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally we could simplify below some parts like self._min_point = self._to_batched_tensor(min_point),as its always a tensor (the input as well).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require some updates downstream, but it would be worth it in my opinion

return value.unsqueeze(0).float()
return value.float()

def _format_output(self, tensor: torch.Tensor):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion to rename this function to what is formatted

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However can remove this if the class only handles tensors

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would break silently for N>1? When N>1, min_point returns a tensor and the subsequent torch.rand(3) arithmetic produces wrong-shaped results ((N,3) instead of (3,)).
I would suggest to move to single input and output type (tensors)

@zhx06 zhx06 force-pushed the zxiao/bbox_batch_support branch from 46af265 to db21d7f Compare March 24, 2026 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants