Skip to content

[SPARK-51867][ML] Make scala model supporting save / load methods against local filesystem path #50665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
sparkSession.createDataFrame(nodeData).repartition(numDataParts).write.parquet(dataPath)
ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession, numDataParts)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ class FMClassificationModel private[classification] (

@Since("3.0.0")
object FMClassificationModel extends MLReadable[FMClassificationModel] {
private case class Data(
intercept: Double,
linear: Vector,
factors: Matrix
)

@Since("3.0.0")
override def read: MLReader[FMClassificationModel] = new FMClassificationModelReader
Expand All @@ -356,16 +361,11 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
private[FMClassificationModel] class FMClassificationModelWriter(
instance: FMClassificationModel) extends MLWriter with Logging {

private case class Data(
intercept: Double,
linear: Vector,
factors: Matrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
}
}

Expand All @@ -376,11 +376,11 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
override def load(path: String): FMClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)

val Row(intercept: Double, linear: Vector, factors: Matrix) =
data.select("intercept", "linear", "factors").head()
val model = new FMClassificationModel(metadata.uid, intercept, linear, factors)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val model = new FMClassificationModel(
metadata.uid, data.intercept, data.linear, data.factors
)
metadata.getAndSetParams(model)
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,9 @@ class LinearSVCModel private[classification] (
}
}


@Since("2.2.0")
object LinearSVCModel extends MLReadable[LinearSVCModel] {
private case class Data(coefficients: Vector, intercept: Double)

@Since("2.2.0")
override def read: MLReader[LinearSVCModel] = new LinearSVCReader
Expand All @@ -460,14 +460,12 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
class LinearSVCWriter(instance: LinearSVCModel)
extends MLWriter with Logging {

private case class Data(coefficients: Vector, intercept: Double)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
}
}

Expand All @@ -479,10 +477,8 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
override def load(path: String): LinearSVCModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
val Row(coefficients: Vector, intercept: Double) =
data.select("coefficients", "intercept").head()
val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val model = new LinearSVCModel(metadata.uid, data.coefficients, data.intercept)
metadata.getAndSetParams(model)
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1316,9 +1316,14 @@ class LogisticRegressionModel private[spark] (
}
}


@Since("1.6.0")
object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
case class Data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why it fails MLSuite, but we should not make it public.
It seems private[spark] can work. #50760

numClasses: Int,
numFeatures: Int,
interceptVector: Vector,
coefficientMatrix: Matrix,
isMultinomial: Boolean)

@Since("1.6.0")
override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
Expand All @@ -1331,21 +1336,14 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
extends MLWriter with Logging {

private case class Data(
numClasses: Int,
numFeatures: Int,
interceptVector: Vector,
coefficientMatrix: Matrix,
isMultinomial: Boolean)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
}
}

Expand All @@ -1359,9 +1357,9 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)

val model = if (major < 2 || (major == 2 && minor == 0)) {
val data = sparkSession.read.format("parquet").load(dataPath)
// 2.0 and before
val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
MLUtils.convertVectorColumnsToML(data, "coefficients")
Expand All @@ -1374,12 +1372,9 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
interceptVector, numClasses, isMultinomial = false)
} else {
// 2.1+
val Row(numClasses: Int, numFeatures: Int, interceptVector: Vector,
coefficientMatrix: Matrix, isMultinomial: Boolean) = data
.select("numClasses", "numFeatures", "interceptVector", "coefficientMatrix",
"isMultinomial").head()
new LogisticRegressionModel(metadata.uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
new LogisticRegressionModel(metadata.uid, data.coefficientMatrix, data.interceptVector,
data.numClasses, data.isMultinomial)
}

metadata.getAndSetParams(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("2.0.0")
object MultilayerPerceptronClassificationModel
extends MLReadable[MultilayerPerceptronClassificationModel] {
private case class Data(weights: Vector)

@Since("2.0.0")
override def read: MLReader[MultilayerPerceptronClassificationModel] =
Expand All @@ -381,15 +382,13 @@ object MultilayerPerceptronClassificationModel
class MultilayerPerceptronClassificationModelWriter(
instance: MultilayerPerceptronClassificationModel) extends MLWriter {

private case class Data(weights: Vector)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights
val data = Data(instance.weights)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
}
}

Expand All @@ -404,17 +403,16 @@ object MultilayerPerceptronClassificationModel
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
val df = sparkSession.read.parquet(dataPath)
val model = if (majorVersion < 3) { // model prior to 3.0.0
val df = sparkSession.read.parquet(dataPath)
val data = df.select("layers", "weights").head()
val layers = data.getAs[Seq[Int]](0).toArray
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, weights)
model.set("layers", layers)
} else {
val data = df.select("weights").head()
val weights = data.getAs[Vector](0)
new MultilayerPerceptronClassificationModel(metadata.uid, weights)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
new MultilayerPerceptronClassificationModel(metadata.uid, data.weights)
}
metadata.getAndSetParams(model)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ class NaiveBayesModel private[ml] (

@Since("1.6.0")
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
private case class Data(pi: Vector, theta: Matrix, sigma: Matrix)

@Since("1.6.0")
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
Expand All @@ -609,8 +610,6 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
import NaiveBayes._

private case class Data(pi: Vector, theta: Matrix, sigma: Matrix)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
Expand All @@ -624,7 +623,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
}

val data = Data(instance.pi, instance.theta, instance.sigma)
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
}
}

