Skip to content

Rotated bboxes transforms #9104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
19 changes: 15 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
Expand Down Expand Up @@ -461,9 +461,20 @@ def sample_position(values, max_value):
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
else:
raise ValueError(f"Format {format} is not supported")
return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
if tv_tensors.is_rotated_bounding_format(format):
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
# numerical issues during the testing
buffer = 4
out_boxes = clamp_bounding_boxes(
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
)
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
out_boxes[:, :2] += buffer // 2
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
out_boxes[:, :] += buffer // 2
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size)


def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
Expand Down
190 changes: 123 additions & 67 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,6 @@
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal


# While we are working on adjusting transform functions
# for rotated and oriented bounding boxes formats,
# we limit the perimeter of tests to formats
# for which transform functions are already implemented.
# In the future, this global variable will be replaced with `list(tv_tensors.BoundingBoxFormat)`
# to support all available formats.
SUPPORTED_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYXY", "XYWH", "CXCYWH"]]
NEW_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYWHR", "CXCYWHR", "XYXYXYXY"]]

# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]

Expand Down Expand Up @@ -568,6 +559,13 @@ def reference_affine_rotated_bounding_boxes_helper(

def affine_rotated_bounding_boxes(bounding_boxes):
dtype = bounding_boxes.dtype
int_dtype = dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
device = bounding_boxes.device

# Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
Expand Down Expand Up @@ -602,19 +600,14 @@ def affine_rotated_bounding_boxes(bounding_boxes):
)

output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output
output = _parallelogram_to_bounding_boxes(output)
if not int_dtype:
output = _parallelogram_to_bounding_boxes(output)

output = F.convert_bounding_box_format(
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
)

if torch.is_floating_point(output) and dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
):
if torch.is_floating_point(output) and int_dtype:
# it is better to round before cast
output = torch.round(output)

Expand Down Expand Up @@ -1462,9 +1455,9 @@ def test_functional_bounding_boxes_correctness(self, format, angle, translate, s
center=center,
)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5)

@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_bounding_boxes_correctness(self, format, center, seed):
Expand All @@ -1480,7 +1473,7 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed):

expected = self._reference_affine_bounding_boxes(bounding_boxes, **params, center=center)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual, expected, atol=1e-5, rtol=2e-5)

@pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["translate"])
Expand Down Expand Up @@ -1712,7 +1705,7 @@ def test_kernel_image(self, param, value, dtype, device):
expand=[False, True],
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
)
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
Expand Down Expand Up @@ -1848,6 +1841,13 @@ def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
x, y = recenter_xy
if bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXY:
translate = [x, y, x, y]
elif bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
translate = [x, y, x, y, x, y, x, y]
elif (
bounding_boxes.format is tv_tensors.BoundingBoxFormat.CXCYWHR
or bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYWHR
):
translate = [x, y, 0.0, 0.0, 0.0]
else:
translate = [x, y, 0.0, 0.0]
return tv_tensors.wrap(
Expand All @@ -1872,7 +1872,12 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix
)

output = reference_affine_bounding_boxes_helper(
helper = (
reference_affine_rotated_bounding_boxes_helper
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)
output = helper(
bounding_boxes,
affine_matrix=affine_matrix,
new_canvas_size=new_canvas_size,
Expand All @@ -1883,7 +1888,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
bounding_boxes
)

@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
Expand All @@ -1896,7 +1901,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
Expand Down Expand Up @@ -2817,7 +2822,7 @@ def test_kernel_image(self, param, value, dtype, device):
check_cuda_vs_cpu=dtype is not torch.float16,
)

@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
Expand Down Expand Up @@ -3647,8 +3652,14 @@ def test_rand_augment_num_ops_error(self, num_ops):


class TestConvertBoundingBoxFormat:
old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2))
old_new_formats += list(itertools.permutations(NEW_BOX_FORMATS, 2))
old_new_formats = list(
itertools.permutations(
[f for f in tv_tensors.BoundingBoxFormat if not tv_tensors.is_rotated_bounding_format(f)], 2
)
)
old_new_formats += list(
itertools.permutations([f for f in tv_tensors.BoundingBoxFormat if tv_tensors.is_rotated_bounding_format(f)], 2)
)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel(self, old_format, new_format):
Expand All @@ -3659,7 +3670,7 @@ def test_kernel(self, old_format, new_format):
old_format=old_format,
)

@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("inplace", [False, True])
def test_kernel_noop(self, format, inplace):
input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
Expand Down Expand Up @@ -3720,7 +3731,7 @@ def test_strings(self, old_format, new_format):
)
out_transform = transforms.ConvertBoundingBoxFormat(new_format)(input)
for out in (out_functional, out_functional_tensor, out_transform):
assert_equal(out, expected)
torch.testing.assert_close(out, expected)

