Skip to content

Commit ad5ae1d

Browse files
WeichenXu123ericm-db
authored andcommitted
[SPARK-51867][ML] Make scala model supporting save / load methods against local filesystem path
### What changes were proposed in this pull request? Make scala model supporting save / load methods (deverloper api) against local filesystem path. ### Why are the changes needed? This is required by Spark Connect server model cache management. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50665 from WeichenXu123/ml-save-to-local. Authored-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: yangjie01 <yangjie01@baidu.com>
1 parent abc8d64 commit ad5ae1d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+555
-384
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
300300
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
301301
val dataPath = new Path(path, "data").toString
302302
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
303-
sparkSession.createDataFrame(nodeData).repartition(numDataParts).write.parquet(dataPath)
303+
ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession, numDataParts)
304304
}
305305
}
306306

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,11 @@ class FMClassificationModel private[classification] (
345345

346346
@Since("3.0.0")
347347
object FMClassificationModel extends MLReadable[FMClassificationModel] {
348+
private case class Data(
349+
intercept: Double,
350+
linear: Vector,
351+
factors: Matrix
352+
)
348353

349354
@Since("3.0.0")
350355
override def read: MLReader[FMClassificationModel] = new FMClassificationModelReader
@@ -356,16 +361,11 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
356361
private[FMClassificationModel] class FMClassificationModelWriter(
357362
instance: FMClassificationModel) extends MLWriter with Logging {
358363

359-
private case class Data(
360-
intercept: Double,
361-
linear: Vector,
362-
factors: Matrix)
363-
364364
override protected def saveImpl(path: String): Unit = {
365365
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
366366
val data = Data(instance.intercept, instance.linear, instance.factors)
367367
val dataPath = new Path(path, "data").toString
368-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
368+
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
369369
}
370370
}
371371

@@ -376,11 +376,11 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
376376
override def load(path: String): FMClassificationModel = {
377377
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
378378
val dataPath = new Path(path, "data").toString
379-
val data = sparkSession.read.format("parquet").load(dataPath)
380379

381-
val Row(intercept: Double, linear: Vector, factors: Matrix) =
382-
data.select("intercept", "linear", "factors").head()
383-
val model = new FMClassificationModel(metadata.uid, intercept, linear, factors)
380+
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
381+
val model = new FMClassificationModel(
382+
metadata.uid, data.intercept, data.linear, data.factors
383+
)
384384
metadata.getAndSetParams(model)
385385
model
386386
}

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,9 @@ class LinearSVCModel private[classification] (
445445
}
446446
}
447447

448-
449448
@Since("2.2.0")
450449
object LinearSVCModel extends MLReadable[LinearSVCModel] {
450+
private case class Data(coefficients: Vector, intercept: Double)
451451

452452
@Since("2.2.0")
453453
override def read: MLReader[LinearSVCModel] = new LinearSVCReader
@@ -460,14 +460,12 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
460460
class LinearSVCWriter(instance: LinearSVCModel)
461461
extends MLWriter with Logging {
462462

463-
private case class Data(coefficients: Vector, intercept: Double)
464-
465463
override protected def saveImpl(path: String): Unit = {
466464
// Save metadata and Params
467465
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
468466
val data = Data(instance.coefficients, instance.intercept)
469467
val dataPath = new Path(path, "data").toString
470-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
468+
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
471469
}
472470
}
473471

@@ -479,10 +477,8 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
479477
override def load(path: String): LinearSVCModel = {
480478
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
481479
val dataPath = new Path(path, "data").toString
482-
val data = sparkSession.read.format("parquet").load(dataPath)
483-
val Row(coefficients: Vector, intercept: Double) =
484-
data.select("coefficients", "intercept").head()
485-
val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
480+
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
481+
val model = new LinearSVCModel(metadata.uid, data.coefficients, data.intercept)
486482
metadata.getAndSetParams(model)
487483
model
488484
}

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,9 +1316,14 @@ class LogisticRegressionModel private[spark] (
13161316
}
13171317
}
13181318

1319-
13201319
@Since("1.6.0")
13211320
object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
1321+
case class Data(
1322+
numClasses: Int,
1323+
numFeatures: Int,
1324+
interceptVector: Vector,
1325+
coefficientMatrix: Matrix,
1326+
isMultinomial: Boolean)
13221327

13231328
@Since("1.6.0")
13241329
override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
@@ -1331,21 +1336,14 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
13311336
class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
13321337
extends MLWriter with Logging {
13331338

1334-
private case class Data(
1335-
numClasses: Int,
1336-
numFeatures: Int,
1337-
interceptVector: Vector,
1338-
coefficientMatrix: Matrix,
1339-
isMultinomial: Boolean)
1340-
13411339
override protected def saveImpl(path: String): Unit = {
13421340
// Save metadata and Params
13431341
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
13441342
// Save model data: numClasses, numFeatures, intercept, coefficients
13451343
val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
13461344
instance.coefficientMatrix, instance.isMultinomial)
13471345
val dataPath = new Path(path, "data").toString
1348-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
1346+
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
13491347
}
13501348
}
13511349

