diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 4aa95ad42ec7f..12943dbe0840c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -126,6 +126,7 @@ abstract class StatePartitionReaderBase( stateStoreColFamilySchema.keyStateEncoderSpec.get, useMultipleValuesPerKey = useMultipleValuesPerKey, isInternal = isInternal) + store.abort() } provider } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 15794cada6753..3974aa1304b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec @@ -43,14 +43,120 @@ private[sql] class RocksDBStateStoreProvider with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore { + /** + * Implementation of a state store that uses RocksDB as the backing data store. + * + * This store implements a state machine with the following states: + * - UPDATING: The store is being updated and has not yet been committed or aborted + * - COMMITTED: Updates have been successfully committed + * - ABORTED: Updates have been aborted + * + * Operations are validated against the current state to ensure proper usage: + * - Get/put/remove/iterator operations are only allowed in UPDATING state + * - Commit is only allowed in UPDATING state + * - Abort is allowed in UPDATING or ABORTED state + * - Metrics retrieval is only allowed in COMMITTED or ABORTED state + * + * Each store instance is assigned a unique stamp when created, which is used to + * verify that operations are performed by the owning thread and to prevent + * concurrent modifications to the same store. + */ + class RocksDBStateStore(lastVersion: Long, stamp: Long) extends StateStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE case object ABORTED extends STATE + private sealed trait TRANSITION + private case object UPDATE extends TRANSITION + private case object ABORT extends TRANSITION + private case object COMMIT extends TRANSITION + private case object METRICS extends TRANSITION + @volatile private var state: STATE = UPDATING + + override def getReadStamp: Long = { + stamp + } + + /** + * Validates the expected state, throws exception if state is not as expected. + * Returns the current state + * + * @param possibleStates Expected possible states + */ + private def validateState(possibleStates: List[STATE]): STATE = { + if (!possibleStates.contains(state)) { + throw StateStoreErrors.stateStoreOperationOutOfOrder( + s"Expected possible states $possibleStates but found $state") + } + state + } + + /** + * Throws error if transition is illegal. + * MUST be called for every StateStore method. + * + * @param transition The transition type of the operation. + */ + private def validateAndTransitionState(transition: TRANSITION): Unit = { + val newState = transition match { + case UPDATE => + stateMachine.verifyStamp(stamp) + state match { + case UPDATING => UPDATING + case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder( + s"Cannot update after committed") + case ABORTED => throw StateStoreErrors.stateStoreOperationOutOfOrder( + "Cannot update after aborted") + } + case ABORT => + state match { + case UPDATING => + stateMachine.verifyStamp(stamp) + ABORTED + case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder( + "Cannot abort after committed") + case ABORTED => ABORTED + } + case COMMIT => + state match { + case UPDATING => + stateMachine.verifyStamp(stamp) + COMMITTED + case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder( + "Cannot commit after committed") + case ABORTED => throw StateStoreErrors.stateStoreOperationOutOfOrder( + "Cannot commit after aborted") + } + case METRICS => + state match { + case UPDATING => throw StateStoreErrors.stateStoreOperationOutOfOrder( + "Cannot get metrics in UPDATING state") + case COMMITTED => COMMITTED + case ABORTED => ABORTED + } + } + state = newState + } + + // Add a listener for task threads to abort when the task completes and hasn't released + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] { + _ => + try { + if (state == UPDATING) { + abort() + } + } catch { + case NonFatal(e) => + logWarning("Failed to abort state store", e) + } finally { + stateMachine.releaseStore(stamp, throwEx = false) + } + }) + + // State row format validated @volatile private var isValidated = false override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId @@ -64,6 +170,8 @@ private[sql] class RocksDBStateStoreProvider keyStateEncoderSpec: KeyStateEncoderSpec, useMultipleValuesPerKey: Boolean = false, isInternal: Boolean = false): Unit = { + validateAndTransitionState(UPDATE) + verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal) val dataEncoderCacheKey = StateRowEncoderCacheKey( @@ -105,6 +213,8 @@ private[sql] class RocksDBStateStoreProvider } override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { + validateAndTransitionState(UPDATE) + verify(key != null, "Key cannot be null") verifyColFamilyOperations("get", colFamilyName) @@ -131,6 +241,8 @@ private[sql] class RocksDBStateStoreProvider * values per key. */ override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = { + validateAndTransitionState(UPDATE) + verify(key != null, "Key cannot be null") verifyColFamilyOperations("valuesIterator", colFamilyName) @@ -147,7 +259,8 @@ private[sql] class RocksDBStateStoreProvider override def merge(key: UnsafeRow, value: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { - verify(state == UPDATING, "Cannot merge after already committed or aborted") + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("merge", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) @@ -162,7 +275,8 @@ private[sql] class RocksDBStateStoreProvider } override def put(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): Unit = { - verify(state == UPDATING, "Cannot put after already committed or aborted") + validateAndTransitionState(UPDATE) + verify(key != null, "Key cannot be null") require(value != null, "Cannot put a null value") verifyColFamilyOperations("put", colFamilyName) @@ -172,7 +286,8 @@ private[sql] class RocksDBStateStoreProvider } override def remove(key: UnsafeRow, colFamilyName: String): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") + validateAndTransitionState(UPDATE) + verify(key != null, "Key cannot be null") verifyColFamilyOperations("remove", colFamilyName) @@ -181,6 +296,7 @@ private[sql] class RocksDBStateStoreProvider } override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) // Note this verify function only verify on the colFamilyName being valid, // we are actually doing prefix when useColumnFamilies, // but pass "iterator" to throw correct error message @@ -215,6 +331,8 @@ private[sql] class RocksDBStateStoreProvider override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): Iterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("prefixScan", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) @@ -231,12 +349,18 @@ private[sql] class RocksDBStateStoreProvider } var checkpointInfo: Option[StateStoreCheckpointInfo] = None + private var storedMetrics: Option[RocksDBMetrics] = None + override def commit(): Long = synchronized { + validateState(List(UPDATING)) try { verify(state == UPDATING, "Cannot commit after already committed or aborted") val (newVersion, newCheckpointInfo) = rocksDB.commit() checkpointInfo = Some(newCheckpointInfo) - state = COMMITTED + storedMetrics = rocksDB.metricsOpt + validateAndTransitionState(COMMIT) + stateMachine.releaseStore(stamp) + logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " + log"for ${MDC(STATE_STORE_ID, id)}") newVersion @@ -247,14 +371,18 @@ private[sql] class RocksDBStateStoreProvider } override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - logInfo(log"Aborting ${MDC(VERSION_NUM, version + 1)} " + - log"for ${MDC(STATE_STORE_ID, id)}") - rocksDB.rollback() - state = ABORTED + if (validateState(List(UPDATING, ABORTED)) != ABORTED) { + logInfo(log"Aborting ${MDC(VERSION_NUM, version + 1)} " + + log"for ${MDC(STATE_STORE_ID, id)}") + rocksDB.rollback() + validateAndTransitionState(ABORT) + stateMachine.releaseStore(stamp) + } } override def metrics: StateStoreMetrics = { + validateAndTransitionState(METRICS) + val rocksDBMetricsOpt = rocksDB.metricsOpt if (rocksDBMetricsOpt.isDefined) { @@ -337,6 +465,8 @@ private[sql] class RocksDBStateStoreProvider } override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = { + validateAndTransitionState(METRICS) + checkpointInfo match { case Some(info) => info case None => throw StateStoreErrors.stateStoreOperationOutOfOrder( @@ -356,6 +486,8 @@ private[sql] class RocksDBStateStoreProvider /** Remove column family if exists */ override def removeColFamilyIfExists(colFamilyName: String): Boolean = { + validateAndTransitionState(UPDATE) + verifyColFamilyCreationOrDeletion("remove_col_family", colFamilyName) verify(useColumnFamilies, "Column families are not supported in this store") @@ -444,17 +576,57 @@ private[sql] class RocksDBStateStoreProvider override def stateStoreId: StateStoreId = stateStoreId_ - override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + private lazy val stateMachine: RocksDBStateStoreProviderStateMachine = + new RocksDBStateStoreProviderStateMachine(stateStoreId, RocksDBConf(storeConf)) + + /** + * Creates and returns a state store with the specified parameters. + * + * @param version The version of the state store to load + * @param uniqueId Optional unique identifier for checkpoint + * @param readOnly Whether to open the store in read-only mode + * @param existingStore Optional existing store to reuse instead of creating a new one + * @return The loaded state store + */ + private def loadStateStore( + version: Long, + uniqueId: Option[String], + readOnly: Boolean, + existingStore: Option[ReadStateStore] = None): StateStore = { try { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None) - new RocksDBStateStore(version) - } - catch { + + // Determine stamp - either use existing or acquire new + val stamp = existingStore.map(_.getReadStamp).getOrElse { + stateMachine.acquireStore() + } + + try { + // Load RocksDB store + rocksDB.load( + version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + readOnly = readOnly) + + // Return appropriate store instance + existingStore match { + case Some(stateStore: RocksDBStateStore) => + // Reuse existing store for getWriteStore case + stateStore + case Some(_) => + throw new IllegalArgumentException("Existing store must be a RocksDBStateStore") + case None => + // Create new store instance for getStore/getReadStore cases + new RocksDBStateStore(version, stamp) + } + } catch { + case e: Throwable => + stateMachine.releaseStore(stamp) + throw e + } + } catch { case e: SparkException if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) => throw e @@ -467,32 +639,25 @@ private[sql] class RocksDBStateStoreProvider } } + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + loadStateStore(version, uniqueId, readOnly = false) + } + + override def getWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = { + assert(version == readStore.version) + loadStateStore(version, uniqueId, readOnly = false, existingStore = Some(readStore)) + } + override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = { - try { - if (version < 0) { - throw QueryExecutionErrors.unexpectedStateStoreVersion(version) - } - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, - readOnly = true) - new RocksDBStateStore(version) - } - catch { - case e: SparkException - if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) => - throw e - case e: OutOfMemoryError => - throw QueryExecutionErrors.notEnoughMemoryToLoadStore( - stateStoreId.toString, - "ROCKSDB_STORE_PROVIDER", - e) - case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) - } + loadStateStore(version, uniqueId, readOnly = true) } override def doMaintenance(): Unit = { try { + stateMachine.maintenanceStore() rocksDB.doMaintenance() } catch { // SPARK-46547 - Swallow non-fatal exception in maintenance task to avoid deadlock between @@ -504,6 +669,7 @@ private[sql] class RocksDBStateStoreProvider } override def close(): Unit = { + stateMachine.closeStore() rocksDB.close() } @@ -560,8 +726,15 @@ private[sql] class RocksDBStateStoreProvider if (endVersion < snapshotVersion) { throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion) } - rocksDB.loadFromSnapshot(snapshotVersion, endVersion) - new RocksDBStateStore(endVersion) + val stamp = stateMachine.acquireStore() + try { + rocksDB.loadFromSnapshot(snapshotVersion, endVersion) + new RocksDBStateStore(endVersion, stamp) + } catch { + case e: Throwable => + stateMachine.releaseStore(stamp) + throw e + } } catch { case e: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProviderStateMachine.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProviderStateMachine.scala new file mode 100644 index 0000000000000..b494e3de5131e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProviderStateMachine.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong +import javax.annotation.concurrent.GuardedBy + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.errors.QueryExecutionErrors + +/** + * A state machine that manages the lifecycle of RocksDB state store instances. + * + * This class enforces proper state transitions and ensures thread-safety for accessing + * state stores. It prevents concurrent modifications to the same state store by using + * a stamp-based locking mechanism. + * + * State Lifecycle: + * - RELEASED: The store is not being accessed by any thread + * - ACQUIRED: The store is currently being accessed by a thread + * - CLOSED: The store has been closed and can no longer be used + * + * Valid Transitions: + * - RELEASED -> ACQUIRED: When a thread acquires the store + * - ACQUIRED -> RELEASED: When a thread releases the store + * - RELEASED -> CLOSED: When the store is shut down + * - ACQUIRED -> MAINTENANCE: Maintenance can be performed on an acquired store + * - RELEASED -> MAINTENANCE: Maintenance can be performed on a released store + * + * Stamps: + * Each time a store is acquired, a unique stamp is generated. This stamp must be presented + * when performing operations on the store and when releasing it. This ensures that only + * the thread that acquired the store can release it or perform operations on it. + */ +class RocksDBStateStoreProviderStateMachine( + stateStoreId: StateStoreId, + rocksDBConf: RocksDBConf) extends Logging { + + private sealed trait STATE + private case object RELEASED extends STATE + private case object ACQUIRED extends STATE + private case object CLOSED extends STATE + + private sealed abstract class TRANSITION(name: String) { + override def toString: String = name + } + private case object LOAD extends TRANSITION("load") + private case object RELEASE extends TRANSITION("release") + private case object CLOSE extends TRANSITION("close") + private case object MAINTENANCE extends TRANSITION("maintenance") + + private val instanceLock = new Object() + @GuardedBy("instanceLock") + private var state: STATE = RELEASED + @GuardedBy("instanceLock") + private var acquiredThreadInfo: AcquiredThreadInfo = _ + + // Can be read without holding any locks, but should only be updated when + // instanceLock is held. + // -1 indicates that the store is not locked. + private[sql] val currentValidStamp = new AtomicLong(-1L) + @GuardedBy("instanceLock") + private var lastValidStamp: Long = 0L + + // Instance lock must be held. + private def incAndGetStamp: Long = { + lastValidStamp += 1 + currentValidStamp.set(lastValidStamp) + lastValidStamp + } + + // Instance lock must be held. + private def awaitNotLocked(transition: TRANSITION): Unit = { + val waitStartTime = System.nanoTime() + def timeWaitedMs = { + val elapsedNanos = System.nanoTime() - waitStartTime + // Convert from nanoseconds to milliseconds + TimeUnit.MILLISECONDS.convert(elapsedNanos, TimeUnit.NANOSECONDS) + } + while (state == ACQUIRED && timeWaitedMs < rocksDBConf.lockAcquireTimeoutMs) { + instanceLock.wait(10) + } + if (state == ACQUIRED) { + val newAcquiredThreadInfo = AcquiredThreadInfo() + val stackTraceOutput = acquiredThreadInfo.threadRef.get.get.getStackTrace.mkString("\n") + val loggingId = s"StateStoreId(opId=${stateStoreId.operatorId}," + + s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})" + throw QueryExecutionErrors.unreleasedThreadError(loggingId, transition.toString, + newAcquiredThreadInfo.toString(), acquiredThreadInfo.toString(), timeWaitedMs, + stackTraceOutput) + } + } + + /** + * Returns oldState, newState. + * Throws error if transition is illegal. + * MUST be called for every StateStoreProvider method. + * Caller MUST hold instance lock. + */ + private def validateAndTransitionState(transition: TRANSITION): (STATE, STATE) = { + val oldState = state + val newState = transition match { + case LOAD => + oldState match { + case RELEASED => ACQUIRED + case ACQUIRED => throw new IllegalStateException("Cannot lock when state is LOCKED") + case CLOSED => throw new IllegalStateException("Cannot lock when state is CLOSED") + } + case RELEASE => + oldState match { + case RELEASED => throw new IllegalStateException("Cannot unlock when state is UNLOCKED") + case ACQUIRED => RELEASED + case CLOSED => throw new IllegalStateException("Cannot unlock when state is CLOSED") + } + case CLOSE => + oldState match { + case RELEASED => CLOSED + case ACQUIRED => throw new IllegalStateException("Cannot closed when state is LOCKED") + case CLOSED => CLOSED + } + case MAINTENANCE => + oldState match { + case RELEASED => RELEASED + case ACQUIRED => ACQUIRED + case CLOSED => throw new IllegalStateException("Cannot do maintenance when state is" + + "CLOSED") + } + } + state = newState + if (newState == ACQUIRED) { + acquiredThreadInfo = AcquiredThreadInfo() + } + (oldState, newState) + } + + def verifyStamp(stamp: Long): Unit = { + if (stamp != currentValidStamp.get()) { + throw new IllegalStateException(s"Invalid stamp $stamp, " + + s"currentStamp: ${currentValidStamp.get()}") + } + } + + // Returns whether store successfully released + def releaseStore(stamp: Long, throwEx: Boolean = true): Boolean = instanceLock.synchronized { + if (!currentValidStamp.compareAndSet(stamp, -1L)) { + if (throwEx) { + throw new IllegalStateException("Invalid stamp for release") + } else { + return false + } + } + validateAndTransitionState(RELEASE) + true + } + + def acquireStore(): Long = instanceLock.synchronized { + awaitNotLocked(LOAD) + validateAndTransitionState(LOAD) + incAndGetStamp + } + + def maintenanceStore(): Unit = instanceLock.synchronized { + validateAndTransitionState(MAINTENANCE) + } + + def closeStore(): Unit = instanceLock.synchronized { + awaitNotLocked(CLOSE) + validateAndTransitionState(CLOSE) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 33a21c79f3db2..124478373231c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -74,6 +74,8 @@ trait ReadStateStore { /** Version of the data in this store before committing updates. */ def version: Long + def getReadStamp: Long = -1 + /** * Get the current value of a non-null key. * @return a non-null row if the key exists in the store, otherwise null. @@ -220,7 +222,7 @@ trait StateStore extends ReadStateStore { } /** Wraps the instance of StateStore to make the instance read-only. */ -class WrappedReadStateStore(store: StateStore) extends ReadStateStore { +class WrappedReadStateStore(private[state] val store: StateStore) extends ReadStateStore { override def id: StateStoreId = store.id override def version: Long = store.version @@ -554,7 +556,11 @@ trait StateStoreProvider { */ def stateStoreId: StateStoreId - /** Called when the provider instance is unloaded from the executor */ + /** + * Called when the provider instance is unloaded from the executor + * WARNING: IF PROVIDER FROM [[StateStore.loadedProviders]], + * CLOSE MUST ONLY BE CALLED FROM MAINTENANCE THREAD! + */ def close(): Unit /** @@ -565,6 +571,11 @@ trait StateStoreProvider { version: Long, stateStoreCkptId: Option[String] = None): StateStore + def getWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = getStore(version, uniqueId) + /** * Return an instance of [[ReadStateStore]] representing state data of the given version * and uniqueID if provided. @@ -924,6 +935,30 @@ object StateStore extends Logging { stateSchemaBroadcast) storeProvider.getStore(version, stateStoreCkptId) } + + def getWriteStore( + readStore: ReadStateStore, + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + version: Long, + stateStoreCkptId: Option[String], + stateSchemaBroadcast: Option[StateSchemaBroadcast], + useColumnFamilies: Boolean, + storeConf: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): StateStore = { + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) + if (version < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(version) + } + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) + val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, + keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey, + stateSchemaBroadcast) + storeProvider.getWriteStore(readStore, version, stateStoreCkptId) + } // scalastyle:on private def getStateStoreProvider( @@ -960,14 +995,45 @@ object StateStore extends Logging { val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds) - providerIdsToUnload.foreach(unload(_)) + providerIdsToUnload.foreach(id => { + loadedProviders.remove(id).foreach( provider => { + // Trigger maintenance thread to immediately do maintenance on and close the provider. + // Doing maintenance first allows us to do maintenance for a constantly-moving state + // store. + logInfo(log"Task thread trigger maintenance on " + + log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER, id)}") + doMaintenanceOnProvider(id, provider, alreadyRemovedFromLoadedProviders = true) + }) + }) provider } } - /** Unload a state store provider */ - def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeProviderId).foreach(_.close()) + /** + * Unload a state store provider. + * If alreadyRemovedFromLoadedProviders is None, provider will be + * removed from loadedProviders and closed. + * If alreadyRemovedFromLoadedProviders is Some, provider will be closed + * using passed in provider. + * WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD! + */ + def unload(storeProviderId: StateStoreProviderId, + alreadyRemovedStoreFromLoadedProviders: Option[StateStoreProvider] = None): Unit = { + var toCloseProviders: List[StateStoreProvider] = Nil + + alreadyRemovedStoreFromLoadedProviders match { + case Some(provider) => + toCloseProviders = provider :: toCloseProviders + case None => + // Copy provider to a local list so we can release loadedProviders lock when closing. + loadedProviders.synchronized { + loadedProviders.remove(storeProviderId).foreach { provider => + toCloseProviders = provider :: toCloseProviders + } + } + } + + toCloseProviders.foreach(_.close()) } /** Unload all state store providers: unit test purpose */ @@ -1038,6 +1104,14 @@ object StateStore extends Logging { } } + // Block until we can process this partition + private def awaitProcessThisPartition(id: StateStoreProviderId): Unit = + maintenanceThreadPoolLock.synchronized { + while (!processThisPartition(id)) { + maintenanceThreadPoolLock.wait() + } + } + /** * Execute background maintenance task in all the loaded store providers if they are still * the active instances according to the coordinator. @@ -1051,47 +1125,7 @@ object StateStore extends Logging { loadedProviders.toSeq }.foreach { case (id, provider) => if (processThisPartition(id)) { - maintenanceThreadPool.execute(() => { - val startTime = System.currentTimeMillis() - try { - provider.doMaintenance() - if (!verifyIfStoreInstanceActive(id)) { - unload(id) - logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}") - } - } catch { - case NonFatal(e) => - logWarning(log"Error managing ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}, " + - log"unloading state store provider", e) - // When we get a non-fatal exception, we just unload the provider. - // - // By not bubbling the exception to the maintenance task thread or the query execution - // thread, it's possible for a maintenance thread pool task to continue failing on - // the same partition. Additionally, if there is some global issue that will cause - // all maintenance thread pool tasks to fail, then bubbling the exception and - // stopping the pool is faster than waiting for all tasks to see the same exception. - // - // However, we assume that repeated failures on the same partition and global issues - // are rare. The benefit to unloading just the partition with an exception is that - // transient issues on a given provider do not affect any other providers; so, in - // most cases, this should be a more performant solution. - unload(id) - } finally { - val duration = System.currentTimeMillis() - startTime - val logMsg = - log"Finished maintenance task for " + - log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" + - log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}\n" - if (duration > 5000) { - logInfo(logMsg) - } else { - logDebug(logMsg) - } - maintenanceThreadPoolLock.synchronized { - maintenancePartitions.remove(id) - } - } - }) + doMaintenanceOnProvider(id, provider) } else { logInfo(log"Not processing partition ${MDC(LogKeys.PARTITION_ID, id)} " + log"for maintenance because it is currently " + @@ -1100,6 +1134,69 @@ object StateStore extends Logging { } } + private def doMaintenanceOnProvider(id: StateStoreProviderId, provider: StateStoreProvider, + alreadyRemovedFromLoadedProviders: Boolean = false): Unit = { + maintenanceThreadPool.execute(() => { + val startTime = System.currentTimeMillis() + if (alreadyRemovedFromLoadedProviders) { + // If provider is already removed from loadedProviders, we MUST process + // this partition to close it, so we block until we can. + awaitProcessThisPartition(id) + } + val awaitingPartitionDuration = System.currentTimeMillis() - startTime + try { + provider.doMaintenance() + // If shouldRemoveFromLoadedProviders is false, we don't need to verify + // with the coordinator as we know it definitely should be unloaded. + if (alreadyRemovedFromLoadedProviders || !verifyIfStoreInstanceActive(id)) { + if (alreadyRemovedFromLoadedProviders) { + unload(id, Some(provider)) + } else { + unload(id) + } + logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}") + } + } catch { + case NonFatal(e) => + logWarning(log"Error managing ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}, " + + log"unloading state store provider", e) + // When we get a non-fatal exception, we just unload the provider. + // + // By not bubbling the exception to the maintenance task thread or the query execution + // thread, it's possible for a maintenance thread pool task to continue failing on + // the same partition. Additionally, if there is some global issue that will cause + // all maintenance thread pool tasks to fail, then bubbling the exception and + // stopping the pool is faster than waiting for all tasks to see the same exception. + // + // However, we assume that repeated failures on the same partition and global issues + // are rare. The benefit to unloading just the partition with an exception is that + // transient issues on a given provider do not affect any other providers; so, in + // most cases, this should be a more performant solution. + if (alreadyRemovedFromLoadedProviders) { + unload(id, Some(provider)) + } else { + unload(id) + } + } finally { + val duration = System.currentTimeMillis() - startTime + val logMsg = + log"Finished maintenance task for " + + log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" + + log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}" + + log" and awaiting_partition_time=" + + log"${MDC(LogKeys.TIME_UNITS, awaitingPartitionDuration)}\n" + if (duration > 5000) { + logInfo(logMsg) + } else { + logDebug(logMsg) + } + maintenanceThreadPoolLock.synchronized { + maintenancePartitions.remove(id) + } + } + }) + } + private def reportActiveStoreInstance( storeProviderId: StateStoreProviderId, otherProviderIds: Seq[StateStoreProviderId]): Seq[StateStoreProviderId] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index d51db6e606e13..99497317e0e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -27,6 +27,30 @@ import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration +/** + * A trait that provides access to state stores for specific partitions in an RDD. + * + * This trait enables the state store optimization pattern for stateful operations + * where the same partition needs to be accessed for both reading and writing. + * Implementing classes maintain a mapping between partition IDs and their associated + * state stores, allowing lookups without creating duplicate connections. + * + * The primary use case is enabling the common pattern for stateful operations: + * 1. A read-only state store is opened to retrieve existing state + * 2. The same state store is then converted to read-write mode for updates + * 3. This avoids having two separate open connections to the same state store + * which would cause blocking or contention issues + */ +trait StateStoreRDDProvider { + /** + * Returns the state store associated with the specified partition, if one exists. + * + * @param partitionId The ID of the partition whose state store should be retrieved + * @return Some(store) if a state store exists for the given partition, None otherwise + */ + def getStateStoreForPartition(partitionId: Int): Option[ReadStateStore] +} + abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], checkpointLocation: String, @@ -82,19 +106,37 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( useColumnFamilies: Boolean = false, extraOptions: Map[String, String] = Map.empty) extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId, - sessionState, storeCoordinator, extraOptions) { + sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider { + + // Using a ConcurrentHashMap to track state stores by partition ID + @transient private lazy val partitionStores = + new java.util.concurrent.ConcurrentHashMap[Int, ReadStateStore]() + + override def getStateStoreForPartition(partitionId: Int): Option[ReadStateStore] = { + Option(partitionStores.get(partitionId)) + } override protected def getPartitions: Array[Partition] = dataRDD.partitions override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { val storeProviderId = getStateProviderId(partition) + val partitionId = partition.index val inputIter = dataRDD.iterator(partition, ctxt) val store = StateStore.getReadOnly( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, - stateStoreCkptIds.map(_.apply(partition.index).head), + stateStoreCkptIds.map(_.apply(partitionId).head), stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value) + + // Store reference for this partition + partitionStores.put(partitionId, store) + + // Register a cleanup callback to be executed when the task completes + ctxt.addTaskCompletionListener[Unit](_ => { + partitionStores.remove(partitionId) + }) + storeReadFunction(store, inputIter) } } @@ -126,16 +168,78 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions + /** + * Recursively searches the RDD lineage to find a StateStoreRDDProvider containing + * an already-opened state store for the current partition. + * + * This method helps implement the read-then-write pattern for stateful operations + * without creating contention issues. Instead of opening separate read and write + * stores that would block each other (since a state store provider can only handle + * one open store at a time), this allows us to: + * 1. Find an existing read store in the RDD lineage + * 2. Convert it to a write store using getWriteStore() + * + * This is particularly important for stateful aggregations where StateStoreRestoreExec + * first reads previous state and StateStoreSaveExec then updates it. + * + * The method performs a depth-first search through the RDD dependency graph. + * + * @param rdd The starting RDD to search from + * @return Some(provider) if a StateStoreRDDProvider is found in the lineage, None otherwise + */ + private def findStateStoreProvider(rdd: RDD[_]): Option[StateStoreRDDProvider] = { + rdd match { + case null => None + case provider: StateStoreRDDProvider => Some(provider) + case _ if rdd.dependencies.isEmpty => None + case _ => + // Search all dependencies + rdd.dependencies.view + .map(dep => findStateStoreProvider(dep.rdd)) + .find(_.isDefined) + .flatten + } + } + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { val storeProviderId = getStateProviderId(partition) - val inputIter = dataRDD.iterator(partition, ctxt) - val store = StateStore.get( - storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, - uniqueId.map(_.apply(partition.index).head), - stateSchemaBroadcast, - useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, - useMultipleValuesPerKey) + + // Try to find a state store provider in the RDD lineage + val store = findStateStoreProvider(dataRDD).flatMap { provider => + provider.getStateStoreForPartition(partition.index) + } match { + case Some(readStore) => + // Convert the read store to a writable store + StateStore.getWriteStore( + readStore, + storeProviderId, + keySchema, + valueSchema, + keyStateEncoderSpec, + storeVersion, + uniqueId.map(_.apply(partition.index).head), + stateSchemaBroadcast, + useColumnFamilies, + storeConf, + hadoopConfBroadcast.value.value, + useMultipleValuesPerKey) + + case None => + // Fall back to creating a new store + StateStore.get( + storeProviderId, + keySchema, + valueSchema, + keyStateEncoderSpec, + storeVersion, + uniqueId.map(_.apply(partition.index).head), + stateSchemaBroadcast, + useColumnFamilies, + storeConf, + hadoopConfBroadcast.value.value, + useMultipleValuesPerKey) + } storeUpdateFunction(store, inputIter) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index a82eff4812953..efa0e103f9481 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType +import org.apache.spark.util.TaskFailureListener package object state { @@ -109,8 +110,9 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeReadFn) val wrappedF = (store: ReadStateStore, iter: Iterator[T]) => { // Clean up the state store. - TaskContext.get().addTaskCompletionListener[Unit](_ => { - store.abort() + TaskContext.get().addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + store.abort() }) cleanedF(store, iter) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index fca7d16012cee..0c2804e58e5ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -1107,6 +1107,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass assert(get(result, "a", 2).get == 2) assert(get(result, "a", 3).get == 3) assert(get(result, "a", 4).isEmpty) + result.abort() provider.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index 476b43e42cb87..744dfffad47a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -56,6 +56,7 @@ class ListStateSuite extends StateVariableSuiteBase { sqlState = Some("42601"), parameters = Map("stateName" -> "listState") ) + store.abort() } } @@ -96,6 +97,7 @@ class ListStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.get().toSeq === Seq.empty[Long]) + store.commit() } } @@ -280,6 +282,7 @@ class ListStateSuite extends StateVariableSuiteBase { assert(!testState2.exists()) assert(testState1.exists()) assert(testState2.get().toSeq === Seq.empty[Long]) + store.commit() } } @@ -311,6 +314,7 @@ class ListStateSuite extends StateVariableSuiteBase { assert(listState2.exists()) assert(!valueState.exists()) assert(listState1.get().toSeq === Seq.empty[Long]) + store.commit() } } @@ -366,6 +370,7 @@ class ListStateSuite extends StateVariableSuiteBase { nextBatchTestState.clear() assert(!nextBatchTestState.exists()) assert(nextBatchTestState.get().isEmpty) + store.commit() } } @@ -393,6 +398,7 @@ class ListStateSuite extends StateVariableSuiteBase { matchPVals = true ) } + store.abort() } } @@ -421,6 +427,7 @@ class ListStateSuite extends StateVariableSuiteBase { assert(ttlValues.forall(_._2 === ttlExpirationMs)) val ttlStateValue = testState.getValueInTTLState() assert(ttlStateValue.isDefined) + store.commit() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index 9a0a891d538ec..9d551d10361ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -70,6 +70,7 @@ class MapStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.iterator().hasNext === false) + store.commit() } } @@ -110,6 +111,7 @@ class MapStateSuite extends StateVariableSuiteBase { assert(!testState2.exists()) assert(testState1.iterator().hasNext === false) assert(testState2.iterator().hasNext === false) + store.commit() } } @@ -170,6 +172,7 @@ class MapStateSuite extends StateVariableSuiteBase { assert(!mapTestState1.exists()) assert(mapTestState2.exists()) assert(mapTestState2.iterator().toList === List(("k2", 4))) + store.commit() } } @@ -228,6 +231,7 @@ class MapStateSuite extends StateVariableSuiteBase { nextBatchTestState.clear() assert(!nextBatchTestState.exists()) assert(nextBatchTestState.getValue("k1") === null) + store.commit() } } @@ -256,6 +260,7 @@ class MapStateSuite extends StateVariableSuiteBase { matchPVals = true ) } + store.abort() } } @@ -286,6 +291,7 @@ class MapStateSuite extends StateVariableSuiteBase { assert(ttlValue.get._2 === ttlExpirationMs) val ttlStateValueIterator = testState.getKeyValuesInTTLState().map(_._2) assert(ttlStateValueIterator.hasNext) + store.commit() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 22150ffde5db6..916c12c39d59c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -61,6 +61,8 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta override def id: StateStoreId = innerStore.id override def version: Long = innerStore.version + override def getReadStamp: Long = innerStore.getReadStamp + override def get( key: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = { @@ -174,12 +176,40 @@ class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { val innerStateStore = innerProvider.getStore(version, stateStoreCkptId) CkptIdCollectingStateStoreWrapper(innerStateStore) } - override def getReadStore(version: Long, uniqueId: Option[String] = None): ReadStateStore = { new WrappedReadStateStore( CkptIdCollectingStateStoreWrapper(innerProvider.getReadStore(version, uniqueId))) } + override def getWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = { + + // Handle the case where we're given a WrappedReadStateStore + readStore match { + case wrappedStore: WrappedReadStateStore => + // Unwrap to get our wrapper + wrappedStore.store match { + case wrapper: CkptIdCollectingStateStoreWrapper => + // Get the inner store from our wrapper + val innerReadStore = wrapper.innerStore + // Call the inner provider's getWriteStore with the inner store + val innerWriteStore = innerProvider.getWriteStore(innerReadStore, version, uniqueId) + // Wrap the result + CkptIdCollectingStateStoreWrapper(innerWriteStore) + case other => + throw new IllegalArgumentException( + "Expected CkptIdCollectingStateStoreWrapper but got " + other.getClass.getName) + } + + case other => + throw new IllegalArgumentException( + "Expected WrappedReadStateStore but got " + + other.getClass.getName) + } + } + override def doMaintenance(): Unit = innerProvider.doMaintenance() override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 5aea0077e2aa8..4b6a342f12c20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -26,7 +26,7 @@ import org.apache.avro.AvroTypeException import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkUnsupportedOperationException} +import org.apache.spark.{SparkConf, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.LocalSparkSession.withSparkSession import org.apache.spark.sql.SparkSession @@ -73,6 +73,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val iter = provider.rocksDB.iterator() assert(iter.hasNext) val kv = iter.next() + store.commit() // Verify the version encoded in first byte of the key and value byte arrays assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) @@ -259,7 +260,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid tryWithProviderResource(newStoreProvider(keySchemaWithSomeUnsupportedTypeCols, RangeKeyScanStateEncoderSpec(keySchemaWithSomeUnsupportedTypeCols, Seq(index)), colFamiliesEnabled)) { provider => - provider.getStore(0) + provider.getStore(0).abort() } } @@ -354,6 +355,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid key._1 }.toSeq assert(result1 === (timerTimestamps ++ timerTimestamps1).sorted) + store1.commit() } } @@ -521,6 +523,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid key._1 }.toSeq assert(result1 === (timerTimestamps ++ timerTimestamps1).sorted) + store1.commit() } } @@ -627,6 +630,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid (key._1, key._2) }.toSeq assert(result === timerTimestamps.sorted) + store.commit() } } @@ -1396,12 +1400,13 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } // verify that ordering for non-null columns on the right in still maintained - val result1: Seq[Int] = store.iterator(cfName).map { kv => + val result1: Seq[Int] = store1.iterator(cfName).map { kv => val keyRow = kv.key keyRow.getInt(1) }.toSeq assert(result1 === timerTimestamps1.map(_._2).sorted) + store1.commit() } } @@ -1456,6 +1461,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid (key._1, key._2) }.toSeq assert(result.map(_._1) === timerTimestamps.map(_._1).sorted) + store.commit() } } @@ -1501,6 +1507,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid (key._1, key._2) }.toSeq assert(result === timerTimestamps.sorted) + store.commit() } } @@ -1544,6 +1551,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid }.toSeq assert(result.size === 1) } + store.commit() } } @@ -1581,6 +1589,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid }.toSeq assert(result.size === idx + 1) } + store.commit() } } @@ -1607,6 +1616,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(valueRowToData(store.get(keyRow2)) === 2) store.remove(keyRow2) assert(store.get(keyRow2) === null) + store.commit() } } @@ -1641,6 +1651,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(!iterator2.hasNext) assert(get(store, "a", 0).isEmpty) + store.commit() } } } @@ -1680,6 +1691,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } } + store.abort() } } @@ -1716,6 +1728,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } } + store.abort() } } @@ -1750,6 +1763,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid verifyStoreOperationUnsupported("prefixScan", colFamiliesEnabled, colFamilyName) { store.prefixScan(dataToKeyRow("a", 1), colFamilyName) } + store.commit() } } @@ -1804,6 +1818,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid sqlState = Some("42802"), parameters = Map("operationType" -> "get", "colFamilyName" -> colFamily1) ) + store.abort() store = provider.getStore(1) // version 1 data recovered correctly @@ -1856,12 +1871,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // this should return the old id, because we didn't remove this colFamily for version 1 store.createColFamilyIfAbsent(colFamily1, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) + store.commit() store = provider.getRocksDBStateStore(3) store.createColFamilyIfAbsent(colFamily4, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) store.createColFamilyIfAbsent(colFamily5, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) + store.commit() } } @@ -1942,10 +1959,12 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val metricPair = store .metrics.customMetrics.find(_._1.name == "rocksdbNumInternalColFamiliesKeys") assert(metricPair.isDefined && metricPair.get._2 === 4) - assert(rowPairsToDataSet(store.iterator(cfName)) === + val store1 = provider.getStore(1) + assert(rowPairsToDataSet(store1.iterator(cfName)) === Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) - assert(rowPairsToDataSet(store.iterator(internalCfName)) === + assert(rowPairsToDataSet(store1.iterator(internalCfName)) === Set(("a", 0) -> 1, ("m", 0) -> 2, ("n", 0) -> 3, ("b", 0) -> 4)) + store1.abort() // Reload the store and remove some keys val reloadedProvider = newStoreProvider(store.id, colFamiliesEnabled) @@ -1962,10 +1981,12 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val metricPairUpdated = reloadedStore .metrics.customMetrics.find(_._1.name == "rocksdbNumInternalColFamiliesKeys") assert(metricPairUpdated.isDefined && metricPairUpdated.get._2 === 3) - assert(rowPairsToDataSet(reloadedStore.iterator(cfName)) === + val reloadedStore2 = reloadedProvider.getStore(2) + assert(rowPairsToDataSet(reloadedStore2.iterator(cfName)) === Set(("a", 0) -> 1, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) - assert(rowPairsToDataSet(reloadedStore.iterator(internalCfName)) === + assert(rowPairsToDataSet(reloadedStore2.iterator(internalCfName)) === Set(("a", 0) -> 1, ("n", 0) -> 3, ("b", 0) -> 4)) + reloadedStore2.commit() } } } @@ -2008,6 +2029,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid sqlState = Some("42802"), parameters = Map("operationType" -> "iterator", "colFamilyName" -> cfName) ) + store.abort() } } } @@ -2044,6 +2066,259 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + test("state transitions with commit and illegal operations") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + // Get a store and put some data + val store = provider.getStore(0) + put(store, "a", 0, 1) + put(store, "b", 0, 2) + + // Verify data is accessible before commit + assert(get(store, "a", 0) === Some(1)) + assert(get(store, "b", 0) === Some(2)) + + // Commit the changes + assert(store.commit() === 1) + assert(store.hasCommitted) + + // Operations after commit should fail with IllegalStateException + val exception = intercept[IllegalStateException] { + put(store, "c", 0, 3) + } + assert(exception.getMessage.contains("Invalid stamp")) + + // Getting a new store for the same version should work + val store1 = provider.getStore(1) + assert(get(store1, "a", 0) === Some(1)) + assert(get(store1, "b", 0) === Some(2)) + + // Can update the new store instance + put(store1, "c", 0, 3) + assert(get(store1, "c", 0) === Some(3)) + + // Commit the new changes + assert(store1.commit() === 2) + } + } + + test("state transitions with abort and subsequent operations") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + // Get a store and put some data + val store = provider.getStore(0) + put(store, "a", 0, 1) + put(store, "b", 0, 2) + + // Abort the changes + store.abort() + + // Operations after abort should fail with IllegalStateException + val exception = intercept[IllegalStateException] { + put(store, "c", 0, 3) + } + assert(exception.getMessage.contains("Invalid stamp")) + + // Get a new store, should be empty since previous changes were aborted + val store1 = provider.getStore(0) + assert(store1.iterator().isEmpty) + + // Put data and commit + put(store1, "d", 0, 4) + assert(store1.commit() === 1) + + // Get a new store and verify data + val store2 = provider.getStore(1) + assert(get(store2, "d", 0) === Some(4)) + store2.commit() + } + } + + test("abort after commit throws StateStoreOperationOutOfOrder") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + assert(store.commit() === 1) + + // Abort after commit should throw a SparkRuntimeException + val exception = intercept[SparkRuntimeException] { + store.abort() + } + checkError( + exception, + condition = "STATE_STORE_OPERATION_OUT_OF_ORDER", + parameters = Map("errorMsg" -> + "Expected possible states List(UPDATING, ABORTED) but found COMMITTED") + ) + + // Get a new store and verify data was committed + val store1 = provider.getStore(1) + assert(get(store1, "a", 0) === Some(1)) + store1.commit() + } + } + + test("multiple aborts are idempotent") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + + // First abort + store.abort() + + // Second abort should not throw + store.abort() + + // Operations should still fail + val exception = intercept[IllegalStateException] { + put(store, "b", 0, 2) + } + assert(exception.getMessage.contains("Invalid stamp")) + } + } + + test("multiple commits throw exception") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + assert(store.commit() === 1) + + // Second commit should fail with stamp verification + val exception = intercept[SparkRuntimeException] { + store.commit() + } + checkError( + exception, + condition = "STATE_STORE_OPERATION_OUT_OF_ORDER", + parameters = Map("errorMsg" -> + "Expected possible states List(UPDATING) but found COMMITTED") + ) + } + } + + test("get metrics works only after commit") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + + // Getting metrics before commit should throw + val exception = intercept[SparkRuntimeException] { + store.metrics + } + checkError( + exception, + condition = "STATE_STORE_OPERATION_OUT_OF_ORDER", + parameters = Map("errorMsg" -> "Cannot get metrics in UPDATING state") + ) + // Commit the changes + assert(store.commit() === 1) + + // Getting metrics after commit should work + val metrics = store.metrics + assert(metrics.numKeys === 1) + } + } + + test("get checkpoint info works only after commit") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + + // Getting checkpoint info before commit should throw + val exception = intercept[SparkRuntimeException] { + store.getStateStoreCheckpointInfo() + } + checkError( + exception, + condition = "STATE_STORE_OPERATION_OUT_OF_ORDER", + parameters = Map("errorMsg" -> "Cannot get metrics in UPDATING state") + ) + + // Commit the changes + assert(store.commit() === 1) + + // Getting checkpoint info after commit should work + val checkpointInfo = store.getStateStoreCheckpointInfo() + assert(checkpointInfo != null) + } + } + + test("read store and write store with common stamp") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + // First prepare some data + val initialStore = provider.getStore(0) + put(initialStore, "a", 0, 1) + assert(initialStore.commit() === 1) + + // Get a read store + val readStore = provider.getReadStore(1) + assert(get(readStore, "a", 0) === Some(1)) + + // Get a write store from the read store + val writeStore = provider.getWriteStore(readStore, 1) + + // Verify data access + assert(get(writeStore, "a", 0) === Some(1)) + + // Update through write store + put(writeStore, "b", 0, 2) + assert(get(writeStore, "b", 0) === Some(2)) + + // Commit the write store + assert(writeStore.commit() === 2) + + // Get a new store and verify + val newStore = provider.getStore(2) + assert(get(newStore, "a", 0) === Some(1)) + assert(get(newStore, "b", 0) === Some(2)) + newStore.commit() + } + } + + test("verify operation validation before and after commit") { + tryWithProviderResource(newStoreProvider(useColumnFamilies = false)) { provider => + val store = provider.getStore(0) + + // Put operations should work in UPDATING state + put(store, "a", 0, 1) + assert(get(store, "a", 0) === Some(1)) + + // Remove operations should work in UPDATING state + remove(store, _._1 == "a") + assert(get(store, "a", 0) === None) + + // Iterator operations should work in UPDATING state + put(store, "b", 0, 2) + assert(rowPairsToDataSet(store.iterator()) === Set(("b", 0) -> 2)) + + // Commit should work in UPDATING state + assert(store.commit() === 1) + + // After commit, state validation should prevent operations due to invalid stamp + // We now expect IllegalStateException instead of SparkRuntimeException + intercept[IllegalStateException] { + put(store, "c", 0, 3) + } + + intercept[IllegalStateException] { + remove(store, _._1 == "b") + } + + intercept[IllegalStateException] { + store.iterator() + } + + // Get a new store for the next version + val store1 = provider.getStore(1) + + // Abort the store + store1.abort() + + // Operations after abort should fail due to invalid stamp + intercept[IllegalStateException] { + put(store1, "c", 0, 3) + } + } + } + override def newStoreProvider(): RocksDBStateStoreProvider = { newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0)) } @@ -2123,7 +2398,10 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid tryWithProviderResource(newStoreProvider(provider.stateStoreId, useColumnFamilies)) { reloadedProvider => val versionToRead = if (version < 0) reloadedProvider.latestVersion else version - reloadedProvider.getStore(versionToRead).iterator().map(rowPairToDataPair).toSet + val store = reloadedProvider.getStore(versionToRead) + val res = store.iterator().map(rowPairToDataPair).toSet + store.abort() + res } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 08648148b4af4..23702d2c30d01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -122,6 +122,38 @@ private object FakeStateStoreProviderWithMaintenanceError { val errorOnMaintenance = new AtomicBoolean(false) } +class FakeStateStoreProviderTracksCloseThread extends StateStoreProvider { + import FakeStateStoreProviderTracksCloseThread._ + private var id: StateStoreId = null + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false, + stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { + id = stateStoreId + } + + override def stateStoreId: StateStoreId = id + + override def close(): Unit = { + closeThreadNames = Thread.currentThread.getName :: closeThreadNames + } + + override def getStore(version: Long, uniqueId: Option[String]): StateStore = null + + override def doMaintenance(): Unit = {} +} + +private object FakeStateStoreProviderTracksCloseThread { + var closeThreadNames: List[String] = Nil +} + @ExtendedSQLTest class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter { @@ -563,8 +595,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] StateStore.get(storeProviderId2, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), 0, None, None, useColumnFamilies = false, storeConf, hadoopConf) - assert(!StateStore.isLoaded(storeProviderId1)) - assert(StateStore.isLoaded(storeProviderId2)) + // Close runs asynchronously, so we need to call eventually with a small timeout + eventually(timeout(5.seconds)) { + assert(!StateStore.isLoaded(storeProviderId1)) + assert(StateStore.isLoaded(storeProviderId2)) + } } } @@ -1082,7 +1117,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] - extends StateStoreCodecsTest with PrivateMethodTester { + extends StateStoreCodecsTest with PrivateMethodTester with BeforeAndAfter { import StateStoreTestsHelper._ type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] @@ -1093,13 +1128,11 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] testWithAllCodec("get, put, remove, commit, and all data iterator") { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(colFamiliesEnabled)) { provider => // Verify state before starting a new set of updates - assert(getLatestData(provider, useColumnFamilies = colFamiliesEnabled).isEmpty) - val store = provider.getStore(0) + assert(store.iterator().isEmpty) assert(!store.hasCommitted) assert(get(store, "a", 0) === None) assert(store.iterator().isEmpty) - assert(store.metrics.numKeys === 0) // Verify state after updating put(store, "a", 0, 1) @@ -1113,9 +1146,12 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(store, "aa", 0, 3) remove(store, _._1.startsWith("a")) assert(store.commit() === 1) + assert(store.metrics.numKeys === 1) assert(store.hasCommitted) - assert(rowPairsToDataSet(store.iterator()) === Set(("b", 0) -> 2)) + val store1 = provider.getStore(1) + assert(rowPairsToDataSet(store1.iterator()) === Set(("b", 0) -> 2)) + store1.abort() assert(getLatestData(provider, useColumnFamilies = colFamiliesEnabled) === Set(("b", 0) -> 2)) @@ -1135,7 +1171,10 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val reloadedStore = reloadedProvider.getStore(1) put(reloadedStore, "c", 0, 4) assert(reloadedStore.commit() === 2) - assert(rowPairsToDataSet(reloadedStore.iterator()) === Set(("b", 0) -> 2, ("c", 0) -> 4)) + val reloadedStore2 = reloadedProvider.getStore(2) + assert(rowPairsToDataSet(reloadedStore2.iterator()) === + Set(("b", 0) -> 2, ("c", 0) -> 4)) + reloadedStore2.commit() assert(getLatestData(provider, useColumnFamilies = colFamiliesEnabled) === Set(("b", 0) -> 2, ("c", 0) -> 4)) assert(getData(provider, version = 1, useColumnFamilies = colFamiliesEnabled) @@ -1147,10 +1186,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] testWithAllCodec("prefix scan") { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchema, PrefixKeyScanStateEncoderSpec(keySchema, 1), colFamiliesEnabled)) { provider => - // Verify state before starting a new set of updates - assert(getLatestData(provider, useColumnFamilies = false).isEmpty) var store = provider.getStore(0) + assert(store.iterator().isEmpty) def putCompositeKeys(keys: Seq[(String, Int)]): Unit = { val randomizedKeys = scala.util.Random.shuffle(keys.toList) @@ -1201,15 +1239,15 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // prefix scan should not reflect the uncommitted changes verifyScan(key1AtVersion0, key2AtVersion0) verifyScan(Seq("d"), Seq.empty) + store.abort() } } testWithAllCodec(s"numKeys metrics") { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(colFamiliesEnabled)) { provider => - // Verify state before starting a new set of updates - assert(getLatestData(provider, useColumnFamilies = colFamiliesEnabled).isEmpty) val store = provider.getStore(0) + assert(store.iterator().isEmpty) put(store, "a", 0, 1) put(store, "b", 0, 2) put(store, "c", 0, 3) @@ -1217,38 +1255,45 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(store, "e", 0, 5) assert(store.commit() === 1) assert(store.metrics.numKeys === 5) - assert(rowPairsToDataSet(store.iterator()) === - Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) val reloadedProvider = newStoreProvider(store.id, colFamiliesEnabled) val reloadedStore = reloadedProvider.getStore(1) + + assert(rowPairsToDataSet(reloadedStore.iterator()) === + Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) + remove(reloadedStore, _._1 == "b") assert(reloadedStore.commit() === 2) assert(reloadedStore.metrics.numKeys === 4) - assert(rowPairsToDataSet(reloadedStore.iterator()) === + val store2 = reloadedProvider.getStore(2) + assert(rowPairsToDataSet(store2.iterator()) === Set(("a", 0) -> 1, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) + store2.commit() } } testWithAllCodec(s"removing while iterating") { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(colFamiliesEnabled)) { provider => // Verify state before starting a new set of updates - assert(getLatestData(provider, useColumnFamilies = colFamiliesEnabled).isEmpty) val store = provider.getStore(0) + assert(store.iterator().isEmpty) put(store, "a", 0, 1) put(store, "b", 0, 2) // Updates should work while iterating of filtered entries - val filtered = store.iterator().filter { tuple => keyRowToData(tuple.key) == ("a", 0) } + val filtered = store.iterator().filter { + tuple => keyRowToData(tuple.key) == ("a", 0) } filtered.foreach { tuple => store.put(tuple.key, dataToValueRow(valueRowToData(tuple.value) + 1)) } assert(get(store, "a", 0) === Some(2)) // Removes should work while iterating of filtered entries - val filtered2 = store.iterator().filter { tuple => keyRowToData(tuple.key) == ("b", 0) } + val filtered2 = store.iterator().filter { + tuple => keyRowToData(tuple.key) == ("b", 0) } filtered2.foreach { tuple => store.remove(tuple.key) } assert(get(store, "b", 0) === None) + store.commit() } } @@ -1257,10 +1302,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store = provider.getStore(0) put(store, "a", 0, 1) store.commit() - assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1)) - // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) + assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1)) put(store1, "b", 0, 1) store1.abort() } @@ -1303,10 +1347,10 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store = provider.getStore(0) put(store, "a", 0, 1) assert(store.commit() === 1) - assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1)) val store1_ = provider.getStore(1) assert(rowPairsToDataSet(store1_.iterator()) === Set(("a", 0) -> 1)) + store1_.abort() checkInvalidVersion(-1, provider.isInstanceOf[HDFSBackedStateStoreProvider]) checkInvalidVersion(2, provider.isInstanceOf[HDFSBackedStateStoreProvider]) @@ -1316,8 +1360,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1)) put(store1, "b", 0, 1) assert(store1.commit() === 2) - assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1, ("b", 0) -> 1)) - + val store2 = provider.getStore(2) + assert(rowPairsToDataSet(store2.iterator()) === Set(("a", 0) -> 1, ("b", 0) -> 1)) + store2.abort() checkInvalidVersion(-1, provider.isInstanceOf[HDFSBackedStateStoreProvider]) checkInvalidVersion(3, provider.isInstanceOf[HDFSBackedStateStoreProvider]) } @@ -1326,7 +1371,8 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] testWithAllCodec("two concurrent StateStores - one for read-only and one for read-write") { colFamiliesEnabled => // During Streaming Aggregation, we have two StateStores per task, one used as read-only in - // `StateStoreRestoreExec`, and one read-write used in `StateStoreSaveExec`. `StateStore.abort` + // `StateStoreRestoreExec`, and one read-write used in `StateStoreSaveExec`. + // `StateStore.abort` // will be called for these StateStores if they haven't committed their results. We need to // make sure that `abort` in read-only store after a `commit` in the read-write store doesn't // accidentally lead to the deletion of state. @@ -1341,23 +1387,25 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(store, key1, key2, 1) store.commit() - assert(rowPairsToDataSet(store.iterator()) === Set((key1, key2) -> 1)) + val store1 = provider0.getStore(1) + assert(rowPairsToDataSet(store1.iterator()) === Set((key1, key2) -> 1)) + store1.commit() } // two state stores tryWithProviderResource(newStoreProvider(storeId, colFamiliesEnabled)) { provider1 => val restoreStore = provider1.getReadStore(1) - val saveStore = provider1.getStore(1) + val saveStore = provider1.getWriteStore(restoreStore, 1) put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) saveStore.commit() - restoreStore.abort() } // check that state is correct for next batch tryWithProviderResource(newStoreProvider(storeId, colFamiliesEnabled)) { provider2 => val finalStore = provider2.getStore(2) assert(rowPairsToDataSet(finalStore.iterator()) === Set((key1, key2) -> 2)) + finalStore.commit() } } @@ -1374,6 +1422,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store = provider.getStore(0) put(store, "a", 0, 0) val e = intercept[SparkException](quietly { store.commit() } ) + store.abort() assert(e.getCondition == "CANNOT_WRITE_STATE_STORE.CANNOT_COMMIT") if (store.getClass.getName contains ROCKSDB_STATE_STORE) { @@ -1407,7 +1456,8 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val itr3 = store.iterator() // itr3 is created after all writes. val intermediateState = Set(("1", 11) -> 100, ("2", 22) -> 200) // The intermediate state. - val finalState = Set(("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 300) // The final state. + val finalState = Set(("1", 11) -> 101, + ("2", 22) -> 200, ("3", 33) -> 300) // The final state. // Itr1 does not see any updates - original state of the store (SPARK-38320) assert(rowPairsToDataSet(itr1) === Set.empty[Set[((String, Int), Int)]]) if (store.getClass.getName contains ROCKSDB_STATE_STORE) { @@ -1555,11 +1605,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // minDeltasForSnapshot = 1 to enable snapshot generation here. tryWithProviderResource(newStoreProvider(minDeltasForSnapshot = 1, numOfVersToRetainInMemory = 1)) { provider => - val store = provider.getStore(0) - val noDataMemoryUsed = store.metrics.memoryUsedBytes - put(store, "a", 0, 1) - store.commit() - assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) + val store0 = provider.getStore(0) + put(store0, "a", 0, 1) + store0.commit() + val store0MemoryUsed = store0.metrics.memoryUsedBytes + val store1 = provider.getStore(1) + put(store1, "b", 0, 1) + store1.commit() + assert(store1.metrics.memoryUsedBytes > store0MemoryUsed) } } @@ -1588,13 +1641,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] test("SPARK-35659: StateStore.put cannot put null value") { tryWithProviderResource(newStoreProvider()) { provider => // Verify state before starting a new set of updates - assert(getLatestData(provider, useColumnFamilies = false).isEmpty) val store = provider.getStore(0) + assert(store.iterator().isEmpty) val err = intercept[IllegalArgumentException] { store.put(dataToKeyRow("key", 0), null) } assert(err.getMessage.contains("Cannot put a null value")) + store.commit() } } @@ -1718,6 +1772,60 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(encoderSpec == deserializedEncoderSpec) } + test("SPARK-51596: unloading only occurs on maintenance thread but occurs promptly") { + // Reset closeThreadNames + FakeStateStoreProviderTracksCloseThread.closeThreadNames = Nil + + val sqlConf = getDefaultSQLConf( + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get + ) + // Make maintenance interval very large (30s) so that task thread runs before maintenance. + sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L) + // Use the `MaintenanceErrorOnCertainPartitionsProvider` to run the test + sqlConf.setConf( + SQLConf.STATE_STORE_PROVIDER_CLASS, + classOf[FakeStateStoreProviderTracksCloseThread].getName + ) + + val conf = new SparkConf().setMaster("local").setAppName("test") + + withSpark(SparkContext.getOrCreate(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-48997" + val providerId = + StateStoreProviderId(StateStoreId(rootLocation, 0, 0), UUID.randomUUID) + val providerId2 = + StateStoreProviderId(StateStoreId(rootLocation, 0, 1), UUID.randomUUID) + + // Create provider to start the maintenance task + pool + StateStore.get( + providerId, + keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() + ) + + // Report instance active on another executor + coordinatorRef.reportActiveInstance(providerId, "otherhost", "otherexec", Seq.empty) + + // Load another provider to trigger task unload + StateStore.get( + providerId2, + keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), + 0, None, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() + ) + + // Wait for close to occur. Timeout is less than maintenance interval, + // so should only close by task triggering. + eventually(timeout(5.seconds)) { + assert(FakeStateStoreProviderTracksCloseThread.closeThreadNames.size == 1) + FakeStateStoreProviderTracksCloseThread.closeThreadNames.foreach { name => + assert(name.contains("state-store-maintenance-thread"))} + } + } + } + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 0d74aade67194..7e8fd8af0f59b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -51,6 +51,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { UUID.randomUUID(), stringEncoder, getTimeMode(timeMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", TTLConfig.NONE) + store.commit() } } } @@ -99,6 +100,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { createValueStateInstance(handle) } } + store.commit() } } } @@ -135,6 +137,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { ), matchPVals = true ) + store.abort() } } @@ -155,6 +158,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.removeImplicitKey() assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty) + store.commit() } } } @@ -195,6 +199,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { assert(timers2.toSeq.sorted === timerTimestamps2.sorted) ImplicitGroupingKeyTracker.removeImplicitKey() assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty) + store.commit() } } } @@ -213,6 +218,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { registerTimer(handle) } } + store.commit() } } } @@ -232,6 +238,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { assert(handle.ttlStates.size() === 1) assert(handle.ttlStates.get(0) === valueStateWithTTL) + store.commit() } } @@ -250,6 +257,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { assert(handle.ttlStates.size() === 1) assert(handle.ttlStates.get(0) === listStateWithTTL) + store.commit() } } @@ -269,6 +277,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { assert(handle.ttlStates.size() === 1) assert(handle.ttlStates.get(0) === mapStateWithTTL) + store.commit() } } @@ -284,6 +293,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { TTLConfig.NONE) assert(handle.ttlStates.isEmpty) + store.commit() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala index 428845d5ebcbb..43a9fd443298b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -55,6 +55,7 @@ class TimerSuite extends StateVariableSuiteBase { assert(timerState.listTimers().toSet === Set(20000L, 1000L)) timerState.deleteTimer(20000L) assert(timerState.listTimers().toSet === Set(1000L)) + store.commit() } } @@ -81,6 +82,7 @@ class TimerSuite extends StateVariableSuiteBase { assert(timerState1.listTimers().toSet === Set(20000L, 15000L, 1000L)) timerState1.deleteTimer(20000L) assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) + store.commit() } } @@ -113,6 +115,7 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key2") assert(timerState2.listTimers().toSet === Set(15000L)) assert(timerState2.getExpiredTimers(1500L).toSeq === Seq(("test_key1", 1000L))) + store.commit() } } @@ -132,6 +135,7 @@ class TimerSuite extends StateVariableSuiteBase { timerTimerstamps.sorted.takeWhile(_ <= 4200L)) assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty) ImplicitGroupingKeyTracker.removeImplicitKey() + store.commit() } } @@ -164,6 +168,7 @@ class TimerSuite extends StateVariableSuiteBase { assert(timerState1.getExpiredTimers(Long.MinValue).toSeq === Seq.empty) assert(timerState1.getExpiredTimers(8000L).toSeq.map(_._2) === (timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted.takeWhile(_ <= 8000L)) + store.commit() } } @@ -183,6 +188,7 @@ class TimerSuite extends StateVariableSuiteBase { assert(timerState.listTimers().toSet === Set(20000L, 1000L)) timerState.deleteTimer(20000L) assert(timerState.listTimers().toSet === Set(1000L)) + store.commit() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8af42d6dec269..537827d7f8592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -88,6 +88,7 @@ class ValueStateSuite extends StateVariableSuiteBase { ), matchPVals = true ) + store.abort() } } @@ -115,6 +116,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.get() === null) + store.commit() } } @@ -160,6 +162,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState2.clear() assert(!testState2.exists()) assert(testState2.get() === null) + store.commit() } } @@ -181,6 +184,7 @@ class ValueStateSuite extends StateVariableSuiteBase { ), matchPVals = false ) + store.abort() } } @@ -226,6 +230,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.get() === null) + store.commit() } } @@ -252,6 +257,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.get() === null) + store.commit() } } @@ -278,6 +284,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.get() === null) + store.commit() } } @@ -304,6 +311,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.clear() assert(!testState.exists()) assert(testState.get() === null) + store.commit() } } @@ -359,6 +367,7 @@ class ValueStateSuite extends StateVariableSuiteBase { nextBatchTestState.clear() assert(!nextBatchTestState.exists()) assert(nextBatchTestState.get() === null) + store.commit() } } @@ -386,6 +395,7 @@ class ValueStateSuite extends StateVariableSuiteBase { matchPVals = true ) } + store.abort() } } @@ -413,6 +423,7 @@ class ValueStateSuite extends StateVariableSuiteBase { assert(ttlValue.get._2 === ttlExpirationMs) val ttlStateValueIterator = testState.getValueInTTLState() assert(ttlStateValueIterator.isDefined) + store.commit() } } }