Skip to content

[SPARK-51711][ML][CONNECT] Memory based MLCache eviction policy #50530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def conf(cls):
# Set a static token for all tests so the parallelism doesn't overwrite each
# tests' environment variables
conf.set("spark.connect.authenticate.token", "deadbeef")
# Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues
# in tests.
conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g")
return conf

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,4 +333,24 @@ object Connect {
Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
}
}

val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE =
buildConf("spark.connect.session.connectML.mlCache.maxSize")
.doc("Maximum size of the MLCache per session. The cache will evict the least recently" +
"used models if the size exceeds this limit. The size is in bytes.")
.version("4.1.0")
.internal()
.bytesConf(ByteUnit.BYTE)
// By default, 1/3 of total designated memory (the configured -Xmx).
.createWithDefault(Runtime.getRuntime.maxMemory() / 3)

val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT =
buildConf("spark.connect.session.connectML.mlCache.timeout")
.doc(
"Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
"used for this amount of time. The timeout is in minutes.")
.version("4.1.0")
.internal()
.timeConf(TimeUnit.MINUTES)
.createWithDefault(15)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,57 @@ package org.apache.spark.sql.connect.ml

import java.util.UUID
import java.util.concurrent.{ConcurrentMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong

import com.google.common.cache.CacheBuilder
import com.google.common.cache.{CacheBuilder, RemovalNotification}

import org.apache.spark.internal.Logging
import org.apache.spark.ml.Model
import org.apache.spark.ml.util.ConnectHelper
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder

/**
* MLCache is for caching ML objects, typically for models and summaries evaluated by a model.
*/
private[connect] class MLCache extends Logging {
private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
private val helper = new ConnectHelper()
private val helperID = "______ML_CONNECT_HELPER______"

private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder
private def getMaxCacheSizeKB: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024
}

private def getTimeoutMinute: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT)
}

private[ml] case class CacheItem(obj: Object, sizeBytes: Long)
private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder
.newBuilder()
.softValues()
.maximumSize(MLCache.MAX_CACHED_ITEMS)
.expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES)
.build[String, Object]()
.maximumWeight(getMaxCacheSizeKB)
.expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES)
.weigher((key: String, value: CacheItem) => {
Math.ceil(value.sizeBytes.toDouble / 1024).toInt
})
.removalListener((removed: RemovalNotification[String, CacheItem]) =>
totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
.build[String, CacheItem]()
.asMap()

private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)

private def estimateObjectSize(obj: Object): Long = {
obj match {
case model: Model[_] =>
model.asInstanceOf[Model[_]].estimatedSize
case _ =>
// There can only be Models in the cache, so we should never reach here.
1
}
}

/**
* Cache an object into a map of MLCache, and return its key
* @param obj
Expand All @@ -48,7 +78,9 @@ private[connect] class MLCache extends Logging {
*/
def register(obj: Object): String = {
val objectId = UUID.randomUUID().toString
cachedModel.put(objectId, obj)
val sizeBytes = estimateObjectSize(obj)
totalSizeBytes.addAndGet(sizeBytes)
cachedModel.put(objectId, CacheItem(obj, sizeBytes))
objectId
}

Expand All @@ -63,7 +95,7 @@ private[connect] class MLCache extends Logging {
if (refId == helperID) {
helper
} else {
cachedModel.get(refId)
Option(cachedModel.get(refId)).map(_.obj).orNull
}
}

Expand All @@ -83,11 +115,3 @@ private[connect] class MLCache extends Logging {
cachedModel.clear()
}
}

private[connect] object MLCache {
// The maximum number of distinct items in the cache.
private val MAX_CACHED_ITEMS = 100

// The maximum time for an item to stay in the cache.
private val CACHE_TIMEOUT_MINUTE = 60
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
new ConcurrentHashMap()

// ML model cache
private[connect] lazy val mlCache = new MLCache()
private[connect] lazy val mlCache = new MLCache(this)

// Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
// StreamingQueryManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder

trait FakeArrayParams extends Params {
Expand Down Expand Up @@ -379,4 +380,32 @@ class MLSuite extends MLHelper {
.map(_.getString)
.toArray sameElements Array("a", "b", "c"))
}

test("Memory limitation of MLCache works") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
val memorySizeBytes = 1024 * 16
sessionHolder.session.conf
.set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, memorySizeBytes)
trainLogisticRegressionModel(sessionHolder)
assert(sessionHolder.mlCache.cachedModel.size() == 1)
assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
val modelSizeBytes = sessionHolder.mlCache.totalSizeBytes.get()
val maxNumModels = memorySizeBytes / modelSizeBytes.toInt

// All models will be kept if the total size is less than the memory limit.
for (i <- 1 until maxNumModels) {
trainLogisticRegressionModel(sessionHolder)
assert(sessionHolder.mlCache.cachedModel.size() == i + 1)
assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
}

// Old models will be removed if new ones are added and the total size exceeds the memory limit.
for (_ <- 0 until 3) {
trainLogisticRegressionModel(sessionHolder)
assert(sessionHolder.mlCache.cachedModel.size() == maxNumModels)
assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
}
}
}