diff --git a/guava/src/com/google/common/graph/AbstractBaseGraph.java b/guava/src/com/google/common/graph/AbstractBaseGraph.java index 9ce3838947b8..04311a21a979 100644 --- a/guava/src/com/google/common/graph/AbstractBaseGraph.java +++ b/guava/src/com/google/common/graph/AbstractBaseGraph.java @@ -31,6 +31,9 @@ import com.google.common.primitives.Ints; import java.util.AbstractSet; import java.util.Set; +import java.util.Iterator; +import java.util.Map; +import java.util.HashMap; import org.jspecify.annotations.Nullable; /** @@ -63,6 +66,9 @@ protected long edgeCount() { * An implementation of {@link BaseGraph#edges()} defined in terms of {@link Graph#nodes()} and * {@link #successors(Object)}. */ + // Cache for edges + private transient Set> cachedEdges; + @Override public Set> edges() { return new AbstractSet>() { @@ -104,34 +110,45 @@ public ElementOrder incidentEdgeOrder() { } @Override - public Set> incidentEdges(N node) { - checkNotNull(node); - checkArgument(nodes().contains(node), "Node %s is not an element of this graph.", node); - IncidentEdgeSet incident = - new IncidentEdgeSet(this, node) { - @Override - public UnmodifiableIterator> iterator() { - if (graph.isDirected()) { - return Iterators.unmodifiableIterator( - Iterators.concat( - Iterators.transform( - graph.predecessors(node).iterator(), - (N predecessor) -> EndpointPair.ordered(predecessor, node)), - Iterators.transform( - // filter out 'node' from successors (already covered by predecessors, - // above) - Sets.difference(graph.successors(node), ImmutableSet.of(node)).iterator(), - (N successor) -> EndpointPair.ordered(node, successor)))); - } else { - return Iterators.unmodifiableIterator( - Iterators.transform( - graph.adjacentNodes(node).iterator(), - (N adjacentNode) -> EndpointPair.unordered(node, adjacentNode))); - } - } - }; - return nodeInvalidatableSet(incident, node); - } + public Set> edges() { + if (cachedEdges == null) { + cachedEdges = new AbstractSet>() { + @Override + public Iterator> iterator() { + return endpointPairIterator(); + } + + @Override + public int size() { + return edgeCount(); + } + + @Override + public boolean contains(Object obj) { + if (!(obj instanceof EndpointPair)) { + return false; + } + EndpointPair endpointPair = (EndpointPair) obj; + return isDirected() == endpointPair.isOrdered() + && nodes().contains(endpointPair.nodeU()) + && successors((N) endpointPair.nodeU()).contains(endpointPair.nodeV()); + } + }; + } + return cachedEdges; + } + + // Method to clear the cache when the graph is modified + protected void clearCache() { + cachedEdges = null; + } + + // Abstract methods that subclasses must implement + protected abstract Iterator> endpointPairIterator(); + protected abstract int edgeCount(); + protected abstract boolean isDirected(); + protected abstract Set nodes(); + protected abstract Set successors(N node); @Override public int degree(N node) {