@@ -1359,9 +1357,9 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
13591357
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
13601358

13611359
val dataPath = new Path(path, "data").toString
1362-
val data = sparkSession.read.format("parquet").load(dataPath)
13631360

13641361
val model = if (major < 2 || (major == 2 && minor == 0)) {
1362+
val data = sparkSession.read.format("parquet").load(dataPath)
13651363
// 2.0 and before
13661364
val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
13671365
MLUtils.convertVectorColumnsToML(data, "coefficients")
@@ -1374,12 +1372,9 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
13741372
interceptVector, numClasses, isMultinomial = false)
13751373
} else {
13761374
// 2.1+
1377-
val Row(numClasses: Int, numFeatures: Int, interceptVector: Vector,
1378-
coefficientMatrix: Matrix, isMultinomial: Boolean) = data
1379-
.select("numClasses", "numFeatures", "interceptVector", "coefficientMatrix",
1380-
"isMultinomial").head()
1381-
new LogisticRegressionModel(metadata.uid, coefficientMatrix, interceptVector,
1382-
numClasses, isMultinomial)
1375+
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
1376+
new LogisticRegressionModel(metadata.uid, data.coefficientMatrix, data.interceptVector,
1377+
data.numClasses, data.isMultinomial)
13831378
}
13841379

13851380
metadata.getAndSetParams(model)

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
368368
@Since("2.0.0")
369369
object MultilayerPerceptronClassificationModel
370370
extends MLReadable[MultilayerPerceptronClassificationModel] {
371+
private case class Data(weights: Vector)
371372

372373
@Since("2.0.0")
373374
override def read: MLReader[MultilayerPerceptronClassificationModel] =
@@ -381,15 +382,13 @@ object MultilayerPerceptronClassificationModel
381382
class MultilayerPerceptronClassificationModelWriter(
382383
instance: MultilayerPerceptronClassificationModel) extends MLWriter {
383384

384-
private case class Data(weights: Vector)
385-
386385
override protected def saveImpl(path: String): Unit = {
387386
// Save metadata and Params
388387
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
389388
// Save model data: weights
390389
val data = Data(instance.weights)
391390
val dataPath = new Path(path, "data").toString
392-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
391+
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
393392
}
394393
}
395394

@@ -404,17 +403,16 @@ object MultilayerPerceptronClassificationModel
404403
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)
405404

406405
val dataPath = new Path(path, "data").toString
407-
val df = sparkSession.read.parquet(dataPath)
408406
val model = if (majorVersion < 3) { // model prior to 3.0.0
407+
val df = sparkSession.read.parquet(dataPath)
409408
val data = df.select("layers", "weights").head()
410409
val layers = data.getAs[Seq[Int]](0).toArray
411410
val weights = data.getAs[Vector](1)
412411
val model = new MultilayerPerceptronClassificationModel(metadata.uid, weights)
413412
model.set("layers", layers)
414413
} else {
415-
val data = df.select("weights").head()
416-
val weights = data.getAs[Vector](0)
417-
new MultilayerPerceptronClassificationModel(metadata.uid, weights)
414+
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
415+
new MultilayerPerceptronClassificationModel(metadata.uid, data.weights)
418416
}
419417
metadata.getAndSetParams(model)
420418
model

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ class NaiveBayesModel private[ml] (
598598

599599
@Since("1.6.0")
600600
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
601+
private case class Data(pi: Vector, theta: Matrix, sigma: Matrix)
601602

602603
@Since("1.6.0")
603604
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
@@ -609,8 +610,6 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
609610
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
610611
import NaiveBayes._
611612

612-
private case class Data(pi: Vector, theta: Matrix, sigma: Matrix)
613-
614613
override protected def saveImpl(path: String): Unit = {
615614
// Save metadata and Params
616615
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
@@ -624,7 +623,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
624623
}
625624

626625
val data = Data(instance.pi, instance.theta, instance.sigma)
627-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
626+
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
628627
}
629628
}
630629

@@ -639,21 +638,17 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
639638
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
640639

