Skip to content

Commit d420392

Browse files
Shuangping Liufacebook-github-bot
authored andcommitted
Add validation logic for lengths & offsets of KJT (#2963)
Summary: Pull Request resolved: #2963 This diff adds validation logic for lengths and offsets of `KeyedJaggedTensor`. The changes include: - Add a new library `jagged_tensor_validator` and a new test file `test_jagged_tensor_validator.py`. - The `validate_keyed_jagged_tensor` function checks the input lengths and/or offsets are valid in non-VBE case, including: - At least one of lengths or offsets is provided - If both are provided, they are consistent with each other - The dimensions of these tensors align with the values tensor - Generates test cases using Hypothesis to cover corner cases and ensure valid KJTs can successfully pass the validator. More validation logic & test cases will be added in follow-up diffs. Reviewed By: TroyGarden Differential Revision: D71531326 fbshipit-source-id: e84d115c15b34cee9672f7422aa3e109878f7bc1
1 parent ceebcf0 commit d420392

File tree

2 files changed

+261
-0
lines changed

2 files changed

+261
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import torch
11+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
12+
13+
14+
def validate_keyed_jagged_tensor(
15+
kjt: KeyedJaggedTensor,
16+
) -> None:
17+
"""
18+
Validates the inputs that construct a KeyedJaggedTensor.
19+
20+
This function ensures that:
21+
- At least one of lengths or offsets is provided
22+
- If both are provided, they are consistent with each other
23+
- The dimensions of these tensors align with the values tensor
24+
25+
Any invalid input will result in a ValueError being thrown.
26+
"""
27+
# TODO: Add validation checks on keys, values, weights
28+
_validate_lengths_and_offsets(kjt)
29+
30+
31+
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
32+
lengths = kjt.lengths_or_none()
33+
offsets = kjt.offsets_or_none()
34+
if lengths is None and offsets is None:
35+
raise ValueError(
36+
"lengths and offsets cannot be both empty in KeyedJaggedTensor"
37+
)
38+
elif lengths is not None and offsets is not None:
39+
_validate_lengths_and_offsets_consistency(lengths, offsets, kjt.values())
40+
elif lengths is not None:
41+
_validate_lengths(lengths, kjt.values())
42+
elif offsets is not None:
43+
_validate_offsets(offsets, kjt.values())
44+
45+
46+
def _validate_lengths_and_offsets_consistency(
47+
lengths: torch.Tensor, offsets: torch.Tensor, values: torch.Tensor
48+
) -> None:
49+
_validate_lengths(lengths, values)
50+
_validate_offsets(offsets, values)
51+
52+
if lengths.numel() != offsets.numel() - 1:
53+
raise ValueError(
54+
f"Expected lengths size to be 1 more than offsets size, but got lengths size: {lengths.numel()} and offsets size: {offsets.numel()}"
55+
)
56+
57+
if not lengths.equal(torch.diff(offsets)):
58+
raise ValueError("offsets is not equal to the cumulative sum of lengths")
59+
60+
61+
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> None:
62+
if lengths.sum().item() != values.numel():
63+
raise ValueError(
64+
f"Sum of lengths must equal the number of values, but got {lengths.sum().item()} and {values.numel()}"
65+
)
66+
67+
68+
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
69+
if offsets.numel() == 0:
70+
raise ValueError("offsets cannot be empty")
71+
72+
if offsets[0] != 0:
73+
raise ValueError(f"Expected first offset to be 0, but got {offsets[0]} instead")
74+
75+
if offsets[-1] != values.numel():
76+
raise ValueError(
77+
f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}"
78+
)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
11+
import unittest
12+
from typing import List, Optional, Tuple
13+
14+
import torch
15+
from hypothesis import given, settings, strategies as st, Verbosity
16+
from parameterized import param, parameterized
17+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
18+
from torchrec.sparse.jagged_tensor_validator import validate_keyed_jagged_tensor
19+
20+
21+
@st.composite
22+
def valid_kjt_from_lengths_offsets_strategy(
23+
draw: st.DrawFn,
24+
) -> Tuple[List[str], torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
25+
keys = draw(st.lists(st.text(), min_size=1, max_size=10, unique=True))
26+
27+
stride = draw(st.integers(1, 10))
28+
lengths = torch.tensor(
29+
draw(
30+
st.lists(
31+
st.integers(0, 20),
32+
min_size=len(keys) * stride,
33+
max_size=len(keys) * stride,
34+
)
35+
)
36+
)
37+
offsets = torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0)))
38+
39+
value_length = int(lengths.sum().item())
40+
values = torch.tensor(
41+
draw(
42+
st.lists(
43+
st.floats(0, 100),
44+
min_size=value_length,
45+
max_size=value_length,
46+
)
47+
)
48+
)
49+
weights_raw = draw(
50+
st.one_of(
51+
st.none(),
52+
st.lists(
53+
st.floats(0, 100),
54+
min_size=value_length,
55+
max_size=value_length,
56+
),
57+
)
58+
)
59+
weights = torch.tensor(weights_raw) if weights_raw is not None else None
60+
61+
return keys, values, weights, lengths, offsets
62+
63+
64+
class TestJaggedTensorValidator(unittest.TestCase):
65+
INVALID_LENGTHS_OFFSETS_CASES = [
66+
param(
67+
expected_error_msg="lengths and offsets cannot be both empty",
68+
keys=["f1", "f2"],
69+
values=torch.tensor([1, 2, 3, 4, 5]),
70+
lengths=None,
71+
offsets=None,
72+
),
73+
param(
74+
expected_error_msg="Expected lengths size to be 1 more than offsets size",
75+
keys=["f1", "f2"],
76+
values=torch.tensor([1, 2, 3, 4, 5]),
77+
lengths=torch.tensor([1, 2, 0, 2]),
78+
offsets=torch.tensor([0, 1, 3, 5]),
79+
),
80+
# Empty lengths is allowed but values must be empty as well
81+
param(
82+
expected_error_msg="Sum of lengths must equal the number of values",
83+
keys=["f1", "f2"],
84+
values=torch.tensor([1, 2, 3, 4, 5]),
85+
lengths=torch.tensor([]),
86+
offsets=None,
87+
),
88+
param(
89+
expected_error_msg="Sum of lengths must equal the number of values",
90+
keys=["f1", "f2"],
91+
values=torch.tensor([1, 2, 3, 4, 5]),
92+
lengths=torch.tensor([3, 3, 2, 1]),
93+
offsets=None,
94+
),
95+
param(
96+
expected_error_msg="offsets cannot be empty",
97+
keys=["f1", "f2"],
98+
values=torch.tensor([1, 2, 3, 4, 5]),
99+
lengths=None,
100+
offsets=torch.tensor([]),
101+
),
102+
param(
103+
expected_error_msg="Expected first offset to be 0",
104+
keys=["f1", "f2"],
105+
values=torch.tensor([1, 2, 3, 4, 5]),
106+
lengths=torch.tensor([1, 2, 0, 2]),
107+
offsets=torch.tensor([1, 2, 4, 4, 6]),
108+
),
109+
param(
110+
expected_error_msg="The last element of offsets must equal to the number of values",
111+
keys=["f1", "f2"],
112+
values=torch.tensor([1, 2, 3, 4, 5]),
113+
lengths=torch.tensor([1, 2, 0, 2]),
114+
offsets=torch.tensor([0, 2, 4, 4, 6]),
115+
),
116+
param(
117+
expected_error_msg="offsets is not equal to the cumulative sum of lengths",
118+
keys=["f1", "f2"],
119+
values=torch.tensor([1, 2, 3, 4, 5]),
120+
lengths=torch.tensor([1, 2, 0, 2]),
121+
offsets=torch.tensor([0, 2, 3, 3, 5]),
122+
),
123+
]
124+
125+
@parameterized.expand(INVALID_LENGTHS_OFFSETS_CASES)
126+
def test_invalid_keyed_jagged_tensor(
127+
self,
128+
expected_error_msg: str,
129+
keys: List[str],
130+
values: torch.Tensor,
131+
lengths: Optional[torch.Tensor],
132+
offsets: Optional[torch.Tensor],
133+
) -> None:
134+
kjt = KeyedJaggedTensor(
135+
keys=keys,
136+
values=values,
137+
lengths=lengths,
138+
offsets=offsets,
139+
)
140+
141+
with self.assertRaises(ValueError) as err:
142+
validate_keyed_jagged_tensor(kjt)
143+
self.assertIn(expected_error_msg, str(err.exception))
144+
145+
# pyre-ignore[56]
146+
@given(valid_kjt_from_lengths_offsets_strategy())
147+
@settings(verbosity=Verbosity.verbose, max_examples=20)
148+
def test_valid_kjt_from_lengths(
149+
self,
150+
test_data: Tuple[
151+
List[str],
152+
torch.Tensor,
153+
Optional[torch.Tensor],
154+
torch.Tensor,
155+
torch.Tensor,
156+
],
157+
) -> None:
158+
keys, values, weights, lengths, _ = test_data
159+
kjt = KeyedJaggedTensor.from_lengths_sync(
160+
keys=keys, values=values, weights=weights, lengths=lengths
161+
)
162+
163+
validate_keyed_jagged_tensor(kjt)
164+
165+
# pyre-ignore[56]
166+
@given(valid_kjt_from_lengths_offsets_strategy())
167+
@settings(verbosity=Verbosity.verbose, max_examples=20)
168+
def test_valid_kjt_from_offsets(
169+
self,
170+
test_data: Tuple[
171+
List[str],
172+
torch.Tensor,
173+
Optional[torch.Tensor],
174+
torch.Tensor,
175+
torch.Tensor,
176+
],
177+
) -> None:
178+
keys, values, weights, _, offsets = test_data
179+
kjt = KeyedJaggedTensor.from_offsets_sync(
180+
keys=keys, values=values, weights=weights, offsets=offsets
181+
)
182+
183+
validate_keyed_jagged_tensor(kjt)

0 commit comments

Comments
 (0)