From 404f5f2390018f954d243aa1d9b0c1ff4339ba2d Mon Sep 17 00:00:00 2001
From: Xi Lyu <xi.lyu@databricks.com>
Date: Mon, 7 Apr 2025 13:41:21 +0000
Subject: [PATCH 1/2] Memory based MLCache eviction policy

---
 python/pyspark/testing/connectutils.py        |  3 +
 .../spark/sql/connect/config/Connect.scala    | 19 ++++++
 .../apache/spark/sql/connect/ml/MLCache.scala | 58 ++++++++++++++-----
 .../sql/connect/service/SessionHolder.scala   |  2 +-
 .../apache/spark/sql/connect/ml/MLSuite.scala | 31 ++++++++++
 5 files changed, 96 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index 423a717e8ab5e..e1e0356f4d426 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -158,6 +158,9 @@ def conf(cls):
         # Set a static token for all tests so the parallelism doesn't overwrite each
         # tests' environment variables
         conf.set("spark.connect.authenticate.token", "deadbeef")
+        # Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues
+        # in tests.
+        conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g")
         return conf
 
     @classmethod
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index f853e115cdfe1..98f19fe261e01 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -333,4 +333,23 @@ object Connect {
       Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
     }
   }
+
+  val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE =
+    buildConf("spark.connect.session.connectML.mlCache.maxSize")
+      .doc("Maximum size of the MLCache per session. The cache will evict the least recently" +
+        "used models if the size exceeds this limit. The size is in bytes.")
+      .version("4.0.0")
+      .internal()
+      .bytesConf(ByteUnit.BYTE)
+      // By default, 1/3 of total designated memory (the configured -Xmx).
+      .createWithDefault(Runtime.getRuntime.maxMemory() / 3)
+
+  val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT =
+    buildConf("spark.connect.session.connectML.mlCache.timeout")
+      .doc("Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
+        "used for this amount of time. The timeout is in minutes.")
+      .version("4.0.0")
+      .internal()
+      .timeConf(TimeUnit.MINUTES)
+      .createWithDefault(15)
 }
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index e8d8585020722..ea3bf6def997b 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -18,27 +18,59 @@ package org.apache.spark.sql.connect.ml
 
 import java.util.UUID
 import java.util.concurrent.{ConcurrentMap, TimeUnit}
+import java.util.concurrent.atomic.AtomicLong
 
-import com.google.common.cache.CacheBuilder
+import com.google.common.cache.{CacheBuilder, RemovalNotification}
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml.Model
 import org.apache.spark.ml.util.ConnectHelper
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.service.SessionHolder
 
 /**
  * MLCache is for caching ML objects, typically for models and summaries evaluated by a model.
  */
-private[connect] class MLCache extends Logging {
+private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
   private val helper = new ConnectHelper()
   private val helperID = "______ML_CONNECT_HELPER______"
 
-  private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder
+  private def getMaxCacheSizeKB: Long = {
+    sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024
+  }
+
+  private def getTimeoutMinute: Long = {
+    sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT)
+  }
+
+  private[ml] case class CacheItem(obj: Object, sizeBytes: Long)
+  private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder
     .newBuilder()
     .softValues()
-    .maximumSize(MLCache.MAX_CACHED_ITEMS)
-    .expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES)
-    .build[String, Object]()
+    .maximumWeight(getMaxCacheSizeKB)
+    .expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES)
+    .weigher((key: String, value: CacheItem) => {
+      Math.ceil(value.sizeBytes.toDouble / 1024).toInt
+    })
+    .removalListener(
+      (removed: RemovalNotification[String, CacheItem]) =>
+        totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)
+    )
+    .build[String, CacheItem]()
     .asMap()
 
+  private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)
+
+  private def estimateObjectSize(obj: Object): Long = {
+    obj match {
+      case model: Model[_] =>
+        model.asInstanceOf[Model[_]].estimatedSize
+      case _ =>
+        // There can only be Models in the cache, so we should never reach here.
+        1
+    }
+  }
+
   /**
    * Cache an object into a map of MLCache, and return its key
    * @param obj
@@ -48,7 +80,9 @@ private[connect] class MLCache extends Logging {
    */
   def register(obj: Object): String = {
     val objectId = UUID.randomUUID().toString
-    cachedModel.put(objectId, obj)
+    val sizeBytes = estimateObjectSize(obj)
+    totalSizeBytes.addAndGet(sizeBytes)
+    cachedModel.put(objectId, CacheItem(obj, sizeBytes))
     objectId
   }
 
@@ -63,7 +97,7 @@ private[connect] class MLCache extends Logging {
     if (refId == helperID) {
       helper
     } else {
-      cachedModel.get(refId)
+      Option(cachedModel.get(refId)).map(_.obj).orNull
     }
   }
 
