import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{OneHotEncoder, RFormula, StringIndexer, VectorDisassembler}
import org.apache.spark.sql.SparkSession
object Test {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("test").getOrCreate()
// https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv
val df = spark.read.option("header", "true").option("inferSchema", "true").csv("iris.csv")
val indexer = new StringIndexer().setInputCol("species").setOutputCol("species_idx")
val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol("species_enc")
val disassembler = new VectorDisassembler().setInputCol(encoder.getOutputCol)
val formula = new RFormula().setFormula("petal_width ~ petal_length + versicolor + virginica")
val pipeline = new Pipeline().setStages(Array(indexer, encoder, disassembler, formula))
val model = pipeline.fit(df)
// java.lang.IllegalArgumentException: Field "versicolor" does not exist.
model.transform(df)
}
}
I think it's the transformSchema that cause the error, the scripts below show the diff:
scala> new Pipeline().setStages(Array(indexer, encoder, disassembler)).fit(df).transform(df).schema
res1: org.apache.spark.sql.types.StructType = StructType(StructField(sepal_length,DoubleType,true), StructField(sepal_width,DoubleType,true), StructField(petal_length,DoubleType,true), StructField(petal_width,DoubleType,true), StructField(species,StringType,true), StructField(species_idx,DoubleType,true), StructField(species_enc,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true), StructField(versicolor,DoubleType,true), StructField(virginica,DoubleType,true))
scala> new Pipeline().setStages(Array(indexer, encoder, disassembler)).fit(df).transformSchema(df.schema)
res2: org.apache.spark.sql.types.StructType = StructType(StructField(sepal_length,DoubleType,true), StructField(sepal_width,DoubleType,true), StructField(petal_length,DoubleType,true), StructField(petal_width,DoubleType,true), StructField(species,StringType,true), StructField(species_idx,DoubleType,false), StructField(species_enc,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,false), StructField(species_enc_0,DoubleType,false))
Hi,
I'm trying to use the
VectorDisassemblerto disassemble the one-hot encoded vector, the code below shows that the stage of VectorDisassembler fail to generate accurate schemaI think it's the transformSchema that cause the error, the scripts below show the diff:
It's obvious
transform().schemagetversicolorandvirginicawhile thetransformSchemagetspecies_enc_0