Skip to content

[SPARK-51745] Enforce State Machine for RocksDBStateStore #50497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
state store impl
liviazhu-db authored and ericm-db committed Apr 2, 2025
commit c70af5f8920658fdb50eee04a1ba594644355ec6
13 changes: 13 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
@@ -4853,6 +4853,19 @@
],
"sqlState" : "XXKST"
},
"STATE_STORE_LOCK_VIOLATION" : {
"message" : [
"An lock model violation occurred in the state store."
],
"subClass" : {
"OPERATION_OUT_OF_ORDER" : {
"message": [
"A state store state machine operation is out of order. errorMsg=<errorMsg>"
]
}
},
"sqlState" : "XXKST"
},
"STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED" : {
"message" : [
"Null type ordering column with name=<fieldName> at index=<index> is not supported for range scan encoder."
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.io.CompressionCodec
@@ -50,7 +50,82 @@ private[sql] class RocksDBStateStoreProvider
case object COMMITTED extends STATE
case object ABORTED extends STATE

private sealed trait TRANSITION
private case object UPDATE extends TRANSITION
private case object ABORT extends TRANSITION
private case object COMMIT extends TRANSITION
private case object METRICS extends TRANSITION

@volatile private var state: STATE = UPDATING

/**
* Validates the expected state, throws exception if state is not as expected.
* Returns the current state
*
* @param possibleStates Expected possible states
*/
private def validateState(possibleStates: List[STATE]): STATE = {
if (!possibleStates.contains(state)) {
throw StateStoreErrors.stateStoreOperationOutOfOrder(
s"Expected possible states $possibleStates but found $state")
}
state
}

/**
* Throws error if transition is illegal.
* MUST be called for every StateStore method.
*
* @param transition The transition type of the operation.
*/
private def validateAndTransitionState(transition: TRANSITION): Unit = {
val newState = transition match {
case UPDATE =>
state match {
case UPDATING => UPDATING
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
s"Cannot update after committed")
case ABORTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot update after aborted")
}
case ABORT =>
state match {
case UPDATING => ABORTED
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot abort after committed")
case ABORTED => ABORTED
}
case COMMIT =>
state match {
case UPDATING => COMMITTED
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot commit after committed")
case ABORTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot commit after aborted")
}
case METRICS =>
state match {
case UPDATING => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot get metrics in UPDATING state")
case COMMITTED => COMMITTED
case ABORTED => ABORTED
}
}
state = newState
}

// Add a listener for task threads to abort when the task completes and hasn't released
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] {
_ =>
try {
abort()
} catch {
case NonFatal(e) =>
logWarning("Failed to abort state store", e)
}
})

// State row format validated
@volatile private var isValidated = false

override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
@@ -64,6 +139,8 @@ private[sql] class RocksDBStateStoreProvider
keyStateEncoderSpec: KeyStateEncoderSpec,
useMultipleValuesPerKey: Boolean = false,
isInternal: Boolean = false): Unit = {
validateAndTransitionState(UPDATE)

verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal)
val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal)
val dataEncoderCacheKey = StateRowEncoderCacheKey(
@@ -105,6 +182,8 @@ private[sql] class RocksDBStateStoreProvider
}

override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
validateAndTransitionState(UPDATE)

verify(key != null, "Key cannot be null")
verifyColFamilyOperations("get", colFamilyName)

@@ -131,6 +210,8 @@ private[sql] class RocksDBStateStoreProvider
* values per key.
*/
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
validateAndTransitionState(UPDATE)

verify(key != null, "Key cannot be null")
verifyColFamilyOperations("valuesIterator", colFamilyName)

@@ -147,7 +228,8 @@ private[sql] class RocksDBStateStoreProvider

override def merge(key: UnsafeRow, value: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
verify(state == UPDATING, "Cannot merge after already committed or aborted")
validateAndTransitionState(UPDATE)

verifyColFamilyOperations("merge", colFamilyName)

val kvEncoder = keyValueEncoderMap.get(colFamilyName)
@@ -162,7 +244,8 @@ private[sql] class RocksDBStateStoreProvider
}

override def put(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): Unit = {
verify(state == UPDATING, "Cannot put after already committed or aborted")
validateAndTransitionState(UPDATE)

verify(key != null, "Key cannot be null")
require(value != null, "Cannot put a null value")
verifyColFamilyOperations("put", colFamilyName)
@@ -172,7 +255,8 @@ private[sql] class RocksDBStateStoreProvider
}

override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
verify(state == UPDATING, "Cannot remove after already committed or aborted")
validateAndTransitionState(UPDATE)

verify(key != null, "Key cannot be null")
verifyColFamilyOperations("remove", colFamilyName)

@@ -181,6 +265,8 @@ private[sql] class RocksDBStateStoreProvider
}

override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
validateAndTransitionState(UPDATE)

// Note this verify function only verify on the colFamilyName being valid,
// we are actually doing prefix when useColumnFamilies,
// but pass "iterator" to throw correct error message
@@ -215,6 +301,8 @@ private[sql] class RocksDBStateStoreProvider

override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
Iterator[UnsafeRowPair] = {
validateAndTransitionState(UPDATE)

verifyColFamilyOperations("prefixScan", colFamilyName)

val kvEncoder = keyValueEncoderMap.get(colFamilyName)
@@ -232,10 +320,13 @@ private[sql] class RocksDBStateStoreProvider

var checkpointInfo: Option[StateStoreCheckpointInfo] = None
override def commit(): Long = synchronized {
validateState(List(UPDATING))

try {
verify(state == UPDATING, "Cannot commit after already committed or aborted")
val (newVersion, newCheckpointInfo) = rocksDB.commit()
checkpointInfo = Some(newCheckpointInfo)
validateAndTransitionState(COMMIT)
state = COMMITTED
logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " +
log"for ${MDC(STATE_STORE_ID, id)}")
@@ -247,14 +338,17 @@ private[sql] class RocksDBStateStoreProvider
}

override def abort(): Unit = {
verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
logInfo(log"Aborting ${MDC(VERSION_NUM, version + 1)} " +
log"for ${MDC(STATE_STORE_ID, id)}")
rocksDB.rollback()
state = ABORTED
if (validateState(List(UPDATING, ABORTED)) != ABORTED) {
logInfo(log"Aborting ${MDC(VERSION_NUM, version + 1)} " +
log"for ${MDC(STATE_STORE_ID, id)}")
rocksDB.rollback()
validateAndTransitionState(ABORT)
}
}

override def metrics: StateStoreMetrics = {
validateAndTransitionState(METRICS)

val rocksDBMetricsOpt = rocksDB.metricsOpt

if (rocksDBMetricsOpt.isDefined) {
@@ -337,6 +431,8 @@ private[sql] class RocksDBStateStoreProvider
}

override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = {
validateAndTransitionState(METRICS)

checkpointInfo match {
case Some(info) => info
case None => throw StateStoreErrors.stateStoreOperationOutOfOrder(
@@ -356,6 +452,8 @@ private[sql] class RocksDBStateStoreProvider

/** Remove column family if exists */
override def removeColFamilyIfExists(colFamilyName: String): Boolean = {
validateAndTransitionState(UPDATE)

verifyColFamilyCreationOrDeletion("remove_col_family", colFamilyName)
verify(useColumnFamilies, "Column families are not supported in this store")