Skip to content

Optimize SetView#equals() to avoid unnecessary iterations. Fixes #7716. #7784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions android/guava-tests/test/com/google/common/collect/SetViewTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package com.google.common.collect;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterators.emptyIterator;
import static com.google.common.collect.Sets.difference;
import static com.google.common.collect.Sets.intersection;
import static com.google.common.collect.Sets.newHashSet;
Expand All @@ -35,12 +38,16 @@
import com.google.common.collect.testing.TestStringSetGenerator;
import com.google.common.collect.testing.features.CollectionFeature;
import com.google.common.collect.testing.features.CollectionSize;
import com.google.common.testing.EqualsTester;
import java.util.AbstractSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;

/**
* Unit tests for {@link SetView}s: {@link Sets#union}, {@link Sets#intersection}, {@link
Expand Down Expand Up @@ -539,4 +546,197 @@ public void testCopyInto_nonEmptySet() {
assertThat(symmetricDifference(newHashSet(1, 2), newHashSet(2, 3)).copyInto(newHashSet(0, 1)))
.containsExactly(0, 1, 3);
}

public void testUnion_minSize() {
assertMinSize(union(emptySet(), emptySet()), 0);
assertMinSize(union(setSize(2), setSize(3)), 3);
assertMinSize(union(setSize(3), setSize(2)), 3);
assertMinSize(union(setSizeRange(10, 20), setSizeRange(11, 12)), 11);
assertMinSize(union(setSizeRange(11, 12), setSizeRange(10, 20)), 11);
}

public void testUnion_maxSize() {
assertMaxSize(union(emptySet(), emptySet()), 0);
assertMaxSize(union(setSize(2), setSize(3)), 5);
assertMaxSize(union(setSize(3), setSize(2)), 5);
assertMaxSize(union(setSizeRange(10, 20), setSizeRange(11, 12)), 32);
assertMaxSize(union(setSizeRange(11, 12), setSizeRange(10, 20)), 32);
}

public void testUnion_maxSize_saturated() {
assertThat(union(setSize(Integer.MAX_VALUE), setSize(1)).maxSize())
.isEqualTo(Integer.MAX_VALUE);
assertThat(union(setSize(1), setSize(Integer.MAX_VALUE)).maxSize())
.isEqualTo(Integer.MAX_VALUE);
}

public void testIntersection_minSize() {
assertMinSize(intersection(emptySet(), emptySet()), 0);
assertMinSize(intersection(setSize(2), setSize(3)), 0);
assertMinSize(intersection(setSize(3), setSize(2)), 0);
assertMinSize(intersection(setSizeRange(10, 20), setSizeRange(11, 12)), 0);
assertMinSize(intersection(setSizeRange(11, 12), setSizeRange(10, 20)), 0);
}

public void testIntersection_maxSize() {
assertMaxSize(intersection(emptySet(), emptySet()), 0);
assertMaxSize(intersection(setSize(2), setSize(3)), 2);
assertMaxSize(intersection(setSize(3), setSize(2)), 2);
assertMaxSize(intersection(setSizeRange(10, 20), setSizeRange(11, 12)), 12);
assertMaxSize(intersection(setSizeRange(11, 12), setSizeRange(10, 20)), 12);
}

public void testDifference_minSize() {
assertMinSize(difference(emptySet(), emptySet()), 0);
assertMinSize(difference(setSize(2), setSize(3)), 0);
assertMinSize(difference(setSize(3), setSize(2)), 1);
assertMinSize(difference(setSizeRange(10, 20), setSizeRange(1, 2)), 8);
assertMinSize(difference(setSizeRange(1, 2), setSizeRange(10, 20)), 0);
assertMinSize(difference(setSizeRange(10, 20), setSizeRange(11, 12)), 0);
assertMinSize(difference(setSizeRange(11, 12), setSizeRange(10, 20)), 0);
}

public void testDifference_maxSize() {
assertMaxSize(difference(emptySet(), emptySet()), 0);
assertMaxSize(difference(setSize(2), setSize(3)), 2);
assertMaxSize(difference(setSize(3), setSize(2)), 3);
assertMaxSize(difference(setSizeRange(10, 20), setSizeRange(1, 2)), 20);
assertMaxSize(difference(setSizeRange(1, 2), setSizeRange(10, 20)), 2);
assertMaxSize(difference(setSizeRange(10, 20), setSizeRange(11, 12)), 20);
assertMaxSize(difference(setSizeRange(11, 12), setSizeRange(10, 20)), 12);
}

public void testSymmetricDifference_minSize() {
assertMinSize(symmetricDifference(emptySet(), emptySet()), 0);
assertMinSize(symmetricDifference(setSize(2), setSize(3)), 1);
assertMinSize(symmetricDifference(setSize(3), setSize(2)), 1);
assertMinSize(symmetricDifference(setSizeRange(10, 20), setSizeRange(1, 2)), 8);
assertMinSize(symmetricDifference(setSizeRange(1, 2), setSizeRange(10, 20)), 8);
assertMinSize(symmetricDifference(setSizeRange(10, 20), setSizeRange(11, 12)), 0);
assertMinSize(symmetricDifference(setSizeRange(11, 12), setSizeRange(10, 20)), 0);
}

public void testSymmetricDifference_maxSize() {
assertMaxSize(symmetricDifference(emptySet(), emptySet()), 0);
assertMaxSize(symmetricDifference(setSize(2), setSize(3)), 5);
assertMaxSize(symmetricDifference(setSize(3), setSize(2)), 5);
assertMaxSize(symmetricDifference(setSizeRange(10, 20), setSizeRange(1, 2)), 22);
assertMaxSize(symmetricDifference(setSizeRange(1, 2), setSizeRange(10, 20)), 22);
assertMaxSize(symmetricDifference(setSizeRange(10, 20), setSizeRange(11, 12)), 32);
assertMaxSize(symmetricDifference(setSizeRange(11, 12), setSizeRange(10, 20)), 32);
}

public void testSymmetricDifference_maxSize_saturated() {
assertThat(symmetricDifference(setSize(Integer.MAX_VALUE), setSize(1)).maxSize())
.isEqualTo(Integer.MAX_VALUE);
assertThat(symmetricDifference(setSize(1), setSize(Integer.MAX_VALUE)).maxSize())
.isEqualTo(Integer.MAX_VALUE);
}

public void testEquals() {
new EqualsTester()
.addEqualityGroup(
emptySet(),
union(emptySet(), emptySet()),
intersection(newHashSet(1, 2), newHashSet(3, 4)),
difference(newHashSet(1, 2), newHashSet(1, 2)),
symmetricDifference(newHashSet(1, 2), newHashSet(1, 2)))
.addEqualityGroup(
singleton(1),
union(singleton(1), singleton(1)),
intersection(newHashSet(1, 2), newHashSet(1, 3)),
difference(newHashSet(1, 2), newHashSet(2, 3)),
symmetricDifference(newHashSet(1, 2, 3), newHashSet(2, 3)))
.addEqualityGroup(
singleton(2),
union(singleton(2), singleton(2)),
intersection(newHashSet(1, 2), newHashSet(2, 3)),
difference(newHashSet(1, 2), newHashSet(1, 3)),
symmetricDifference(newHashSet(1, 2, 3), newHashSet(1, 3)))
.addEqualityGroup(
newHashSet(1, 2),
union(singleton(1), singleton(2)),
intersection(newHashSet(1, 2), newHashSet(1, 2, 3)),
difference(newHashSet(1, 2, 3), newHashSet(3)),
symmetricDifference(newHashSet(1, 3), newHashSet(2, 3)))
.addEqualityGroup(
newHashSet(3, 2),
union(singleton(3), singleton(2)),
intersection(newHashSet(3, 2), newHashSet(3, 2, 1)),
difference(newHashSet(3, 2, 1), newHashSet(1)),
symmetricDifference(newHashSet(3, 1), newHashSet(2, 1)))
.addEqualityGroup(
newHashSet(1, 2, 3),
union(newHashSet(1, 2), newHashSet(2, 3)),
intersection(newHashSet(1, 2, 3), newHashSet(1, 2, 3)),
difference(newHashSet(1, 2, 3), emptySet()),
symmetricDifference(emptySet(), newHashSet(1, 2, 3)))
.testEquals();
}

public void testEquals_otherSetContainsThrows() {
new EqualsTester()
.addEqualityGroup(new SetContainsThrows())
.addEqualityGroup(intersection(singleton(null), singleton(null))) // NPE
.addEqualityGroup(intersection(singleton(0), singleton(0))) // CCE
.testEquals();
}

/** Returns a {@link Set} with a {@link Set#size()} of {@code size}. */
private static ContiguousSet<Integer> setSize(int size) {
checkArgument(size >= 0);
ContiguousSet<Integer> set = ContiguousSet.closedOpen(0, size);
checkState(set.size() == size);
return set;
}