@@ -83,11 +117,3 @@ private[connect] class MLCache extends Logging {
     cachedModel.clear()
   }
 }
-
-private[connect] object MLCache {
-  // The maximum number of distinct items in the cache.
-  private val MAX_CACHED_ITEMS = 100
-
-  // The maximum time for an item to stay in the cache.
-  private val CACHE_TIMEOUT_MINUTE = 60
-}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 631885a5d741c..6f252c0cd9480 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -112,7 +112,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
     new ConcurrentHashMap()
 
   // ML model cache
-  private[connect] lazy val mlCache = new MLCache()
+  private[connect] lazy val mlCache = new MLCache(this)
 
   // Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
   // StreamingQueryManager.
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index 76ce34a67e748..a54c5bf90e7f9 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.connect.SparkConnectTestUtils
+import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.service.SessionHolder
 
 trait FakeArrayParams extends Params {
@@ -379,4 +380,34 @@ class MLSuite extends MLHelper {
         .map(_.getString)
         .toArray sameElements Array("a", "b", "c"))
   }
+
+  test("Memory limitation of MLCache works") {
+    val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+    val memorySizeBytes = 1024 * 16
+    sessionHolder.session.conf.set(
+      Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key,
+      memorySizeBytes
+    )
+    trainLogisticRegressionModel(sessionHolder)
+    assert(sessionHolder.mlCache.cachedModel.size() == 1)
+    assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
+    val modelSizeBytes = sessionHolder.mlCache.totalSizeBytes.get()
+    val maxNumModels = memorySizeBytes / modelSizeBytes.toInt
+
+    // All models will be kept if the total size is less than the memory limit.
+    for (i <- 1 until maxNumModels) {
+      trainLogisticRegressionModel(sessionHolder)
+      assert(sessionHolder.mlCache.cachedModel.size() == i + 1)
+      assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
+      assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
+    }
+
+    // Old models will be removed if new ones are added and the total size exceeds the memory limit.
+    for (_ <- 0 until 3) {
+      trainLogisticRegressionModel(sessionHolder)
+      assert(sessionHolder.mlCache.cachedModel.size() == maxNumModels)
+      assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
+      assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
+    }
+  }
 }

From b40c9dc870dab427b7dc931115c758034baf5200 Mon Sep 17 00:00:00 2001
From: Xi Lyu <xi.lyu@databricks.com>
Date: Tue, 8 Apr 2025 08:01:22 +0000
Subject: [PATCH 2/2] Make config version 4.1.0, reformat code

---
 .../org/apache/spark/sql/connect/config/Connect.scala    | 9 +++++----
 .../scala/org/apache/spark/sql/connect/ml/MLCache.scala  | 6 ++----
 .../scala/org/apache/spark/sql/connect/ml/MLSuite.scala  | 6 ++----
 3 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 98f19fe261e01..1b9f770e9e96a 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -338,7 +338,7 @@ object Connect {
     buildConf("spark.connect.session.connectML.mlCache.maxSize")
       .doc("Maximum size of the MLCache per session. The cache will evict the least recently" +
         "used models if the size exceeds this limit. The size is in bytes.")
-      .version("4.0.0")
+      .version("4.1.0")
       .internal()
       .bytesConf(ByteUnit.BYTE)
       // By default, 1/3 of total designated memory (the configured -Xmx).
@@ -346,9 +346,10 @@ object Connect {
 
   val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT =
     buildConf("spark.connect.session.connectML.mlCache.timeout")
-      .doc("Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
-        "used for this amount of time. The timeout is in minutes.")
-      .version("4.0.0")
+      .doc(
+        "Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
+          "used for this amount of time. The timeout is in minutes.")
+      .version("4.1.0")
       .internal()
       .timeConf(TimeUnit.MINUTES)
       .createWithDefault(15)
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index ea3bf6def997b..0f7acd0a6b527 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -52,10 +52,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
     .weigher((key: String, value: CacheItem) => {
       Math.ceil(value.sizeBytes.toDouble / 1024).toInt
     })
-    .removalListener(
-      (removed: RemovalNotification[String, CacheItem]) =>
-        totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)
-    )
+    .removalListener((removed: RemovalNotification[String, CacheItem]) =>
+      totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
     .build[String, CacheItem]()
     .asMap()
 
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index a54c5bf90e7f9..73bc1f2086aef 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -384,10 +384,8 @@ class MLSuite extends MLHelper {
   test("Memory limitation of MLCache works") {
     val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
     val memorySizeBytes = 1024 * 16
-    sessionHolder.session.conf.set(
-      Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key,
-      memorySizeBytes
-    )
+    sessionHolder.session.conf
+      .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, memorySizeBytes)
     trainLogisticRegressionModel(sessionHolder)
     assert(sessionHolder.mlCache.cachedModel.size() == 1)
     assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)