From 404f5f2390018f954d243aa1d9b0c1ff4339ba2d Mon Sep 17 00:00:00 2001 From: Xi Lyu <xi.lyu@databricks.com> Date: Mon, 7 Apr 2025 13:41:21 +0000 Subject: [PATCH 1/2] Memory based MLCache eviction policy --- python/pyspark/testing/connectutils.py | 3 + .../spark/sql/connect/config/Connect.scala | 19 ++++++ .../apache/spark/sql/connect/ml/MLCache.scala | 58 ++++++++++++++----- .../sql/connect/service/SessionHolder.scala | 2 +- .../apache/spark/sql/connect/ml/MLSuite.scala | 31 ++++++++++ 5 files changed, 96 insertions(+), 17 deletions(-) diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 423a717e8ab5e..e1e0356f4d426 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -158,6 +158,9 @@ def conf(cls): # Set a static token for all tests so the parallelism doesn't overwrite each # tests' environment variables conf.set("spark.connect.authenticate.token", "deadbeef") + # Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues + # in tests. + conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g") return conf @classmethod diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index f853e115cdfe1..98f19fe261e01 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -333,4 +333,23 @@ object Connect { Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV)) } } + + val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE = + buildConf("spark.connect.session.connectML.mlCache.maxSize") + .doc("Maximum size of the MLCache per session. The cache will evict the least recently" + + "used models if the size exceeds this limit. The size is in bytes.") + .version("4.0.0") + .internal() + .bytesConf(ByteUnit.BYTE) + // By default, 1/3 of total designated memory (the configured -Xmx). + .createWithDefault(Runtime.getRuntime.maxMemory() / 3) + + val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT = + buildConf("spark.connect.session.connectML.mlCache.timeout") + .doc("Timeout of models in MLCache. Models will be evicted from the cache if they are not " + + "used for this amount of time. The timeout is in minutes.") + .version("4.0.0") + .internal() + .timeConf(TimeUnit.MINUTES) + .createWithDefault(15) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index e8d8585020722..ea3bf6def997b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -18,27 +18,59 @@ package org.apache.spark.sql.connect.ml import java.util.UUID import java.util.concurrent.{ConcurrentMap, TimeUnit} +import java.util.concurrent.atomic.AtomicLong -import com.google.common.cache.CacheBuilder +import com.google.common.cache.{CacheBuilder, RemovalNotification} import org.apache.spark.internal.Logging +import org.apache.spark.ml.Model import org.apache.spark.ml.util.ConnectHelper +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SessionHolder /** * MLCache is for caching ML objects, typically for models and summaries evaluated by a model. */ -private[connect] class MLCache extends Logging { +private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private val helper = new ConnectHelper() private val helperID = "______ML_CONNECT_HELPER______" - private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder + private def getMaxCacheSizeKB: Long = { + sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024 + } + + private def getTimeoutMinute: Long = { + sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT) + } + + private[ml] case class CacheItem(obj: Object, sizeBytes: Long) + private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder .newBuilder() .softValues() - .maximumSize(MLCache.MAX_CACHED_ITEMS) - .expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES) - .build[String, Object]() + .maximumWeight(getMaxCacheSizeKB) + .expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES) + .weigher((key: String, value: CacheItem) => { + Math.ceil(value.sizeBytes.toDouble / 1024).toInt + }) + .removalListener( + (removed: RemovalNotification[String, CacheItem]) => + totalSizeBytes.addAndGet(-removed.getValue.sizeBytes) + ) + .build[String, CacheItem]() .asMap() + private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) + + private def estimateObjectSize(obj: Object): Long = { + obj match { + case model: Model[_] => + model.asInstanceOf[Model[_]].estimatedSize + case _ => + // There can only be Models in the cache, so we should never reach here. + 1 + } + } + /** * Cache an object into a map of MLCache, and return its key * @param obj @@ -48,7 +80,9 @@ private[connect] class MLCache extends Logging { */ def register(obj: Object): String = { val objectId = UUID.randomUUID().toString - cachedModel.put(objectId, obj) + val sizeBytes = estimateObjectSize(obj) + totalSizeBytes.addAndGet(sizeBytes) + cachedModel.put(objectId, CacheItem(obj, sizeBytes)) objectId } @@ -63,7 +97,7 @@ private[connect] class MLCache extends Logging { if (refId == helperID) { helper } else { - cachedModel.get(refId) + Option(cachedModel.get(refId)).map(_.obj).orNull } } @@ -83,11 +117,3 @@ private[connect] class MLCache extends Logging { cachedModel.clear() } } - -private[connect] object MLCache { - // The maximum number of distinct items in the cache. - private val MAX_CACHED_ITEMS = 100 - - // The maximum time for an item to stay in the cache. - private val CACHE_TIMEOUT_MINUTE = 60 -} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 631885a5d741c..6f252c0cd9480 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -112,7 +112,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio new ConcurrentHashMap() // ML model cache - private[connect] lazy val mlCache = new MLCache() + private[connect] lazy val mlCache = new MLCache(this) // Mapping from id to StreamingQueryListener. Used for methods like removeListener() in // StreamingQueryManager. diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 76ce34a67e748..a54c5bf90e7f9 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.connect.SparkConnectTestUtils +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder trait FakeArrayParams extends Params { @@ -379,4 +380,34 @@ class MLSuite extends MLHelper { .map(_.getString) .toArray sameElements Array("a", "b", "c")) } + + test("Memory limitation of MLCache works") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + val memorySizeBytes = 1024 * 16 + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, + memorySizeBytes + ) + trainLogisticRegressionModel(sessionHolder) + assert(sessionHolder.mlCache.cachedModel.size() == 1) + assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) + val modelSizeBytes = sessionHolder.mlCache.totalSizeBytes.get() + val maxNumModels = memorySizeBytes / modelSizeBytes.toInt + + // All models will be kept if the total size is less than the memory limit. + for (i <- 1 until maxNumModels) { + trainLogisticRegressionModel(sessionHolder) + assert(sessionHolder.mlCache.cachedModel.size() == i + 1) + assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) + assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes) + } + + // Old models will be removed if new ones are added and the total size exceeds the memory limit. + for (_ <- 0 until 3) { + trainLogisticRegressionModel(sessionHolder) + assert(sessionHolder.mlCache.cachedModel.size() == maxNumModels) + assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) + assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes) + } + } } From b40c9dc870dab427b7dc931115c758034baf5200 Mon Sep 17 00:00:00 2001 From: Xi Lyu <xi.lyu@databricks.com> Date: Tue, 8 Apr 2025 08:01:22 +0000 Subject: [PATCH 2/2] Make config version 4.1.0, reformat code --- .../org/apache/spark/sql/connect/config/Connect.scala | 9 +++++---- .../scala/org/apache/spark/sql/connect/ml/MLCache.scala | 6 ++---- .../scala/org/apache/spark/sql/connect/ml/MLSuite.scala | 6 ++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 98f19fe261e01..1b9f770e9e96a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -338,7 +338,7 @@ object Connect { buildConf("spark.connect.session.connectML.mlCache.maxSize") .doc("Maximum size of the MLCache per session. The cache will evict the least recently" + "used models if the size exceeds this limit. The size is in bytes.") - .version("4.0.0") + .version("4.1.0") .internal() .bytesConf(ByteUnit.BYTE) // By default, 1/3 of total designated memory (the configured -Xmx). @@ -346,9 +346,10 @@ object Connect { val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT = buildConf("spark.connect.session.connectML.mlCache.timeout") - .doc("Timeout of models in MLCache. Models will be evicted from the cache if they are not " + - "used for this amount of time. The timeout is in minutes.") - .version("4.0.0") + .doc( + "Timeout of models in MLCache. Models will be evicted from the cache if they are not " + + "used for this amount of time. The timeout is in minutes.") + .version("4.1.0") .internal() .timeConf(TimeUnit.MINUTES) .createWithDefault(15) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index ea3bf6def997b..0f7acd0a6b527 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -52,10 +52,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { .weigher((key: String, value: CacheItem) => { Math.ceil(value.sizeBytes.toDouble / 1024).toInt }) - .removalListener( - (removed: RemovalNotification[String, CacheItem]) => - totalSizeBytes.addAndGet(-removed.getValue.sizeBytes) - ) + .removalListener((removed: RemovalNotification[String, CacheItem]) => + totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)) .build[String, CacheItem]() .asMap() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index a54c5bf90e7f9..73bc1f2086aef 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -384,10 +384,8 @@ class MLSuite extends MLHelper { test("Memory limitation of MLCache works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) val memorySizeBytes = 1024 * 16 - sessionHolder.session.conf.set( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, - memorySizeBytes - ) + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, memorySizeBytes) trainLogisticRegressionModel(sessionHolder) assert(sessionHolder.mlCache.cachedModel.size() == 1) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)