/**
* Returns a {@link SetView} with a {@link SetView#minSize()} of {@code min} and a {@link
* SetView#maxSize()} of {@code max}.
*/
private static SetView<Integer> setSizeRange(int min, int max) {
checkArgument(min >= 0 && max >= min);
SetView<Integer> set = difference(setSize(max), setSize(max - min));
checkState(set.minSize() == min && set.maxSize() == max);
return set;
}

/**
* Asserts that {@code code} has a {@link SetView#minSize()} of {@code min} and a {@link
* Set#size()} of at least {@code min}.
*/
private static void assertMinSize(SetView<?> set, int min) {
assertThat(set.minSize()).isEqualTo(min);
assertThat(set.size()).isAtLeast(min);
}

/**
* Asserts that {@code code} has a {@link SetView#maxSize()} of {@code max} and a {@link
* Set#size()} of at most {@code max}.
*/
private static void assertMaxSize(SetView<?> set, int max) {
assertThat(set.maxSize()).isEqualTo(max);
assertThat(set.size()).isAtMost(max);
}

/**
* A {@link Set} that throws {@link NullPointerException} and {@link ClassCastException} from
* {@link #contains}.
*/
private static final class SetContainsThrows extends AbstractSet<Void> {
@Override
public boolean contains(@Nullable Object o) {
throw o == null ? new NullPointerException() : new ClassCastException();
}

@Override
public int size() {
return 0;
}

@Override
public Iterator<Void> iterator() {
return emptyIterator();
}
}
}
Loading
Loading