Expand All @@ -639,21 +638,17 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")

val model = if (major < 3) {
val data = sparkSession.read.parquet(dataPath)
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
val Row(pi: Vector, theta: Matrix) =
MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
.select("pi", "theta")
.head()
new NaiveBayesModel(metadata.uid, pi, theta, Matrices.zeros(0, 0))
} else {
val Row(pi: Vector, theta: Matrix, sigma: Matrix) =
MLUtils.convertMatrixColumnsToML(vecConverted, "theta", "sigma")
.select("pi", "theta", "sigma")
.head()
new NaiveBayesModel(metadata.uid, pi, theta, sigma)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
new NaiveBayesModel(metadata.uid, data.pi, data.theta, data.sigma)
}

metadata.getAndSetParams(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"OneVsRestModel does not support saving to local filesystem path."
)
}
val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
("numClasses" -> instance.models.length)
OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson))
Expand All @@ -293,6 +298,11 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
private val className = classOf[OneVsRestModel].getName

override def load(path: String): OneVsRestModel = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"OneVsRestModel does not support loading from local filesystem path."
)
}
implicit val format = DefaultFormats
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className)
val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class GaussianMixtureModel private[ml] (

@Since("2.0.0")
object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])

@Since("2.0.0")
override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader
Expand All @@ -234,8 +235,6 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
private[GaussianMixtureModel] class GaussianMixtureModelWriter(
instance: GaussianMixtureModel) extends MLWriter {

private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
Expand All @@ -246,7 +245,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
}
}

Expand All @@ -259,16 +258,27 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val weights = row.getSeq[Double](0).toArray
val mus = row.getSeq[OldVector](1).toArray
val sigmas = row.getSeq[OldMatrix](2).toArray
require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
require(mus.length == weights.length, "Length of weight and Gaussian array must match")

val gaussians = mus.zip(sigmas)

val data = if (ReadWriteUtils.localSavingModeState.get()) {
ReadWriteUtils.loadObjectFromLocal(dataPath)
} else {
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
Data(
row.getSeq[Double](0).toArray,
row.getSeq[OldVector](1).toArray,
row.getSeq[OldMatrix](2).toArray
)
}

require(data.mus.length == data.sigmas.length, "Length of Mu and Sigma array must match")
require(
data.mus.length == data.weights.length,
"Length of weight and Gaussian array must match"
)

val gaussians = data.mus.zip(data.sigmas)
.map { case (mu, sigma) => new MultivariateGaussian(mu.asML, sigma.asML) }
val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
val model = new GaussianMixtureModel(metadata.uid, data.weights, gaussians)

metadata.getAndSetParams(model)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi
ClusterData(idx, center)
}
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(data.toImmutableArraySeq).repartition(1).write.parquet(dataPath)
ReadWriteUtils.saveArray[ClusterData](dataPath, data, sparkSession)
}
}

Expand Down Expand Up @@ -281,8 +281,8 @@ object KMeansModel extends MLReadable[KMeansModel] {
val dataPath = new Path(path, "data").toString

val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
val data = ReadWriteUtils.loadArray[ClusterData](dataPath, sparkSession)
data.sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
Expand Down
Loading