diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 2c9f518c772c4..534c54c567589 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -300,7 +300,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString val numDataParts = NodeData.inferNumPartitions(instance.numNodes) - sparkSession.createDataFrame(nodeData).repartition(numDataParts).write.parquet(dataPath) + ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession, numDataParts) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index 9621c4c9d76bf..b0dba4e3cf9d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -345,6 +345,11 @@ class FMClassificationModel private[classification] ( @Since("3.0.0") object FMClassificationModel extends MLReadable[FMClassificationModel] { + private case class Data( + intercept: Double, + linear: Vector, + factors: Matrix + ) @Since("3.0.0") override def read: MLReader[FMClassificationModel] = new FMClassificationModelReader @@ -356,16 +361,11 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] { private[FMClassificationModel] class FMClassificationModelWriter( instance: FMClassificationModel) extends MLWriter with Logging { - private case class Data( - intercept: Double, - linear: Vector, - factors: Matrix) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.intercept, instance.linear, instance.factors) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -376,11 +376,11 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] { override def load(path: String): FMClassificationModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.format("parquet").load(dataPath) - val Row(intercept: Double, linear: Vector, factors: Matrix) = - data.select("intercept", "linear", "factors").head() - val model = new FMClassificationModel(metadata.uid, intercept, linear, factors) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new FMClassificationModel( + metadata.uid, data.intercept, data.linear, data.factors + ) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ec4896fe3c445..e67e7b0daed1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -445,9 +445,9 @@ class LinearSVCModel private[classification] ( } } - @Since("2.2.0") object LinearSVCModel extends MLReadable[LinearSVCModel] { + private case class Data(coefficients: Vector, intercept: Double) @Since("2.2.0") override def read: MLReader[LinearSVCModel] = new LinearSVCReader @@ -460,14 +460,12 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { class LinearSVCWriter(instance: LinearSVCModel) extends MLWriter with Logging { - private case class Data(coefficients: Vector, intercept: Double) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.coefficients, instance.intercept) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -479,10 +477,8 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { override def load(path: String): LinearSVCModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.format("parquet").load(dataPath) - val Row(coefficients: Vector, intercept: Double) = - data.select("coefficients", "intercept").head() - val model = new LinearSVCModel(metadata.uid, coefficients, intercept) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new LinearSVCModel(metadata.uid, data.coefficients, data.intercept) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d0f323eb38434..093f3efba2dd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1316,9 +1316,14 @@ class LogisticRegressionModel private[spark] ( } } - @Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + case class Data( + numClasses: Int, + numFeatures: Int, + interceptVector: Vector, + coefficientMatrix: Matrix, + isMultinomial: Boolean) @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader @@ -1331,13 +1336,6 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { - private case class Data( - numClasses: Int, - numFeatures: Int, - interceptVector: Vector, - coefficientMatrix: Matrix, - isMultinomial: Boolean) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -1345,7 +1343,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector, instance.coefficientMatrix, instance.isMultinomial) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -1359,9 +1357,9 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.format("parquet").load(dataPath) val model = if (major < 2 || (major == 2 && minor == 0)) { + val data = sparkSession.read.format("parquet").load(dataPath) // 2.0 and before val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) = MLUtils.convertVectorColumnsToML(data, "coefficients") @@ -1374,12 +1372,9 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { interceptVector, numClasses, isMultinomial = false) } else { // 2.1+ - val Row(numClasses: Int, numFeatures: Int, interceptVector: Vector, - coefficientMatrix: Matrix, isMultinomial: Boolean) = data - .select("numClasses", "numFeatures", "interceptVector", "coefficientMatrix", - "isMultinomial").head() - new LogisticRegressionModel(metadata.uid, coefficientMatrix, interceptVector, - numClasses, isMultinomial) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + new LogisticRegressionModel(metadata.uid, data.coefficientMatrix, data.interceptVector, + data.numClasses, data.isMultinomial) } metadata.getAndSetParams(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 099f237b4b10a..f8f41a6a6bece 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -368,6 +368,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("2.0.0") object MultilayerPerceptronClassificationModel extends MLReadable[MultilayerPerceptronClassificationModel] { + private case class Data(weights: Vector) @Since("2.0.0") override def read: MLReader[MultilayerPerceptronClassificationModel] = @@ -381,15 +382,13 @@ object MultilayerPerceptronClassificationModel class MultilayerPerceptronClassificationModelWriter( instance: MultilayerPerceptronClassificationModel) extends MLWriter { - private case class Data(weights: Vector) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: weights val data = Data(instance.weights) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -404,17 +403,16 @@ object MultilayerPerceptronClassificationModel val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString - val df = sparkSession.read.parquet(dataPath) val model = if (majorVersion < 3) { // model prior to 3.0.0 + val df = sparkSession.read.parquet(dataPath) val data = df.select("layers", "weights").head() val layers = data.getAs[Seq[Int]](0).toArray val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, weights) model.set("layers", layers) } else { - val data = df.select("weights").head() - val weights = data.getAs[Vector](0) - new MultilayerPerceptronClassificationModel(metadata.uid, weights) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + new MultilayerPerceptronClassificationModel(metadata.uid, data.weights) } metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 4b0f8c311c3d0..c07e3289f6536 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -598,6 +598,7 @@ class NaiveBayesModel private[ml] ( @Since("1.6.0") object NaiveBayesModel extends MLReadable[NaiveBayesModel] { + private case class Data(pi: Vector, theta: Matrix, sigma: Matrix) @Since("1.6.0") override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader @@ -609,8 +610,6 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { import NaiveBayes._ - private case class Data(pi: Vector, theta: Matrix, sigma: Matrix) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -624,7 +623,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { } val data = Data(instance.pi, instance.theta, instance.sigma) - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -639,21 +638,17 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") - val model = if (major < 3) { + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") .select("pi", "theta") .head() new NaiveBayesModel(metadata.uid, pi, theta, Matrices.zeros(0, 0)) } else { - val Row(pi: Vector, theta: Matrix, sigma: Matrix) = - MLUtils.convertMatrixColumnsToML(vecConverted, "theta", "sigma") - .select("pi", "theta", "sigma") - .head() - new NaiveBayesModel(metadata.uid, pi, theta, sigma) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + new NaiveBayesModel(metadata.uid, data.pi, data.theta, data.sigma) } metadata.getAndSetParams(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index b3a512caa0c18..929bd1541ec6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -277,6 +277,11 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { OneVsRestParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "OneVsRestModel does not support saving to local filesystem path." + ) + } val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~ ("numClasses" -> instance.models.length) OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson)) @@ -293,6 +298,11 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { private val className = classOf[OneVsRestModel].getName override def load(path: String): OneVsRestModel = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "OneVsRestModel does not support loading from local filesystem path." + ) + } implicit val format = DefaultFormats val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className) val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 42ce5d329ce0a..5924a9976c9b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -223,6 +223,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { + private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix]) @Since("2.0.0") override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader @@ -234,8 +235,6 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { private[GaussianMixtureModel] class GaussianMixtureModelWriter( instance: GaussianMixtureModel) extends MLWriter { - private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -246,7 +245,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov)) val data = Data(weights, mus, sigmas) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -259,16 +258,27 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head() - val weights = row.getSeq[Double](0).toArray - val mus = row.getSeq[OldVector](1).toArray - val sigmas = row.getSeq[OldMatrix](2).toArray - require(mus.length == sigmas.length, "Length of Mu and Sigma array must match") - require(mus.length == weights.length, "Length of weight and Gaussian array must match") - - val gaussians = mus.zip(sigmas) + + val data = if (ReadWriteUtils.localSavingModeState.get()) { + ReadWriteUtils.loadObjectFromLocal(dataPath) + } else { + val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head() + Data( + row.getSeq[Double](0).toArray, + row.getSeq[OldVector](1).toArray, + row.getSeq[OldMatrix](2).toArray + ) + } + + require(data.mus.length == data.sigmas.length, "Length of Mu and Sigma array must match") + require( + data.mus.length == data.weights.length, + "Length of weight and Gaussian array must match" + ) + + val gaussians = data.mus.zip(data.sigmas) .map { case (mu, sigma) => new MultivariateGaussian(mu.asML, sigma.asML) } - val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) + val model = new GaussianMixtureModel(metadata.uid, data.weights, gaussians) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 0821d9a841cc3..ca90097eb01dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -233,7 +233,7 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi ClusterData(idx, center) } val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data.toImmutableArraySeq).repartition(1).write.parquet(dataPath) + ReadWriteUtils.saveArray[ClusterData](dataPath, data, sparkSession) } } @@ -281,8 +281,8 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] - data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) + val data = ReadWriteUtils.loadArray[ClusterData](dataPath, sparkSession) + data.sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 0c52118643856..9fde28502973c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{monotonically_increasing_id, udf} @@ -642,27 +641,27 @@ class LocalLDAModel private[ml] ( } } - @Since("1.6.0") object LocalLDAModel extends MLReadable[LocalLDAModel] { + private case class LocalModelData( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) private[LocalLDAModel] class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { - private case class Data( - vocabSize: Int, - topicsMatrix: Matrix, - docConcentration: Vector, - topicConcentration: Double, - gammaShape: Double) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val oldModel = instance.oldLocalModel - val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, - oldModel.topicConcentration, oldModel.gammaShape) + val data = LocalModelData( + instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape + ) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[LocalModelData](dataPath, data, sparkSession) } } @@ -673,16 +672,15 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { override def load(path: String): LocalLDAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration") - val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix") - val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector, - topicConcentration: Double, gammaShape: Double) = - matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration", - "topicConcentration", "gammaShape").head() - val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, - gammaShape) - val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession) + + val data = ReadWriteUtils.loadObject[LocalModelData](dataPath, sparkSession) + val oldModel = new OldLocalLDAModel( + data.topicsMatrix, + data.docConcentration, + data.topicConcentration, + data.gammaShape + ) + val model = new LocalLDAModel(metadata.uid, data.vocabSize, oldModel, sparkSession) LDAParams.getAndSetParams(model, metadata) model } @@ -820,6 +818,11 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "DistributedLDAModel does not support saving to local filesystem path." + ) + } DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val modelPath = new Path(path, "oldModel").toString instance.oldDistributedModel.save(sc, modelPath) @@ -831,6 +834,11 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { private val className = classOf[DistributedLDAModel].getName override def load(path: String): DistributedLDAModel = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "DistributedLDAModel does not support loading from local filesystem path." + ) + } val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index c429788ee3685..aee51e4be5193 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -26,8 +26,6 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ @@ -214,6 +212,8 @@ object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomP @Since("2.1.0") object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] { + // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved. + private case class Data(randUnitVectors: Matrix) @Since("2.1.0") override def read: MLReader[BucketedRandomProjectionLSHModel] = { @@ -226,14 +226,11 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter( instance: BucketedRandomProjectionLSHModel) extends MLWriter { - // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved. - private case class Data(randUnitVectors: Matrix) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.randMatrix) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -247,11 +244,8 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val Row(randMatrix: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors") - .select("randUnitVectors") - .head() - val model = new BucketedRandomProjectionLSHModel(metadata.uid, randMatrix) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new BucketedRandomProjectionLSHModel(metadata.uid, data.randUnitVectors) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index ff18efb149399..5205e3965bbc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -170,16 +170,15 @@ final class ChiSqSelectorModel private[ml] ( @Since("1.6.0") object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + private case class Data(selectedFeatures: Seq[Int]) class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { - private case class Data(selectedFeatures: Seq[Int]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.selectedFeatures.toImmutableArraySeq) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -190,9 +189,8 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { override def load(path: String): ChiSqSelectorModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head() - val selectedFeatures = data.getAs[Seq[Int]](0).toArray - val model = new ChiSqSelectorModel(metadata.uid, selectedFeatures) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new ChiSqSelectorModel(metadata.uid, data.selectedFeatures.toArray) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 34465248f20df..55e03781ad27e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -368,17 +368,16 @@ class CountVectorizerModel( @Since("1.6.0") object CountVectorizerModel extends MLReadable[CountVectorizerModel] { + private case class Data(vocabulary: Seq[String]) private[CountVectorizerModel] class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { - private case class Data(vocabulary: Seq[String]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.vocabulary.toImmutableArraySeq) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -389,11 +388,8 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("vocabulary") - .head() - val vocabulary = data.getAs[Seq[String]](0).toArray - val model = new CountVectorizerModel(metadata.uid, vocabulary) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new CountVectorizerModel(metadata.uid, data.vocabulary.toArray) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index c2b7ff7b00a3c..e4ba7a0adec2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -195,16 +195,15 @@ class IDFModel private[ml] ( @Since("1.6.0") object IDFModel extends MLReadable[IDFModel] { + private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long) private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { - private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.idf, instance.docFreq, instance.numDocs) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -218,10 +217,11 @@ object IDFModel extends MLReadable[IDFModel] { val data = sparkSession.read.parquet(dataPath) val model = if (majorVersion(metadata.sparkVersion) >= 3) { - val Row(idf: Vector, df: scala.collection.Seq[_], numDocs: Long) = - data.select("idf", "docFreq", "numDocs").head() - new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf), - df.asInstanceOf[scala.collection.Seq[Long]].toArray, numDocs)) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + new IDFModel( + metadata.uid, + new feature.IDFModel(OldVectors.fromML(data.idf), data.docFreq, data.numDocs) + ) } else { val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf") .select("idf") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index b9fb20d14933f..a4109a8ad9e1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -313,7 +313,13 @@ object ImputerModel extends MLReadable[ImputerModel] { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val dataPath = new Path(path, "data").toString - instance.surrogateDF.repartition(1).write.parquet(dataPath) + if (ReadWriteUtils.localSavingModeState.get()) { + ReadWriteUtils.saveObjectToLocal[(Array[String], Array[Double])]( + dataPath, (instance.columnNames, instance.surrogates) + ) + } else { + instance.surrogateDF.repartition(1).write.parquet(dataPath) + } } } @@ -324,11 +330,16 @@ object ImputerModel extends MLReadable[ImputerModel] { override def load(path: String): ImputerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val row = sparkSession.read.parquet(dataPath).head() - val (columnNames, surrogates) = row.schema.fieldNames.zipWithIndex - .map { case (name, index) => (name, row.getDouble(index)) } - .unzip - val model = new ImputerModel(metadata.uid, columnNames, surrogates) + val model = if (ReadWriteUtils.localSavingModeState.get()) { + val data = ReadWriteUtils.loadObjectFromLocal[(Array[String], Array[Double])](dataPath) + new ImputerModel(metadata.uid, data._1, data._2) + } else { + val row = sparkSession.read.parquet(dataPath).head() + val (columnNames, surrogates) = row.schema.fieldNames.zipWithIndex + .map { case (name, index) => (name, row.getDouble(index)) } + .unzip + new ImputerModel(metadata.uid, columnNames, surrogates) + } metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index a9f1cd34ba3ed..a15578ae31851 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -158,17 +158,16 @@ class MaxAbsScalerModel private[ml] ( @Since("2.0.0") object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { + private case class Data(maxAbs: Vector) private[MaxAbsScalerModel] class MaxAbsScalerModelWriter(instance: MaxAbsScalerModel) extends MLWriter { - private case class Data(maxAbs: Vector) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = new Data(instance.maxAbs) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -179,10 +178,8 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { override def load(path: String): MaxAbsScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath) - .select("maxAbs") - .head() - val model = new MaxAbsScalerModel(metadata.uid, maxAbs) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new MaxAbsScalerModel(metadata.uid, data.maxAbs) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 96d341b163474..1bddc67f8f810 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -210,6 +210,7 @@ object MinHashLSH extends DefaultParamsReadable[MinHashLSH] { @Since("2.1.0") object MinHashLSHModel extends MLReadable[MinHashLSHModel] { + private case class Data(randCoefficients: Array[Int]) @Since("2.1.0") override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader @@ -220,13 +221,11 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel) extends MLWriter { - private case class Data(randCoefficients: Array[Int]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.randCoefficients.flatMap(tuple => Array(tuple._1, tuple._2))) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -239,10 +238,11 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("randCoefficients").head() - val randCoefficients = data.getSeq[Int](0).grouped(2) - .map(tuple => (tuple(0), tuple(1))).toArray - val model = new MinHashLSHModel(metadata.uid, randCoefficients) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new MinHashLSHModel( + metadata.uid, + data.randCoefficients.grouped(2).map(tuple => (tuple(0), tuple(1))).toArray + ) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index c54e64f97953e..e806d4a29d333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -26,7 +26,6 @@ import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -243,17 +242,16 @@ class MinMaxScalerModel private[ml] ( @Since("1.6.0") object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { + private case class Data(originalMin: Vector, originalMax: Vector) private[MinMaxScalerModel] class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { - private case class Data(originalMin: Vector, originalMax: Vector) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) - val data = new Data(instance.originalMin, instance.originalMax) + val data = Data(instance.originalMin, instance.originalMax) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -264,12 +262,8 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val Row(originalMin: Vector, originalMax: Vector) = - MLUtils.convertVectorColumnsToML(data, "originalMin", "originalMax") - .select("originalMin", "originalMax") - .head() - val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new MinMaxScalerModel(metadata.uid, data.originalMin, data.originalMax) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 3eaff518e8fc3..d34ffbfc202f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -401,17 +401,16 @@ class OneHotEncoderModel private[ml] ( @Since("3.0.0") object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { + private case class Data(categorySizes: Array[Int]) private[OneHotEncoderModel] class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { - private case class Data(categorySizes: Array[Int]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.categorySizes) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -422,11 +421,8 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { override def load(path: String): OneHotEncoderModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("categorySizes") - .head() - val categorySizes = data.getAs[Seq[Int]](0).toArray - val model = new OneHotEncoderModel(metadata.uid, categorySizes) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new OneHotEncoderModel(metadata.uid, data.categorySizes) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 6b61e761f5894..0c80d442114c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -181,16 +181,15 @@ class PCAModel private[ml] ( @Since("1.6.0") object PCAModel extends MLReadable[PCAModel] { + private case class Data(pc: Matrix, explainedVariance: Vector) private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { - private case class Data(pc: DenseMatrix, explainedVariance: DenseVector) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.pc, instance.explainedVariance) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -212,11 +211,8 @@ object PCAModel extends MLReadable[PCAModel] { val dataPath = new Path(path, "data").toString val model = if (majorVersion(metadata.sparkVersion) >= 2) { - val Row(pc: DenseMatrix, explainedVariance: DenseVector) = - sparkSession.read.parquet(dataPath) - .select("pc", "explainedVariance") - .head() - new PCAModel(metadata.uid, pc, explainedVariance) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + new PCAModel(metadata.uid, data.pc.toDense, data.explainedVariance.toDense) } else { // pc field is the old matrix format in Spark <= 1.6 // explainedVariance field is not present in Spark <= 1.6 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d2191185dddd5..abb69d7e873d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -441,8 +441,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: resolvedFormula val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(instance.resolvedFormula)) - .repartition(1).write.parquet(dataPath) + ReadWriteUtils.saveObject[ResolvedRFormula](dataPath, instance.resolvedFormula, sparkSession) // Save pipeline model val pmPath = new Path(path, "pipelineModel").toString instance.pipelineModel.save(pmPath) @@ -458,11 +457,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() - val label = data.getString(0) - val terms = data.getSeq[scala.collection.Seq[String]](1).map(_.toSeq) - val hasIntercept = data.getBoolean(2) - val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept) + val resolvedRFormula = ReadWriteUtils.loadObject[ResolvedRFormula](dataPath, sparkSession) val pmPath = new Path(path, "pipelineModel").toString val pipelineModel = PipelineModel.load(pmPath) @@ -501,6 +496,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str } private object ColumnPruner extends MLReadable[ColumnPruner] { + private case class Data(columnsToPrune: Seq[String]) override def read: MLReader[ColumnPruner] = new ColumnPrunerReader @@ -509,15 +505,13 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { /** [[MLWriter]] instance for [[ColumnPruner]] */ private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter { - private case class Data(columnsToPrune: Seq[String]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: columnsToPrune val data = Data(instance.columnsToPrune.toSeq) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -530,9 +524,8 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head() - val columnsToPrune = data.getAs[Seq[String]](0).toSet - val pruner = new ColumnPruner(metadata.uid, columnsToPrune) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val pruner = new ColumnPruner(metadata.uid, data.columnsToPrune.toSet) metadata.getAndSetParams(pruner) pruner @@ -595,6 +588,7 @@ private class VectorAttributeRewriter( } private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { + private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String]) override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader @@ -604,15 +598,13 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite private[VectorAttributeRewriter] class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter { - private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String]) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: vectorCol, prefixesToRewrite val data = Data(instance.vectorCol, instance.prefixesToRewrite) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -625,10 +617,10 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() - val vectorCol = data.getString(0) - val prefixesToRewrite = data.getAs[Map[String, String]](1) - val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val rewriter = new VectorAttributeRewriter( + metadata.uid, data.vectorCol, data.prefixesToRewrite + ) metadata.getAndSetParams(rewriter) rewriter diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala index 1779f0d6278f0..bb0179613b7b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala @@ -25,7 +25,6 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasRelativeError} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.QuantileSummaries @@ -280,17 +279,16 @@ class RobustScalerModel private[ml] ( @Since("3.0.0") object RobustScalerModel extends MLReadable[RobustScalerModel] { + private case class Data(range: Vector, median: Vector) private[RobustScalerModel] class RobustScalerModelWriter(instance: RobustScalerModel) extends MLWriter { - private case class Data(range: Vector, median: Vector) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.range, instance.median) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -301,12 +299,8 @@ object RobustScalerModel extends MLReadable[RobustScalerModel] { override def load(path: String): RobustScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val Row(range: Vector, median: Vector) = MLUtils - .convertVectorColumnsToML(data, "range", "median") - .select("range", "median") - .head() - val model = new RobustScalerModel(metadata.uid, range, median) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new RobustScalerModel(metadata.uid, data.range, data.median) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c1ac1fdbba7d8..19c3e4ca25cca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -26,7 +26,6 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -201,17 +200,16 @@ class StandardScalerModel private[ml] ( @Since("1.6.0") object StandardScalerModel extends MLReadable[StandardScalerModel] { + private case class Data(std: Vector, mean: Vector) private[StandardScalerModel] class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { - private case class Data(std: Vector, mean: Vector) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -222,11 +220,8 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val Row(std: Vector, mean: Vector) = MLUtils.convertVectorColumnsToML(data, "std", "mean") - .select("std", "mean") - .head() - val model = new StandardScalerModel(metadata.uid, std, mean) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new StandardScalerModel(metadata.uid, data.std, data.mean) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 06a88e9b1c499..6518b0d9cf92a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -469,17 +469,16 @@ class StringIndexerModel ( @Since("1.6.0") object StringIndexerModel extends MLReadable[StringIndexerModel] { + private case class Data(labelsArray: Seq[Seq[String]]) private[StringIndexerModel] class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { - private case class Data(labelsArray: Array[Array[String]]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) - val data = Data(instance.labelsArray) + val data = Data(instance.labelsArray.map(_.toImmutableArraySeq).toImmutableArraySeq) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -502,11 +501,8 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { val labels = data.getAs[Seq[String]](0).toArray Array(labels) } else { - // After Spark 3.0. - val data = sparkSession.read.parquet(dataPath) - .select("labelsArray") - .head() - data.getSeq[scala.collection.Seq[String]](0).map(_.toArray).toArray + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + data.labelsArray.map(_.toArray).toArray } val model = new StringIndexerModel(metadata.uid, labelsArray) metadata.getAndSetParams(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 39ffaf32a1f36..8634779b0bc92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -402,13 +402,13 @@ class TargetEncoderModel private[ml] ( @Since("4.0.0") object TargetEncoderModel extends MLReadable[TargetEncoderModel] { + private case class Data( + index: Int, categories: Array[Double], + counts: Array[Double], stats: Array[Double]) private[TargetEncoderModel] class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { - private case class Data(index: Int, categories: Array[Double], - counts: Array[Double], stats: Array[Double]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val datum = instance.stats.iterator.zipWithIndex.map { case (stat, index) => @@ -417,7 +417,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { Data(index, _categories.toArray, _counts.toArray, _stats.toArray) }.toSeq val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(datum).write.parquet(dataPath) + ReadWriteUtils.saveArray[Data](dataPath, datum.toArray, sparkSession) } } @@ -429,16 +429,10 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val stats = sparkSession.read.parquet(dataPath) - .select("index", "categories", "counts", "stats") - .collect() - .map { row => - val index = row.getInt(0) - val categories = row.getAs[Seq[Double]](1).toArray - val counts = row.getAs[Seq[Double]](2).toArray - val stats = row.getAs[Seq[Double]](3).toArray - (index, categories.zip(counts.zip(stats)).toMap) - }.sortBy(_._1).map(_._2) + val datum = ReadWriteUtils.loadArray[Data](dataPath, sparkSession) + val stats = datum.map { data => + (data.index, data.categories.zip(data.counts.zip(data.stats)).toMap) + }.sortBy(_._1).map(_._2) val model = new TargetEncoderModel(metadata.uid, stats) metadata.getAndSetParams(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index 704166d9b6575..75ff263d61b34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -338,6 +338,7 @@ class UnivariateFeatureSelectorModel private[ml]( @Since("3.1.1") object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelectorModel] { + private case class Data(selectedFeatures: Seq[Int]) @Since("3.1.1") override def read: MLReader[UnivariateFeatureSelectorModel] = @@ -349,13 +350,11 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect private[UnivariateFeatureSelectorModel] class UnivariateFeatureSelectorModelWriter( instance: UnivariateFeatureSelectorModel) extends MLWriter { - private case class Data(selectedFeatures: Seq[Int]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.selectedFeatures.toImmutableArraySeq) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -368,10 +367,8 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect override def load(path: String): UnivariateFeatureSelectorModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("selectedFeatures").head() - val selectedFeatures = data.getAs[Seq[Int]](0).toArray - val model = new UnivariateFeatureSelectorModel(metadata.uid, selectedFeatures) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new UnivariateFeatureSelectorModel(metadata.uid, data.selectedFeatures.toArray) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala index cd1905b90ace8..08ba51b413d22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala @@ -176,6 +176,7 @@ class VarianceThresholdSelectorModel private[ml]( @Since("3.1.0") object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelectorModel] { + private case class Data(selectedFeatures: Seq[Int]) @Since("3.1.0") override def read: MLReader[VarianceThresholdSelectorModel] = @@ -187,13 +188,11 @@ object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelect private[VarianceThresholdSelectorModel] class VarianceThresholdSelectorWriter( instance: VarianceThresholdSelectorModel) extends MLWriter { - private case class Data(selectedFeatures: Seq[Int]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.selectedFeatures.toImmutableArraySeq) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -206,10 +205,8 @@ object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelect override def load(path: String): VarianceThresholdSelectorModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("selectedFeatures").head() - val selectedFeatures = data.getAs[Seq[Int]](0).toArray - val model = new VarianceThresholdSelectorModel(metadata.uid, selectedFeatures) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new VarianceThresholdSelectorModel(metadata.uid, data.selectedFeatures.toArray) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 091e209227827..48ad67af09347 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -528,17 +528,16 @@ class VectorIndexerModel private[ml] ( @Since("1.6.0") object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) private[VectorIndexerModel] class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { - private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.numFeatures, instance.categoryMaps) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -549,12 +548,8 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { override def load(path: String): VectorIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("numFeatures", "categoryMaps") - .head() - val numFeatures = data.getAs[Int](0) - val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) - val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new VectorIndexerModel(metadata.uid, data.numFeatures, data.categoryMaps) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 7d6765b231b5c..50e25ccf092c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -363,14 +363,8 @@ object Word2VecModel extends MLReadable[Word2VecModel] { sc.conf.get(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "64m")) val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) - val spark = sparkSession - import spark.implicits._ - spark.createDataset[(String, Array[Float])](wordVectors.toSeq) - .repartition(numPartitions) - .map { case (word, vector) => Data(word, vector) } - .toDF() - .write - .parquet(dataPath) + val datum = wordVectors.toArray.map { case (word, vector) => Data(word, vector) } + ReadWriteUtils.saveArray[Data](dataPath, datum, sparkSession, numPartitions) } } @@ -408,7 +402,6 @@ object Word2VecModel extends MLReadable[Word2VecModel] { override def load(path: String): Word2VecModel = { val spark = sparkSession - import spark.implicits._ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) @@ -423,10 +416,8 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val wordVectors = data.getAs[Seq[Float]](1).toArray new feature.Word2VecModel(wordIndex, wordVectors) } else { - val wordVectorsMap = spark.read.parquet(dataPath).as[Data] - .collect() - .map(wordVector => (wordVector.word, wordVector.vector)) - .toMap + val datum = ReadWriteUtils.loadArray[Data](dataPath, sparkSession) + val wordVectorsMap = datum.map(wordVector => (wordVector.word, wordVector.vector)).toMap new feature.Word2VecModel(wordVectorsMap) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 7a932d250cee0..6fd20ceb562b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -343,6 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "FPGrowthModel does not support saving to local filesystem path." + ) + } val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) DefaultParamsWriter.saveMetadata(instance, path, sparkSession, extraMetadata = Some(extraMetadata)) @@ -357,6 +362,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "FPGrowthModel does not support loading from local filesystem path." + ) + } implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 36255d3df0f1f..0dd10691c5d26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -547,6 +547,8 @@ class ALSModel private[ml] ( } } +private case class FeatureData(id: Int, features: Array[Float]) + @Since("1.6.0") object ALSModel extends MLReadable[ALSModel] { @@ -569,9 +571,21 @@ object ALSModel extends MLReadable[ALSModel] { val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString - instance.userFactors.write.format("parquet").save(userPath) val itemPath = new Path(path, "itemFactors").toString - instance.itemFactors.write.format("parquet").save(itemPath) + + if (ReadWriteUtils.localSavingModeState.get()) { + // Import implicits for Dataset Encoder + val sparkSession = super.sparkSession + import sparkSession.implicits._ + + val userFactorsData = instance.userFactors.as[FeatureData].collect() + ReadWriteUtils.saveArray(userPath, userFactorsData, sparkSession) + val itemFactorsData = instance.itemFactors.as[FeatureData].collect() + ReadWriteUtils.saveArray(itemPath, itemFactorsData, sparkSession) + } else { + instance.userFactors.write.format("parquet").save(userPath) + instance.itemFactors.write.format("parquet").save(itemPath) + } } } @@ -585,9 +599,20 @@ object ALSModel extends MLReadable[ALSModel] { implicit val format = DefaultFormats val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString - val userFactors = sparkSession.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString - val itemFactors = sparkSession.read.format("parquet").load(itemPath) + + val (userFactors, itemFactors) = if (ReadWriteUtils.localSavingModeState.get()) { + import org.apache.spark.util.ArrayImplicits._ + val userFactorsData = ReadWriteUtils.loadArray[FeatureData](userPath, sparkSession) + val userFactors = sparkSession.createDataFrame(userFactorsData.toImmutableArraySeq) + val itemFactorsData = ReadWriteUtils.loadArray[FeatureData](itemPath, sparkSession) + val itemFactors = sparkSession.createDataFrame(itemFactorsData.toImmutableArraySeq) + (userFactors, itemFactors) + } else { + val userFactors = sparkSession.read.format("parquet").load(userPath) + val itemFactors = sparkSession.read.format("parquet").load(itemPath) + (userFactors, itemFactors) + } val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 5326313456663..1b77c1d4d51a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -497,6 +497,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + private case class Data(coefficients: Vector, intercept: Double, scale: Double) @Since("1.6.0") override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader @@ -509,15 +510,13 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] instance: AFTSurvivalRegressionModel ) extends MLWriter with Logging { - private case class Data(coefficients: Vector, intercept: Double, scale: Double) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: coefficients, intercept, scale val data = Data(instance.coefficients, instance.intercept, instance.scale) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -530,12 +529,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val Row(coefficients: Vector, intercept: Double, scale: Double) = - MLUtils.convertVectorColumnsToML(data, "coefficients") - .select("coefficients", "intercept", "scale") - .head() - val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new AFTSurvivalRegressionModel( + metadata.uid, data.coefficients, data.intercept, data.scale + ) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 4f38d87574132..50de0c54b8c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -309,7 +309,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString val numDataParts = NodeData.inferNumPartitions(instance.numNodes) - sparkSession.createDataFrame(nodeData).repartition(numDataParts).write.parquet(dataPath) + ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession, numDataParts) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index 5cc93e14fa3d5..09df9295d618d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -510,6 +510,10 @@ class FMRegressionModel private[regression] ( @Since("3.0.0") object FMRegressionModel extends MLReadable[FMRegressionModel] { + private case class Data( + intercept: Double, + linear: Vector, + factors: Matrix) @Since("3.0.0") override def read: MLReader[FMRegressionModel] = new FMRegressionModelReader @@ -521,16 +525,11 @@ object FMRegressionModel extends MLReadable[FMRegressionModel] { private[FMRegressionModel] class FMRegressionModelWriter( instance: FMRegressionModel) extends MLWriter with Logging { - private case class Data( - intercept: Double, - linear: Vector, - factors: Matrix) - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.intercept, instance.linear, instance.factors) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -541,11 +540,8 @@ object FMRegressionModel extends MLReadable[FMRegressionModel] { override def load(path: String): FMRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.format("parquet").load(dataPath) - - val Row(intercept: Double, linear: Vector, factors: Matrix) = data - .select("intercept", "linear", "factors").head() - val model = new FMRegressionModel(metadata.uid, intercept, linear, factors) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) + val model = new FMRegressionModel(metadata.uid, data.intercept, data.linear, data.factors) metadata.getAndSetParams(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 6d4669ec78af9..0584a21d25fcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1143,6 +1143,7 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] { + private case class Data(intercept: Double, coefficients: Vector) @Since("2.0.0") override def read: MLReader[GeneralizedLinearRegressionModel] = @@ -1156,15 +1157,13 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel) extends MLWriter with Logging { - private case class Data(intercept: Double, coefficients: Vector) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -1178,12 +1177,11 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) - val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) + val model = new GeneralizedLinearRegressionModel( + metadata.uid, data.coefficients, data.intercept + ) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index e1bfff068cfe2..5d93541ab245c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -285,6 +285,10 @@ class IsotonicRegressionModel private[ml] ( @Since("1.6.0") object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) @Since("1.6.0") override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader @@ -297,11 +301,6 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { instance: IsotonicRegressionModel ) extends MLWriter with Logging { - private case class Data( - boundaries: Array[Double], - predictions: Array[Double], - isotonic: Boolean) - override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -309,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val data = Data( instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession) } } @@ -322,13 +321,11 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("boundaries", "predictions", "isotonic").head() - val boundaries = data.getAs[Seq[Double]](0).toArray - val predictions = data.getAs[Seq[Double]](1).toArray - val isotonic = data.getBoolean(2) + val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession) val model = new IsotonicRegressionModel( - metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + metadata.uid, + new MLlibIsotonicRegressionModel(data.boundaries, data.predictions, data.isotonic) + ) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c7baf097c591c..ea27afa755516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -786,6 +786,8 @@ class LinearRegressionModel private[ml] ( } } +private case class LinearModelData(intercept: Double, coefficients: Vector, scale: Double) + /** A writer for LinearRegression that handles the "internal" (or default) format */ private class InternalLinearRegressionModelWriter extends MLWriterFormat with MLFormatRegister { @@ -793,8 +795,6 @@ private class InternalLinearRegressionModelWriter override def format(): String = "internal" override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" - private case class Data(intercept: Double, coefficients: Vector, scale: Double) - override def write(path: String, sparkSession: SparkSession, optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { val instance = stage.asInstanceOf[LinearRegressionModel] @@ -802,9 +802,9 @@ private class InternalLinearRegressionModelWriter // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: intercept, coefficients, scale - val data = Data(instance.intercept, instance.coefficients, instance.scale) + val data = LinearModelData(instance.intercept, instance.coefficients, instance.scale) val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + ReadWriteUtils.saveObject[LinearModelData](dataPath, data, sparkSession) } } @@ -847,20 +847,20 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.format("parquet").load(dataPath) val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) val model = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { // Spark 2.2 and before + val data = sparkSession.read.format("parquet").load(dataPath) val Row(intercept: Double, coefficients: Vector) = MLUtils.convertVectorColumnsToML(data, "coefficients") .select("intercept", "coefficients") .head() new LinearRegressionModel(metadata.uid, coefficients, intercept) } else { - // Spark 2.3 and later - val Row(intercept: Double, coefficients: Vector, scale: Double) = - data.select("intercept", "coefficients", "scale").head() - new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) + val data = ReadWriteUtils.loadObject[LinearModelData](dataPath, sparkSession) + new LinearRegressionModel( + metadata.uid, data.coefficients, data.intercept, data.scale + ) } metadata.getAndSetParams(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index a001edf3e0456..73a54405c752d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData -import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter} +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, ReadWriteUtils} import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} @@ -415,13 +415,17 @@ private[ml] object DecisionTreeModelReadWrite { } val dataPath = new Path(path, "data").toString - var df = sparkSession.read.parquet(dataPath) val (major, _) = VersionUtils.majorMinorVersion(metadata.sparkVersion) - if (major < 3) { + val nodeDataArray = if (major < 3) { + var df = sparkSession.read.parquet(dataPath) df = df.withColumn("rawCount", lit(-1L)) + df.as[NodeData].collect() + } else { + import org.apache.spark.ml.util.ReadWriteUtils + ReadWriteUtils.loadArray[NodeData](dataPath, sparkSession) } - buildTreeFromNodes(df.as[NodeData].collect(), impurityType) + buildTreeFromNodes(nodeDataArray, impurityType) } /** @@ -480,19 +484,19 @@ private[ml] object EnsembleModelReadWrite { instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString - sparkSession.createDataFrame(treesMetadataWeights.toImmutableArraySeq) - .toDF("treeID", "metadata", "weights") - .repartition(1) - .write.parquet(treesMetadataPath) + ReadWriteUtils.saveArray[(Int, String, Double)]( + treesMetadataPath, treesMetadataWeights, sparkSession, numDataParts = 1 + ) val dataPath = new Path(path, "data").toString val numDataParts = NodeData.inferNumPartitions(instance.trees.map(_.numNodes.toLong).sum) - val nodeDataRDD = sparkSession.sparkContext - .parallelize(instance.trees.zipWithIndex.toImmutableArraySeq) - .flatMap { case (tree, treeID) => EnsembleNodeData.build(tree, treeID) } - sparkSession.createDataFrame(nodeDataRDD) - .repartition(numDataParts) - .write.parquet(dataPath) + + val nodeDataArray = instance.trees.zipWithIndex.flatMap { + case (tree, treeID) => EnsembleNodeData.build(tree, treeID) + } + ReadWriteUtils.saveArray[EnsembleNodeData]( + dataPath, nodeDataArray, sparkSession, numDataParts + ) } /** @@ -521,37 +525,39 @@ private[ml] object EnsembleModelReadWrite { } val treesMetadataPath = new Path(path, "treesMetadata").toString - val treesMetadataRDD = sparkSession.read.parquet(treesMetadataPath) - .select("treeID", "metadata", "weights") - .as[(Int, String, Double)].rdd - .map { case (treeID: Int, json: String, weights: Double) => - treeID -> ((DefaultParamsReader.parseMetadata(json, treeClassName), weights)) - } - val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadataWeights = ReadWriteUtils.loadArray[(Int, String, Double)]( + treesMetadataPath, sparkSession + ).map { case (treeID: Int, json: String, weights: Double) => + treeID -> ((DefaultParamsReader.parseMetadata(json, treeClassName), weights)) + }.sortBy(_._1).map(_._2) + val treesMetadata = treesMetadataWeights.map(_._1) val treesWeights = treesMetadataWeights.map(_._2) val dataPath = new Path(path, "data").toString - var df = sparkSession.read.parquet(dataPath) val (major, _) = VersionUtils.majorMinorVersion(metadata.sparkVersion) - if (major < 3) { + val ensembleNodeDataArray = if (major < 3) { + var df = sparkSession.read.parquet(dataPath) val newNodeDataCol = df.schema("nodeData").dataType match { case StructType(fields) => val cols = fields.map(f => col(s"nodeData.${f.name}")) :+ lit(-1L).as("rawCount") - import org.apache.spark.util.ArrayImplicits._ struct(cols.toImmutableArraySeq: _*) } df = df.withColumn("nodeData", newNodeDataCol) + df.as[EnsembleNodeData].collect() + } else { + ReadWriteUtils.loadArray[EnsembleNodeData](dataPath, sparkSession) } + val rootNodes = ensembleNodeDataArray + .groupBy(_.treeID) + .map { case (treeID: Int, ensembleNodeDataArrayPerTree: Array[EnsembleNodeData]) => + val nodeDataArray = ensembleNodeDataArrayPerTree.map(_.nodeData) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeDataArray, impurityType) + }.toSeq + .sortBy(_._1) + .map(_._2) - val rootNodesRDD = df.as[EnsembleNodeData].rdd - .map(d => (d.treeID, d.nodeData)) - .groupByKey() - .map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) - } - val rootNodes = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index d023c8990e76d..bc6b747344e31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -392,6 +392,11 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "CrossValidatorModel does not support saving to local filesystem path." + ) + } val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", if (instance.hasSubModels) "true" else "false") @@ -429,6 +434,11 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { private val className = classOf[CrossValidatorModel].getName override def load(path: String): CrossValidatorModel = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "CrossValidatorModel does not support loading from local filesystem path." + ) + } implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index ebfcac2e4952b..324a08ba0b5ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -357,6 +357,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "TrainValidationSplitModel does not support saving to local filesystem path." + ) + } val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", if (instance.hasSubModels) "true" else "false") @@ -391,6 +396,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { private val className = classOf[TrainValidationSplitModel].getName override def load(path: String): TrainValidationSplitModel = { + if (ReadWriteUtils.localSavingModeState.get()) { + throw new UnsupportedOperationException( + "TrainValidationSplitModel does not support loading from local filesystem path." + ) + } implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d155f257d2300..8665efd23b3ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -17,20 +17,24 @@ package org.apache.spark.ml.util -import java.io.IOException +import java.io.{File, IOException} +import java.nio.file.{Files, Paths} import java.util.{Locale, ServiceLoader} import scala.collection.mutable import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.{Failure, Success, Try} +import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PATH @@ -169,6 +173,19 @@ abstract class MLWriter extends BaseReadWrite with Logging { saveImpl(path) } + /** + * Saves the ML instances to the local file system path. + */ + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + private[spark] def saveToLocal(path: String): Unit = { + ReadWriteUtils.localSavingModeState.set(true) + try { + save(path) + } finally { + ReadWriteUtils.localSavingModeState.set(false) + } + } + /** * `save()` handles overwriting and then calls this method. Subclasses should override this * method to implement the actual saving of the instance. @@ -329,6 +346,18 @@ abstract class MLReader[T] extends BaseReadWrite { @Since("1.6.0") def load(path: String): T + /** + * Loads the ML component from the local file system path. + */ + private[spark] def loadFromLocal(path: String): T = { + ReadWriteUtils.localSavingModeState.set(true) + try { + load(path) + } finally { + ReadWriteUtils.localSavingModeState.set(false) + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) } @@ -442,7 +471,7 @@ private[ml] object DefaultParamsWriter { val metadataJson = getMetadataToSave(instance, spark, extraMetadata, paramMap) // Note that we should write single file. If there are more than one row // it produces more partitions. - spark.createDataFrame(Seq(Tuple1(metadataJson))).write.text(metadataPath) + ReadWriteUtils.saveText(metadataPath, metadataJson, spark) } def saveMetadata( @@ -662,7 +691,7 @@ private[ml] object DefaultParamsReader { def loadMetadata(path: String, spark: SparkSession, expectedClassName: String): Metadata = { val metadataPath = new Path(path, "metadata").toString - val metadataStr = spark.read.text(metadataPath).first().getString(0) + val metadataStr = ReadWriteUtils.loadText(metadataPath, spark) parseMetadata(metadataStr, expectedClassName) } @@ -757,20 +786,136 @@ private[ml] object MetaAlgorithmReadWrite { private[spark] class FileSystemOverwrite extends Logging { def handleOverwrite(path: String, shouldOverwrite: Boolean, session: SparkSession): Unit = { - val hadoopConf = session.sessionState.newHadoopConf() - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(hadoopConf) - val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - if (fs.exists(qualifiedOutputPath)) { - if (shouldOverwrite) { - logInfo(log"Path ${MDC(PATH, path)} already exists. It will be overwritten.") - // TODO: Revert back to the original content if save is not successful. - fs.delete(qualifiedOutputPath, true) + val errMsg = s"Path $path already exists. To overwrite it, " + + s"please use write.overwrite().save(path) for Scala and use " + + s"write().overwrite().save(path) for Java and Python." + + if (ReadWriteUtils.localSavingModeState.get()) { + val filePath = new File(path) + if (filePath.exists()) { + if (shouldOverwrite) { + FileUtils.deleteDirectory(filePath) + } else { + throw new IOException(errMsg) + } + } + + } else { + val hadoopConf = session.sessionState.newHadoopConf() + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + if (fs.exists(qualifiedOutputPath)) { + if (shouldOverwrite) { + logInfo(log"Path ${MDC(PATH, path)} already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(qualifiedOutputPath, true) + } else { + throw new IOException(errMsg) + } + } + } + } +} + + +private[spark] object ReadWriteUtils { + + val localSavingModeState = new ThreadLocal[Boolean]() { + override def initialValue: Boolean = false + } + + def saveText(path: String, data: String, spark: SparkSession): Unit = { + if (localSavingModeState.get()) { + val filePath = Paths.get(path) + + Files.createDirectories(filePath.getParent) + Files.writeString(filePath, data) + } else { + spark.createDataFrame(Seq(Tuple1(data))).write.text(path) + } + } + + def loadText(path: String, spark: SparkSession): String = { + if (localSavingModeState.get()) { + Files.readString(Paths.get(path)) + } else { + spark.read.text(path).first().getString(0) + } + } + + def saveObjectToLocal[T <: Product: ClassTag: TypeTag](path: String, data: T): Unit = { + val serializer = SparkEnv.get.serializer.newInstance() + val dataBuffer = serializer.serialize(data) + val dataBytes = new Array[Byte](dataBuffer.limit) + dataBuffer.get(dataBytes) + + val filePath = Paths.get(path) + + Files.createDirectories(filePath.getParent) + Files.write(filePath, dataBytes) + } + + def saveObject[T <: Product: ClassTag: TypeTag]( + path: String, data: T, spark: SparkSession + ): Unit = { + if (localSavingModeState.get()) { + saveObjectToLocal(path, data) + } else { + spark.createDataFrame[T](Seq(data)).write.parquet(path) + } + } + + def loadObjectFromLocal[T <: Product: ClassTag: TypeTag](path: String): T = { + val serializer = SparkEnv.get.serializer.newInstance() + + val dataBytes = Files.readAllBytes(Paths.get(path)) + serializer.deserialize[T](java.nio.ByteBuffer.wrap(dataBytes)) + } + + def loadObject[T <: Product: ClassTag: TypeTag](path: String, spark: SparkSession): T = { + if (localSavingModeState.get()) { + loadObjectFromLocal(path) + } else { + import spark.implicits._ + spark.read.parquet(path).as[T].head() + } + } + + def saveArray[T <: Product: ClassTag: TypeTag]( + path: String, data: Array[T], spark: SparkSession, + numDataParts: Int = -1 + ): Unit = { + if (localSavingModeState.get()) { + val serializer = SparkEnv.get.serializer.newInstance() + val dataBuffer = serializer.serialize(data) + val dataBytes = new Array[Byte](dataBuffer.limit) + dataBuffer.get(dataBytes) + + val filePath = Paths.get(path) + + Files.createDirectories(filePath.getParent) + Files.write(filePath, dataBytes) + } else { + import org.apache.spark.util.ArrayImplicits._ + val df = spark.createDataFrame[T](data.toImmutableArraySeq) + if (numDataParts == -1) { + df.write.parquet(path) } else { - throw new IOException(s"Path $path already exists. To overwrite it, " + - s"please use write.overwrite().save(path) for Scala and use " + - s"write().overwrite().save(path) for Java and Python.") + df.repartition(numDataParts).write.parquet(path) } } } + + def loadArray[T <: Product: ClassTag: TypeTag](path: String, spark: SparkSession): Array[T] = { + if (localSavingModeState.get()) { + val serializer = SparkEnv.get.serializer.newInstance() + + val dataBytes = Files.readAllBytes(Paths.get(path)) + serializer.deserialize[Array[T]](java.nio.ByteBuffer.wrap(dataBytes)) + } else { + import spark.implicits._ + spark.read.parquet(path).as[T].collect() + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a0223396da317..4ae1d3ce24a6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -274,7 +274,8 @@ class LDASuite extends MLTest with DefaultReadWriteTest { val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings ++ Map("optimizer" -> "em"), - LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData, + skipTestSaveLocal = true) } test("EM LDA checkpointing: save last checkpoint") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 3d994366b8918..1630a5d07d8e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, - FPGrowthSuite.allParamSettings, checkModelData) + FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index c5bf202a2d337..537deedfbcbb5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -41,20 +41,38 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => */ def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, - testParams: Boolean = true): T = { + testParams: Boolean = true, + testSaveToLocal: Boolean = false): T = { val uid = instance.uid val subdirName = Identifiable.randomUID("test") val subdir = new File(tempDir, subdirName) val path = new File(subdir, uid).getPath - instance.save(path) - intercept[IOException] { + if (testSaveToLocal) { + instance.write.saveToLocal(path) + assert( + new File(path, "metadata").isFile(), + "saveToLocal should generate metadata as a file." + ) + intercept[IOException] { + instance.write.saveToLocal(path) + } + instance.write.overwrite().saveToLocal(path) + } else { instance.save(path) + intercept[IOException] { + instance.save(path) + } + instance.write.overwrite().save(path) } - instance.write.overwrite().save(path) + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] - val newInstance = loader.load(path) + val newInstance = if (testSaveToLocal) { + loader.loadFromLocal(path) + } else { + loader.load(path) + } assert(newInstance.uid === instance.uid) if (testParams) { instance.params.foreach { p => @@ -73,9 +91,14 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } } - - val load = instance.getClass.getMethod("load", classOf[String]) - val another = load.invoke(instance, path).asInstanceOf[T] + val another = if (testSaveToLocal) { + val read = instance.getClass.getMethod("read") + val reader = read.invoke(instance).asInstanceOf[MLReader[T]] + reader.loadFromLocal(path) + } else { + val load = instance.getClass.getMethod("load", classOf[String]) + load.invoke(instance, path).asInstanceOf[T] + } assert(another.uid === instance.uid) another } @@ -104,7 +127,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => dataset: Dataset[_], testEstimatorParams: Map[String, Any], testModelParams: Map[String, Any], - checkModelData: (M, M) => Unit): Unit = { + checkModelData: (M, M) => Unit, + skipTestSaveLocal: Boolean = false): Unit = { // Set some Params to make sure set Params are serialized. testEstimatorParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) @@ -119,13 +143,20 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } // Test Model save/load - val model2 = testDefaultReadWrite(model) - testModelParams.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) + val testTargets = if (skipTestSaveLocal) { + Seq(false) + } else { + Seq(false, true) } + for (testSaveToLocal <- testTargets) { + val model2 = testDefaultReadWrite(model, testSaveToLocal = testSaveToLocal) + testModelParams.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } - checkModelData(model, model2) + checkModelData(model, model2) + } } }