641640
val dataPath = new Path(path, "data").toString
642-
val data = sparkSession.read.parquet(dataPath)
643-
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
644-
645641
val model = if (major < 3) {
642+
val data = sparkSession.read.parquet(dataPath)
643+
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
646644
val Row(pi: Vector, theta: Matrix) =
647645
MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
648646
.select("pi", "theta")
649647
.head()
650648
new NaiveBayesModel(metadata.uid, pi, theta, Matrices.zeros(0, 0))
651649
} else {
652-
val Row(pi: Vector, theta: Matrix, sigma: Matrix) =
653-
MLUtils.convertMatrixColumnsToML(vecConverted, "theta", "sigma")
654-
.select("pi", "theta", "sigma")
655-
.head()
656-
new NaiveBayesModel(metadata.uid, pi, theta, sigma)
650+
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
651+
new NaiveBayesModel(metadata.uid, data.pi, data.theta, data.sigma)
657652
}
658653

659654
metadata.getAndSetParams(model)

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
277277
OneVsRestParams.validateParams(instance)
278278

279279
override protected def saveImpl(path: String): Unit = {
280+
if (ReadWriteUtils.localSavingModeState.get()) {
281+
throw new UnsupportedOperationException(
282+
"OneVsRestModel does not support saving to local filesystem path."
283+
)
284+
}
280285
val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
281286
("numClasses" -> instance.models.length)
282287
OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson))
@@ -293,6 +298,11 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
293298
private val className = classOf[OneVsRestModel].getName
294299

295300
override def load(path: String): OneVsRestModel = {
301+
if (ReadWriteUtils.localSavingModeState.get()) {
302+
throw new UnsupportedOperationException(
303+
"OneVsRestModel does not support loading from local filesystem path."
304+
)
305+
}
296306
implicit val format = DefaultFormats
297307
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className)
298308
val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class GaussianMixtureModel private[ml] (
223223

224224
@Since("2.0.0")
225225
object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
226+
private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])
226227

227228
@Since("2.0.0")
228229
override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader
@@ -234,8 +235,6 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
234235
private[GaussianMixtureModel] class GaussianMixtureModelWriter(
235236
instance: GaussianMixtureModel) extends MLWriter {
236237

237-
private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])
238-
239238
override protected def saveImpl(path: String): Unit = {
240239
// Save metadata and Params
241240
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
@@ -246,7 +245,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
246245
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
247246
val data = Data(weights, mus, sigmas)
248247
val dataPath = new Path(path, "data").toString
249-
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
248+
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
250249
}
251250
}
252251

@@ -259,16 +258,27 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
259258
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
260259

261260
val dataPath = new Path(path, "data").toString
262-
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
263-
val weights = row.getSeq[Double](0).toArray
264-
val mus = row.getSeq[OldVector](1).toArray
265-
val sigmas = row.getSeq[OldMatrix](2).toArray
266-
require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
267-
require(mus.length == weights.length, "Length of weight and Gaussian array must match")
268-
269-
val gaussians = mus.zip(sigmas)
261+
262+
val data = if (ReadWriteUtils.localSavingModeState.get()) {
263+
ReadWriteUtils.loadObjectFromLocal(dataPath)
264+
} else {
265+
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
266+
Data(
267+
row.getSeq[Double](0).toArray,
268+
row.getSeq[OldVector](1).toArray,
269+
row.getSeq[OldMatrix](2).toArray
270+
)
271+
}
272+
273+
require(data.mus.length == data.sigmas.length, "Length of Mu and Sigma array must match")
274+
require(
275+
data.mus.length == data.weights.length,
276+
"Length of weight and Gaussian array must match"
277+
)
278+
279+
val gaussians = data.mus.zip(data.sigmas)
270280
.map { case (mu, sigma) => new MultivariateGaussian(mu.asML, sigma.asML) }
271-
val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
281+
val model = new GaussianMixtureModel(metadata.uid, data.weights, gaussians)
272282

273283
metadata.getAndSetParams(model)
274284
model

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi
233233
ClusterData(idx, center)
234234
}
235235
val dataPath = new Path(path, "data").toString
236-
sparkSession.createDataFrame(data.toImmutableArraySeq).repartition(1).write.parquet(dataPath)
236+
ReadWriteUtils.saveArray[ClusterData](dataPath, data, sparkSession)
237237
}
238238
}
239239

@@ -281,8 +281,8 @@ object KMeansModel extends MLReadable[KMeansModel] {
281281
val dataPath = new Path(path, "data").toString
282282

283283
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
284-
val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
285-
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
284+
val data = ReadWriteUtils.loadArray[ClusterData](dataPath, sparkSession)
285+
data.sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
286286
} else {
287287
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
288288
sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters

0 commit comments

Comments
 (0)