diff --git a/test/common_utils.py b/test/common_utils.py index df3126fa4b4..2d664ba8caf 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -21,7 +21,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_image, to_pil_image +from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) @@ -461,9 +461,20 @@ def sample_position(values, max_value): parts = (x1, y1, x2, y2, x3, y3, x4, y4) else: raise ValueError(f"Format {format} is not supported") - return tv_tensors.BoundingBoxes( - torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size - ) + out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device) + if tv_tensors.is_rotated_bounding_format(format): + # The rotated bounding boxes are not guaranteed to be within the canvas by design, + # so we apply clamping. We also add a 2 buffer to the canvas size to avoid + # numerical issues during the testing + buffer = 4 + out_boxes = clamp_bounding_boxes( + out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer) + ) + if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR: + out_boxes[:, :2] += buffer // 2 + elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY: + out_boxes[:, :] += buffer // 2 + return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size) def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"): diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7ade10a8ea5..e9ae54a9735 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -53,15 +53,6 @@ from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal -# While we are working on adjusting transform functions -# for rotated and oriented bounding boxes formats, -# we limit the perimeter of tests to formats -# for which transform functions are already implemented. -# In the future, this global variable will be replaced with `list(tv_tensors.BoundingBoxFormat)` -# to support all available formats. -SUPPORTED_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYXY", "XYWH", "CXCYWH"]] -NEW_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYWHR", "CXCYWHR", "XYXYXYXY"]] - # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -568,6 +559,13 @@ def reference_affine_rotated_bounding_boxes_helper( def affine_rotated_bounding_boxes(bounding_boxes): dtype = bounding_boxes.dtype + int_dtype = dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ) device = bounding_boxes.device # Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1 @@ -602,19 +600,14 @@ def affine_rotated_bounding_boxes(bounding_boxes): ) output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output - output = _parallelogram_to_bounding_boxes(output) + if not int_dtype: + output = _parallelogram_to_bounding_boxes(output) output = F.convert_bounding_box_format( output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format ) - if torch.is_floating_point(output) and dtype in ( - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - ): + if torch.is_floating_point(output) and int_dtype: # it is better to round before cast output = torch.round(output) @@ -1462,9 +1455,9 @@ def test_functional_bounding_boxes_correctness(self, format, angle, translate, s center=center, ) - torch.testing.assert_close(actual, expected) + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_bounding_boxes_correctness(self, format, center, seed): @@ -1480,7 +1473,7 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed): expected = self._reference_affine_bounding_boxes(bounding_boxes, **params, center=center) - torch.testing.assert_close(actual, expected) + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=2e-5) @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"]) @pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["translate"]) @@ -1712,7 +1705,7 @@ def test_kernel_image(self, param, value, dtype, device): expand=[False, True], center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"], ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, param, value, format, dtype, device): @@ -1848,6 +1841,13 @@ def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy): x, y = recenter_xy if bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXY: translate = [x, y, x, y] + elif bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXYXYXY: + translate = [x, y, x, y, x, y, x, y] + elif ( + bounding_boxes.format is tv_tensors.BoundingBoxFormat.CXCYWHR + or bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYWHR + ): + translate = [x, y, 0.0, 0.0, 0.0] else: translate = [x, y, 0.0, 0.0] return tv_tensors.wrap( @@ -1872,7 +1872,12 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix ) - output = reference_affine_bounding_boxes_helper( + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper + ) + output = helper( bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=new_canvas_size, @@ -1883,7 +1888,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen bounding_boxes ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @@ -1896,7 +1901,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) @@ -2817,7 +2822,7 @@ def test_kernel_image(self, param, value, dtype, device): check_cuda_vs_cpu=dtype is not torch.float16, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): @@ -3647,8 +3652,14 @@ def test_rand_augment_num_ops_error(self, num_ops): class TestConvertBoundingBoxFormat: - old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2)) - old_new_formats += list(itertools.permutations(NEW_BOX_FORMATS, 2)) + old_new_formats = list( + itertools.permutations( + [f for f in tv_tensors.BoundingBoxFormat if not tv_tensors.is_rotated_bounding_format(f)], 2 + ) + ) + old_new_formats += list( + itertools.permutations([f for f in tv_tensors.BoundingBoxFormat if tv_tensors.is_rotated_bounding_format(f)], 2) + ) @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) def test_kernel(self, old_format, new_format): @@ -3659,7 +3670,7 @@ def test_kernel(self, old_format, new_format): old_format=old_format, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("inplace", [False, True]) def test_kernel_noop(self, format, inplace): input = make_bounding_boxes(format=format).as_subclass(torch.Tensor) @@ -3720,7 +3731,7 @@ def test_strings(self, old_format, new_format): ) out_transform = transforms.ConvertBoundingBoxFormat(new_format)(input) for out in (out_functional, out_functional_tensor, out_transform): - assert_equal(out, expected) + torch.testing.assert_close(out, expected) def _reference_convert_bounding_box_format(self, bounding_boxes, new_format): return tv_tensors.wrap( @@ -3748,7 +3759,7 @@ def test_correctness(self, old_format, new_format, dtype, device, fn_type): actual = fn(bounding_boxes) expected = self._reference_convert_bounding_box_format(bounding_boxes, new_format) - assert_equal(actual, expected) + torch.testing.assert_close(actual, expected) def test_errors(self): input_tv_tensor = make_bounding_boxes() @@ -4259,7 +4270,7 @@ def test_kernel_image_error(self): coefficients=COEFFICIENTS, start_end_points=START_END_POINTS, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_kernel_bounding_boxes(self, param, value, format): if param == "start_end_points": kwargs = dict(zip(["startpoints", "endpoints"], value)) @@ -4367,6 +4378,12 @@ def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, canvas_size = bounding_boxes.canvas_size dtype = bounding_boxes.dtype device = bounding_boxes.device + is_rotated = tv_tensors.is_rotated_bounding_format(format) + ndims = 4 + if is_rotated and format == tv_tensors.BoundingBoxFormat.XYXYXYXY: + ndims = 8 + if is_rotated and format != tv_tensors.BoundingBoxFormat.XYXYXYXY: + ndims = 5 coefficients = _get_perspective_coeffs(endpoints, startpoints) @@ -4384,39 +4401,74 @@ def perspective_bounding_boxes(bounding_boxes): ] ) - # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 - input_xyxy = F.convert_bounding_box_format( - bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True), - old_format=format, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - inplace=True, - ) - x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist() + if is_rotated: + input_xyxyxyxy = F.convert_bounding_box_format( + bounding_boxes.to(device="cpu", copy=True), + old_format=format, + new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, + inplace=True, + ) + x1, y1, x2, y2, x3, y3, x4, y4 = input_xyxyxyxy.squeeze(0).tolist() + points = np.array( + [ + [x1, y1, 1.0], + [x2, y2, 1.0], + [x3, y3, 1.0], + [x4, y4, 1.0], + ] + ) - points = np.array( - [ - [x1, y1, 1.0], - [x2, y1, 1.0], - [x1, y2, 1.0], - [x2, y2, 1.0], - ] - ) + else: + # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 + input_xyxy = F.convert_bounding_box_format( + bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True), + old_format=format, + new_format=tv_tensors.BoundingBoxFormat.XYXY, + inplace=True, + ) + x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist() + + points = np.array( + [ + [x1, y1, 1.0], + [x2, y1, 1.0], + [x1, y2, 1.0], + [x2, y2, 1.0], + ] + ) - numerator = points @ m1.T - denominator = points @ m2.T + numerator = points @ m1.astype(points.dtype).T + denominator = points @ m2.astype(points.dtype).T transformed_points = numerator / denominator - output_xyxy = torch.Tensor( - [ - float(np.min(transformed_points[:, 0])), - float(np.min(transformed_points[:, 1])), - float(np.max(transformed_points[:, 0])), - float(np.max(transformed_points[:, 1])), - ] - ) + if is_rotated: + output = torch.Tensor( + [ + float(transformed_points[0, 0]), + float(transformed_points[0, 1]), + float(transformed_points[1, 0]), + float(transformed_points[1, 1]), + float(transformed_points[2, 0]), + float(transformed_points[2, 1]), + float(transformed_points[3, 0]), + float(transformed_points[3, 1]), + ] + ) + output = _parallelogram_to_bounding_boxes(output) + else: + output = torch.Tensor( + [ + float(np.min(transformed_points[:, 0])), + float(np.min(transformed_points[:, 1])), + float(np.max(transformed_points[:, 0])), + float(np.max(transformed_points[:, 1])), + ] + ) output = F.convert_bounding_box_format( - output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format + output, + old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY, + new_format=format, ) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 @@ -4427,15 +4479,15 @@ def perspective_bounding_boxes(bounding_boxes): ).to(dtype=dtype, device=device) return tv_tensors.BoundingBoxes( - torch.cat([perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape( - bounding_boxes.shape - ), + torch.cat( + [perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, ndims).unbind()], dim=0 + ).reshape(bounding_boxes.shape), format=format, canvas_size=canvas_size, ) @pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device): @@ -4642,7 +4694,7 @@ def test_correctness_image(self, mean, std, dtype, fn): class TestClampBoundingBoxes: - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel(self, format, dtype, device): @@ -4654,7 +4706,7 @@ def test_kernel(self, format, dtype, device): canvas_size=bounding_boxes.canvas_size, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional(self, format): check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format)) @@ -5959,15 +6011,19 @@ def test_parallelogram_to_bounding_boxes(input_size, dtype, device): # 1---2 1----2 # \ \ -> | | # 4---3 4----3 - parallelogram = torch.tensor([[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]]) + parallelogram = torch.tensor( + [[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]], + dtype=torch.float32, + ) expected = torch.tensor( [ [0, 0, 4, 0, 4, 2, 0, 2], [0, 0, 4, 0, 4, 2, 0, 2], - ] + ], + dtype=torch.float32, ) actual = _parallelogram_to_bounding_boxes(parallelogram) - assert_equal(actual, expected) + torch.testing.assert_close(actual, expected) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image)) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 09be745a695..e3003b866bc 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -407,6 +407,10 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso torch.int32, torch.int64, ) + if int_dtype: + # Does not apply the transformation to `int` boxes as the rounding error + # will typically not ensure the resulting box has a rectangular shape. + return parallelogram.clone() out_boxes = parallelogram.clone() @@ -415,12 +419,14 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso dy13 = parallelogram[..., 5] - parallelogram[..., 1] dx42 = parallelogram[..., 2] - parallelogram[..., 6] dy42 = parallelogram[..., 3] - parallelogram[..., 7] + dx12 = parallelogram[..., 2] - parallelogram[..., 0] + dy12 = parallelogram[..., 1] - parallelogram[..., 3] diag13 = torch.sqrt(dx13**2 + dy13**2) diag24 = torch.sqrt(dx42**2 + dy42**2) mask = diag13 > diag24 # Calculate rotation angle in radians - r_rad = torch.atan2(parallelogram[..., 1] - parallelogram[..., 3], parallelogram[..., 2] - parallelogram[..., 0]) + r_rad = torch.atan2(dy12, dx12) cos, sin = torch.cos(r_rad), torch.sin(r_rad) # Calculate width using the angle between diagonal and rotation @@ -432,7 +438,6 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos delta_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin - # Update coordinates to form a rectangle # Keeping the points (x1, y1) and (x3, y3) unchanged. out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2]) @@ -478,6 +483,9 @@ def resize_bounding_boxes( ) transformed_points = xyxyxyxy_boxes.mul(ratios) out_bboxes = _parallelogram_to_bounding_boxes(transformed_points) + out_bboxes = clamp_bounding_boxes( + out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, canvas_size=(new_height, new_width) + ) return ( convert_bounding_box_format( out_bboxes, @@ -983,7 +991,7 @@ def _affine_bounding_boxes_with_expand( new_points = torch.matmul(points, transposed_affine_matrix) tr = torch.amin(new_points, dim=0, keepdim=True) # Translate bounding boxes - out_bboxes.sub_(tr.repeat((1, 2))) + out_bboxes.sub_(tr.repeat((1, 4 if is_rotated else 2))) # Estimate meta-data for image with inverted=True affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear) new_width, new_height = _compute_affine_output_size(affine_vector, width, height) @@ -1753,10 +1761,13 @@ def perspective_bounding_boxes( perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) original_shape = bounding_boxes.shape + original_dtype = bounding_boxes.dtype + is_rotated = tv_tensors.is_rotated_bounding_format(format) + intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY # TODO: first cast to float if bbox is int64 before convert_bounding_box_format bounding_boxes = ( - convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY) - ).reshape(-1, 4) + convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format) + ).reshape(-1, 8 if is_rotated else 4) dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 device = bounding_boxes.device @@ -1805,7 +1816,8 @@ def perspective_bounding_boxes( # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] - points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + points = points.reshape(-1, 2) points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) # 2) Now let's transform the points using perspective matrices # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) @@ -1817,21 +1829,23 @@ def perspective_bounding_boxes( # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # and compute bounding box from 4 transformed points: - transformed_points = transformed_points.reshape(-1, 4, 2) - out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) - - out_bboxes = clamp_bounding_boxes( - torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=canvas_size, - ) + if is_rotated: + transformed_points = transformed_points.reshape(-1, 8) + out_bboxes = _parallelogram_to_bounding_boxes(transformed_points) + else: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) - # out_bboxes should be of shape [N boxes, 4] + out_bboxes = clamp_bounding_boxes(out_bboxes, format=intermediate_format, canvas_size=canvas_size) - return convert_bounding_box_format( - out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes = convert_bounding_box_format( + out_bboxes, old_format=intermediate_format, new_format=format, inplace=True ).reshape(original_shape) + out_bboxes = out_bboxes.to(original_dtype) + return out_bboxes + @_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) def _perspective_bounding_boxes_dispatch( @@ -2011,15 +2025,18 @@ def elastic_bounding_boxes( # TODO: add in docstring about approximation we are doing for grid inversion device = bounding_boxes.device dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 + is_rotated = tv_tensors.is_rotated_bounding_format(format) if displacement.dtype != dtype or displacement.device != device: displacement = displacement.to(dtype=dtype, device=device) original_shape = bounding_boxes.shape # TODO: first cast to float if bbox is int64 before convert_bounding_box_format + intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY + bounding_boxes = ( - convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY) - ).reshape(-1, 4) + convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format) + ).reshape(-1, 8 if is_rotated else 4) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement @@ -2027,7 +2044,8 @@ def elastic_bounding_boxes( inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + points = points.reshape(-1, 2) if points.is_floating_point(): points = points.ceil_() index_xy = points.to(dtype=torch.long) @@ -2037,16 +2055,22 @@ def elastic_bounding_boxes( t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) - transformed_points = transformed_points.reshape(-1, 4, 2) - out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + if is_rotated: + transformed_points = transformed_points.reshape(-1, 8) + out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype) + else: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype) + out_bboxes = clamp_bounding_boxes( - torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), - format=tv_tensors.BoundingBoxFormat.XYXY, + out_bboxes, + format=intermediate_format, canvas_size=canvas_size, ) return convert_bounding_box_format( - out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes, old_format=intermediate_format, new_format=format, inplace=False ).reshape(original_shape) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 019f4e25cb7..9921e0b4282 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -352,14 +352,144 @@ def _clamp_bounding_boxes( return out_boxes.to(in_dtype) +def _order_bounding_boxes_points( + bounding_boxes: torch.Tensor, indices: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Re-order points in bounding boxes based on specific criteria or provided indices. + + This function reorders the points of bounding boxes either according to provided indices or + by a default ordering strategy. In the default strategy, (x1, y1) corresponds to the point + with the lowest x value. If multiple points have the same lowest x value, the point with the + lowest y value is chosen. + + Args: + bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates in format [x1, y1, x2, y2, x3, y3, x4, y4]. + indices (torch.Tensor | None): Optional tensor containing indices for reordering. If None, default ordering is applied. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - indices: The indices used for reordering + - reordered_boxes: The bounding boxes with reordered points + """ + if indices is None: + output_xyxyxyxy = bounding_boxes.reshape(-1, 8) + x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2] + y_max = torch.max(y, dim=1, keepdim=True)[0] + _, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1) + indices = torch.ones_like(output_xyxyxyxy) + indices[..., 0] = x1.mul(2) + indices.cumsum_(1).remainder_(8) + return indices, bounding_boxes.gather(1, indices.to(torch.int64)) + + +def _area(box: torch.Tensor) -> torch.Tensor: + x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1) + w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2) + h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2) + return w * h + + +def _clamp_along_y_axis( + bounding_boxes: torch.Tensor, +) -> torch.Tensor: + """ + Adjusts bounding boxes along the y-axis based on specific conditions. + + This function modifies the bounding boxes by evaluating different cases + and applying the appropriate transformation to ensure the bounding boxes + are clamped correctly along the y-axis. + + Args: + bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates. + + Returns: + torch.Tensor: The adjusted bounding boxes. + """ + original_dtype = bounding_boxes.dtype + original_shape = bounding_boxes.shape + x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.reshape(-1, 8).unbind(-1) + a = (y2 - y1) / (x2 - x1) + b1 = y1 - a * x1 + b2 = y2 + x2 / a + b3 = y3 - a * x3 + b4 = y4 + x4 / a + b23 = (b2 - b3) / 2 * a / (1 + a**2) + z = torch.zeros_like(b1) + case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1) + case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1) + case_c = torch.cat( + [x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1 + ) + case_d = torch.zeros_like(case_c) + case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1) + + cond_a = x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0)) + cond_a = cond_a.logical_and(_area(case_a) > _area(case_b)) + cond_a = cond_a.logical_or(x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.le(0))) + cond_b = x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0)) + cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b)) + cond_b = cond_b.logical_or(x1.lt(0).logical_and(x2.le(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0))) + cond_c = x1.lt(0).logical_and(x2.le(0)).logical_and(x3.ge(0)).logical_and(x4.le(0)) + cond_d = x1.lt(0).logical_and(x2.le(0)).logical_and(x3.le(0)).logical_and(x4.le(0)) + cond_e = x1.isclose(x2) + + for cond, case in zip( + [cond_a, cond_b, cond_c, cond_d, cond_e], + [case_a, case_b, case_c, case_d, case_e], + ): + bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes) + return bounding_boxes.to(original_dtype).reshape(original_shape) + + def _clamp_rotated_bounding_boxes( bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int] ) -> torch.Tensor: - # TODO: For now we are not clamping rotated bounding boxes. - in_dtype = bounding_boxes.dtype - out_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + """ + Clamp rotated bounding boxes to ensure they stay within the canvas boundaries. + + This function handles rotated bounding boxes by: + 1. Converting them to XYXYXYXY format (8 coordinates representing 4 corners). + 2. Re-ordering the points in the bounding boxes to ensure (x1, y1) corresponds to the point with the lowest x value + 2. Translates the points (x1, y1), (x2, y2) and (x3, y3) + to ensure the bounding box does not go out beyond the left boundary of the canvas. + 3. Rotate the bounding box four times and apply the same transformation to each vertex to ensure + the box does not go beyond the top, right, and bottom boundaries. + 3. Converting back to the original format and re-order the points as in the original input. + + Args: + bounding_boxes (torch.Tensor): Tensor containing rotated bounding box coordinates + format (BoundingBoxFormat): The format of the input bounding boxes + canvas_size (tuple[int, int]): The size of the canvas as (height, width) + + Returns: + torch.Tensor: Clamped bounding boxes in the original format and shape + """ + original_shape = bounding_boxes.shape + original_dtype = bounding_boxes.dtype + bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + out_boxes = ( + convert_bounding_box_format( + bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=True + ) + ).reshape(-1, 8) + + for _ in range(4): + indices, out_boxes = _order_bounding_boxes_points(out_boxes) + out_boxes = _clamp_along_y_axis(out_boxes) + _, out_boxes = _order_bounding_boxes_points(out_boxes, indices) + # rotate 90 degrees counter clock wise + out_boxes[:, ::2], out_boxes[:, 1::2] = ( + out_boxes[:, 1::2].clone(), + canvas_size[1] - out_boxes[:, ::2].clone(), + ) + canvas_size = (canvas_size[1], canvas_size[0]) - return out_boxes.to(in_dtype) + out_boxes = convert_bounding_box_format( + out_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format, inplace=True + ).reshape(original_shape) + + out_boxes = out_boxes.to(original_dtype) + return out_boxes def clamp_bounding_boxes( diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 7e167d788e6..191d01c9922 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -405,16 +405,27 @@ def _get_sanitize_bounding_boxes_mask( min_area: float = 1.0, ) -> torch.Tensor: - bounding_boxes = _convert_bounding_box_format( - bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format - ) + is_rotated = tv_tensors.is_rotated_bounding_format(format) + intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY + bounding_boxes = _convert_bounding_box_format(bounding_boxes, new_format=intermediate_format, old_format=format) image_h, image_w = canvas_size - ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] + if is_rotated: + dx12 = bounding_boxes[..., 0] - bounding_boxes[..., 2] + dy12 = bounding_boxes[..., 1] - bounding_boxes[..., 3] + dx23 = bounding_boxes[..., 3] - bounding_boxes[..., 5] + dy23 = bounding_boxes[..., 4] - bounding_boxes[..., 6] + ws = torch.sqrt(dx12**2 + dy12**2) + hs = torch.sqrt(dx23**2 + dy23**2) + else: + ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = canvas_size valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) + if is_rotated: + valid &= (bounding_boxes[..., 4] <= image_w) & (bounding_boxes[..., 5] <= image_h) + valid &= (bounding_boxes[..., 6] <= image_w) & (bounding_boxes[..., 7] <= image_h) return valid