def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
return tv_tensors.wrap(
Expand Down Expand Up @@ -3748,7 +3759,7 @@ def test_correctness(self, old_format, new_format, dtype, device, fn_type):
actual = fn(bounding_boxes)
expected = self._reference_convert_bounding_box_format(bounding_boxes, new_format)

assert_equal(actual, expected)
torch.testing.assert_close(actual, expected)

def test_errors(self):
input_tv_tensor = make_bounding_boxes()
Expand Down Expand Up @@ -4259,7 +4270,7 @@ def test_kernel_image_error(self):
coefficients=COEFFICIENTS,
start_end_points=START_END_POINTS,
)
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
def test_kernel_bounding_boxes(self, param, value, format):
if param == "start_end_points":
kwargs = dict(zip(["startpoints", "endpoints"], value))
Expand Down Expand Up @@ -4367,6 +4378,12 @@ def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints,
canvas_size = bounding_boxes.canvas_size
dtype = bounding_boxes.dtype
device = bounding_boxes.device
is_rotated = tv_tensors.is_rotated_bounding_format(format)
ndims = 4
if is_rotated and format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
ndims = 8
if is_rotated and format != tv_tensors.BoundingBoxFormat.XYXYXYXY:
ndims = 5

coefficients = _get_perspective_coeffs(endpoints, startpoints)

Expand All @@ -4384,39 +4401,74 @@ def perspective_bounding_boxes(bounding_boxes):
]
)

# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
input_xyxy = F.convert_bounding_box_format(
bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
old_format=format,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
inplace=True,
)
x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
if is_rotated:
input_xyxyxyxy = F.convert_bounding_box_format(
bounding_boxes.to(device="cpu", copy=True),
old_format=format,
new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
inplace=True,
)
x1, y1, x2, y2, x3, y3, x4, y4 = input_xyxyxyxy.squeeze(0).tolist()
points = np.array(
[
[x1, y1, 1.0],
[x2, y2, 1.0],
[x3, y3, 1.0],
[x4, y4, 1.0],
]
)

points = np.array(
[
[x1, y1, 1.0],
[x2, y1, 1.0],
[x1, y2, 1.0],
[x2, y2, 1.0],
]
)
else:
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
input_xyxy = F.convert_bounding_box_format(
bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
old_format=format,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
inplace=True,
)
x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()

points = np.array(
[
[x1, y1, 1.0],
[x2, y1, 1.0],
[x1, y2, 1.0],
[x2, y2, 1.0],
]
)

numerator = points @ m1.T
denominator = points @ m2.T
numerator = points @ m1.astype(points.dtype).T
denominator = points @ m2.astype(points.dtype).T
transformed_points = numerator / denominator

output_xyxy = torch.Tensor(
[
float(np.min(transformed_points[:, 0])),
float(np.min(transformed_points[:, 1])),
float(np.max(transformed_points[:, 0])),
float(np.max(transformed_points[:, 1])),
]
)
if is_rotated:
output = torch.Tensor(
[
float(transformed_points[0, 0]),
float(transformed_points[0, 1]),
float(transformed_points[1, 0]),
float(transformed_points[1, 1]),
float(transformed_points[2, 0]),
float(transformed_points[2, 1]),
float(transformed_points[3, 0]),
float(transformed_points[3, 1]),
]
)
output = _parallelogram_to_bounding_boxes(output)
else:
output = torch.Tensor(
[
float(np.min(transformed_points[:, 0])),
float(np.min(transformed_points[:, 1])),
float(np.max(transformed_points[:, 0])),
float(np.max(transformed_points[:, 1])),
]
)

output = F.convert_bounding_box_format(
output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
output,
old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY,
new_format=format,
)

# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
Expand All @@ -4427,15 +4479,15 @@ def perspective_bounding_boxes(bounding_boxes):
).to(dtype=dtype, device=device)

return tv_tensors.BoundingBoxes(
torch.cat([perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
bounding_boxes.shape
),
torch.cat(
[perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, ndims).unbind()], dim=0
).reshape(bounding_boxes.shape),
format=format,
canvas_size=canvas_size,
)

@pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):
Expand Down Expand Up @@ -4642,7 +4694,7 @@ def test_correctness_image(self, mean, std, dtype, fn):


class TestClampBoundingBoxes:
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, format, dtype, device):
Expand All @@ -4654,7 +4706,7 @@ def test_kernel(self, format, dtype, device):
canvas_size=bounding_boxes.canvas_size,
)

@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
def test_functional(self, format):
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))

Expand Down Expand Up @@ -5959,15 +6011,19 @@ def test_parallelogram_to_bounding_boxes(input_size, dtype, device):
# 1---2 1----2
# \ \ -> | |
# 4---3 4----3
parallelogram = torch.tensor([[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]])
parallelogram = torch.tensor(
[[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]],
dtype=torch.float32,
)
expected = torch.tensor(
[
[0, 0, 4, 0, 4, 2, 0, 2],
[0, 0, 4, 0, 4, 2, 0, 2],
]
],
dtype=torch.float32,
)
actual = _parallelogram_to_bounding_boxes(parallelogram)
assert_equal(actual, expected)
torch.testing.assert_close(actual, expected)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
Expand Down
Loading
Loading