diff --git a/core/commonMain/src/implementations/immutableMap/PersistentHashMapBuilderContentIterators.kt b/core/commonMain/src/implementations/immutableMap/PersistentHashMapBuilderContentIterators.kt index 986369ca..ccc03ab5 100644 --- a/core/commonMain/src/implementations/immutableMap/PersistentHashMapBuilderContentIterators.kt +++ b/core/commonMain/src/implementations/immutableMap/PersistentHashMapBuilderContentIterators.kt @@ -57,7 +57,7 @@ internal open class PersistentHashMapBuilderBaseIterator( val currentKey = currentKey() builder.remove(lastIteratedKey) - resetPath(currentKey.hashCode(), builder.node, currentKey, 0) + resetPath(currentKey.hashCode(), builder.node, currentKey, 0, lastIteratedKey.hashCode(), afterRemove = true) } else { builder.remove(lastIteratedKey) } @@ -82,7 +82,7 @@ internal open class PersistentHashMapBuilderBaseIterator( expectedModCount = builder.modCount } - private fun resetPath(keyHash: Int, node: TrieNode<*, *>, key: K, pathIndex: Int) { + private fun resetPath(keyHash: Int, node: TrieNode<*, *>, key: K, pathIndex: Int, removedKeyHash: Int = 0, afterRemove: Boolean = false) { val shift = pathIndex * LOG_MAX_BRANCHING_FACTOR if (shift > MAX_SHIFT) { // collision @@ -99,6 +99,21 @@ internal open class PersistentHashMapBuilderBaseIterator( if (node.hasEntryAt(keyPositionMask)) { // key is directly in buffer val keyIndex = node.entryKeyIndex(keyPositionMask) + // After removing an element, we need to handle node promotion properly to maintain a correct iteration order. + // `removedKeyPositionMask` represents the bit position of the removed key's hash at the current level. + // This is needed to detect if the current key was potentially promoted from a deeper level. + val removedKeyPositionMask = if (afterRemove) 1 shl indexSegment(removedKeyHash, shift) else 0 + + // Check if the removed key is at the same position as the current key and was previously at a deeper level. + // This indicates a node promotion occurred during removal, + // and we need to handle it in a special way to prevent re-traversing already visited elements. + if (keyPositionMask == removedKeyPositionMask && pathIndex < pathLastIndex) { + // Instead of traversing the normal way, we create a special path entry at the previous depth + // that points directly to the promoted entry, maintaining the original iteration sequence. + path[pathLastIndex].reset(arrayOf(node.buffer[keyIndex], node.buffer[keyIndex + 1]), ENTRY_SIZE) + return + } + // assert(node.keyAtIndex(keyIndex) == key) path[pathIndex].reset(node.buffer, ENTRY_SIZE * node.entryCount(), keyIndex) @@ -111,7 +126,7 @@ internal open class PersistentHashMapBuilderBaseIterator( val nodeIndex = node.nodeIndex(keyPositionMask) val targetNode = node.nodeAtIndex(nodeIndex) path[pathIndex].reset(node.buffer, ENTRY_SIZE * node.entryCount(), nodeIndex) - resetPath(keyHash, targetNode, key, pathIndex + 1) + resetPath(keyHash, targetNode, key, pathIndex + 1, removedKeyHash, afterRemove) } private fun checkNextWasInvoked() { diff --git a/core/commonMain/src/implementations/immutableMap/TrieNode.kt b/core/commonMain/src/implementations/immutableMap/TrieNode.kt index 9fa6898b..e00eaa7a 100644 --- a/core/commonMain/src/implementations/immutableMap/TrieNode.kt +++ b/core/commonMain/src/implementations/immutableMap/TrieNode.kt @@ -180,7 +180,7 @@ internal class TrieNode( } /** The given [newNode] must not be a part of any persistent map instance. */ - private fun updateNodeAtIndex(nodeIndex: Int, positionMask: Int, newNode: TrieNode): TrieNode { + private fun updateNodeAtIndex(nodeIndex: Int, positionMask: Int, newNode: TrieNode, owner: MutabilityOwnership? = null): TrieNode { // assert(buffer[nodeIndex] !== newNode) val newNodeBuffer = newNode.buffer if (newNodeBuffer.size == 2 && newNode.nodeMap == 0) { @@ -192,30 +192,14 @@ internal class TrieNode( val keyIndex = entryKeyIndex(positionMask) val newBuffer = buffer.replaceNodeWithEntry(nodeIndex, keyIndex, newNodeBuffer[0], newNodeBuffer[1]) - return TrieNode(dataMap xor positionMask, nodeMap xor positionMask, newBuffer) + return TrieNode(dataMap xor positionMask, nodeMap xor positionMask, newBuffer, owner) } - val newBuffer = buffer.copyOf(buffer.size) - newBuffer[nodeIndex] = newNode - return TrieNode(dataMap, nodeMap, newBuffer) - } - - /** The given [newNode] must not be a part of any persistent map instance. */ - private fun mutableUpdateNodeAtIndex(nodeIndex: Int, newNode: TrieNode, owner: MutabilityOwnership): TrieNode { - assert(newNode.ownedBy === owner) -// assert(buffer[nodeIndex] !== newNode) - - // nodes (including collision nodes) that have only one entry are upped if they have no siblings - if (buffer.size == 1 && newNode.buffer.size == ENTRY_SIZE && newNode.nodeMap == 0) { -// assert(dataMap == 0 && nodeMap xor positionMask == 0) - newNode.dataMap = nodeMap - return newNode - } - - if (ownedBy === owner) { + if (owner != null && ownedBy === owner) { buffer[nodeIndex] = newNode return this } + val newBuffer = buffer.copyOf() newBuffer[nodeIndex] = newNode return TrieNode(dataMap, nodeMap, newBuffer, owner) @@ -716,7 +700,7 @@ internal class TrieNode( if (targetNode === newNode) { return this } - return mutableUpdateNodeAtIndex(nodeIndex, newNode, mutator.ownership) + return updateNodeAtIndex(nodeIndex, keyPositionMask, newNode, mutator.ownership) } // key is absent @@ -791,7 +775,7 @@ internal class TrieNode( newNode == null -> mutableRemoveNodeAtIndex(nodeIndex, positionMask, owner) targetNode !== newNode -> - mutableUpdateNodeAtIndex(nodeIndex, newNode, owner) + updateNodeAtIndex(nodeIndex, positionMask, newNode, owner) else -> this } diff --git a/core/commonTest/src/contract/map/PersistentHashMapBuilderTest.kt b/core/commonTest/src/contract/map/PersistentHashMapBuilderTest.kt new file mode 100644 index 00000000..732eb81a --- /dev/null +++ b/core/commonTest/src/contract/map/PersistentHashMapBuilderTest.kt @@ -0,0 +1,121 @@ +/* + * Copyright 2016-2025 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package tests.contract.map + +import kotlinx.collections.immutable.implementations.immutableMap.PersistentHashMap +import kotlinx.collections.immutable.persistentHashMapOf +import tests.stress.IntWrapper +import kotlin.collections.iterator +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class PersistentHashMapBuilderTest { + + @Test + fun `should correctly iterate after removing integer key and promotion colliding key during iteration`() { + val removedKey = 0 + val map: PersistentHashMap = + persistentHashMapOf(1 to "a", 2 to "b", 3 to "c", removedKey to "y", 32 to "z") + as PersistentHashMap + + validatePromotion(map, removedKey) + } + + @Test + fun `should correctly iterate after removing IntWrapper key and promotion colliding key during iteration`() { + val removedKey = IntWrapper(0, 0) + val map: PersistentHashMap = persistentHashMapOf( + removedKey to "a", + IntWrapper(1, 0) to "b", + IntWrapper(2, 32) to "c", + IntWrapper(3, 32) to "d" + ) as PersistentHashMap + + validatePromotion(map, removedKey) + } + + private fun validatePromotion(map: PersistentHashMap, removedKey: K) { + val builder = map.builder() + val iterator = builder.entries.iterator() + + val expectedCount = map.size + var actualCount = 0 + + while (iterator.hasNext()) { + val (key, _) = iterator.next() + if (key == removedKey) { + iterator.remove() + } + actualCount++ + } + + val resultMap = builder.build() + for ((key, value) in map) { + if (key != removedKey) { + assertTrue(key in resultMap) + assertEquals(resultMap[key], value) + } else { + assertFalse(key in resultMap) + } + } + + assertEquals(expectedCount, actualCount) + } + + @Test + fun `removing twice on iterators throws IllegalStateException`() { + val map: PersistentHashMap = + persistentHashMapOf(1 to "a", 2 to "b", 3 to "c", 0 to "y", 32 to "z") as PersistentHashMap + val builder = map.builder() + val iterator = builder.entries.iterator() + + assertFailsWith { + while (iterator.hasNext()) { + val (key, _) = iterator.next() + if (key == 0) iterator.remove() + if (key == 0) { + iterator.remove() + iterator.remove() + } + } + } + } + + @Test + fun `removing elements from different iterators throws ConcurrentModificationException`() { + val map: PersistentHashMap = + persistentHashMapOf(1 to "a", 2 to "b", 3 to "c", 0 to "y", 32 to "z") as PersistentHashMap + val builder = map.builder() + val iterator1 = builder.entries.iterator() + val iterator2 = builder.entries.iterator() + + assertFailsWith { + while (iterator1.hasNext()) { + val (key, _) = iterator1.next() + iterator2.next() + if (key == 0) iterator1.remove() + if (key == 2) iterator2.remove() + } + } + } + + @Test + fun `removing element from one iterator and accessing another throws ConcurrentModificationException`() { + val map = persistentHashMapOf(1 to "a", 2 to "b", 3 to "c") + val builder = map.builder() + val iterator1 = builder.entries.iterator() + val iterator2 = builder.entries.iterator() + + assertFailsWith { + iterator1.next() + iterator1.remove() + iterator2.next() + } + } +} \ No newline at end of file diff --git a/core/commonTest/src/contract/map/PersistentHashMapTest.kt b/core/commonTest/src/contract/map/PersistentHashMapTest.kt new file mode 100644 index 00000000..c1152f4a --- /dev/null +++ b/core/commonTest/src/contract/map/PersistentHashMapTest.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2016-2025 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package tests.contract.map + +import kotlinx.collections.immutable.implementations.immutableMap.PersistentHashMap +import kotlinx.collections.immutable.persistentHashMapOf +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class PersistentHashMapTest { + + @Test + fun `if the collision is of size 2 and one of the keys is removed the remaining key must be promoted`() { + val map1: PersistentHashMap = + persistentHashMapOf(-1 to "a", 0 to "b", 32 to "c") as PersistentHashMap + val builder = map1.builder() + val map2 = builder.build() + + assertTrue(map1.equals(builder)) + assertEquals(map1, map2.toMap()) + assertEquals(map1, map2) + + val map3 = map1.remove(0) + builder.remove(0) + val map4 = builder.build() + + assertTrue(map3.equals(builder)) + assertEquals(map3, map4.toMap()) + assertEquals(map3, map4) + } +} \ No newline at end of file diff --git a/core/commonTest/src/contract/set/PersistentHashSetBuilderTest.kt b/core/commonTest/src/contract/set/PersistentHashSetBuilderTest.kt new file mode 100644 index 00000000..df1df86a --- /dev/null +++ b/core/commonTest/src/contract/set/PersistentHashSetBuilderTest.kt @@ -0,0 +1,119 @@ +/* + * Copyright 2016-2025 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package tests.contract.set + +import kotlinx.collections.immutable.implementations.immutableSet.PersistentHashSet +import kotlinx.collections.immutable.persistentHashSetOf +import tests.stress.IntWrapper +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class PersistentHashSetBuilderTest { + + @Test + fun `should correctly iterate after removing integer element`() { + val removedElement = 0 + val set: PersistentHashSet = + persistentHashSetOf(1, 2, 3, removedElement, 32) + as PersistentHashSet + + validate(set, removedElement) + } + + @Test + fun `should correctly iterate after removing IntWrapper element`() { + val removedElement = IntWrapper(0, 0) + val set: PersistentHashSet = persistentHashSetOf( + removedElement, + IntWrapper(1, 0), + IntWrapper(2, 32), + IntWrapper(3, 32) + ) as PersistentHashSet + + validate(set, removedElement) + } + + private fun validate(set: PersistentHashSet, removedElement: E) { + val builder = set.builder() + val iterator = builder.iterator() + + val expectedCount = set.size + var actualCount = 0 + + while (iterator.hasNext()) { + val element = iterator.next() + if (element == removedElement) { + iterator.remove() + } + actualCount++ + } + + val resultSet = builder.build() + for (element in set) { + if (element != removedElement) { + assertTrue(element in resultSet) + } else { + assertFalse(element in resultSet) + } + } + + assertEquals(expectedCount, actualCount) + } + + @Test + fun `removing twice on iterators throws IllegalStateException`() { + val set: PersistentHashSet = + persistentHashSetOf(1, 2, 3, 0, 32) as PersistentHashSet + val builder = set.builder() + val iterator = builder.iterator() + + assertFailsWith { + while (iterator.hasNext()) { + val element = iterator.next() + if (element == 0) iterator.remove() + if (element == 0) { + iterator.remove() + iterator.remove() + } + } + } + } + + @Test + fun `removing elements from different iterators throws ConcurrentModificationException`() { + val set: PersistentHashSet = + persistentHashSetOf(1, 2, 3, 0, 32) as PersistentHashSet + val builder = set.builder() + val iterator1 = builder.iterator() + val iterator2 = builder.iterator() + + assertFailsWith { + while (iterator1.hasNext()) { + val element1 = iterator1.next() + iterator2.next() + if (element1 == 0) iterator1.remove() + if (element1 == 2) iterator2.remove() + } + } + } + + @Test + fun `removing element from one iterator and accessing another throws ConcurrentModificationException`() { + val set = persistentHashSetOf(1, 2, 3) + val builder = set.builder() + val iterator1 = builder.iterator() + val iterator2 = builder.iterator() + + assertFailsWith { + iterator1.next() + iterator1.remove() + iterator2.next() + } + } +} \ No newline at end of file diff --git a/core/commonTest/src/contract/set/PersistentHashSetTest.kt b/core/commonTest/src/contract/set/PersistentHashSetTest.kt new file mode 100644 index 00000000..d1a5c777 --- /dev/null +++ b/core/commonTest/src/contract/set/PersistentHashSetTest.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2016-2025 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package tests.contract.set + +import kotlinx.collections.immutable.persistentHashSetOf +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class PersistentHashSetTest { + + @Test + fun `persistentHashSet and their builder should be equal before and after modification`() { + val set1 = persistentHashSetOf(-1, 0, 32) + val builder = set1.builder() + + assertTrue(set1.equals(builder)) + assertEquals(set1, builder.build()) + assertEquals(set1, builder.build().toSet()) + + val set2 = set1.remove(0) + builder.remove(0) + + assertEquals(set2, builder.build().toSet()) + assertEquals(set2, builder.build()) + } +} \ No newline at end of file diff --git a/core/commonTest/src/contract/set/PersistentOrderedSetTest.kt b/core/commonTest/src/contract/set/PersistentOrderedSetTest.kt new file mode 100644 index 00000000..a3d639d9 --- /dev/null +++ b/core/commonTest/src/contract/set/PersistentOrderedSetTest.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2016-2025 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package tests.contract.set + +import kotlinx.collections.immutable.persistentSetOf +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class PersistentOrderedSetTest { + + /** + * Test from issue: https://github.com/Kotlin/kotlinx.collections.immutable/issues/204 + */ + @Test + fun `persistentOrderedSet and their builder should be equal before and after modification`() { + val set1 = persistentSetOf(-486539264, 16777216, 0, 67108864) + val builder = set1.builder() + + assertTrue(set1.equals(builder)) + assertEquals(set1, builder.build()) + assertEquals(set1, builder.build().toSet()) + + val set2 = set1.remove(0) + builder.remove(0) + + assertEquals(set2, builder.build().toSet()) + assertEquals(set2, builder.build()) + } +} \ No newline at end of file