发布日期:2023-06-15
VIP内容
示例:如何pivot和unpivot DataFrame
Spark pivot()函数用于将数据从一个DataFrame/Dataset列pivot/旋转到多个列(将行转换为列),而unpivot用于将其转换回来(将列转换为行)。
pivot()是一种聚合,其中一个分组列值转换为具有不同数据的单独列。
首先,创建一个示例DataFrame,代码如下:
val data = Seq(
("Banana",1000,"美国"),
("Carrots",1500,"美国"),
("Beans",1600,"美国"),
("Orange",2000,"美国"),
("Orange",2000,"美国"),
("Banana",400,"中国"),
("Carrots",1200,"中国"),
("Beans",1500,"中国"),
("Orange",4000,"中国"),
("Banana",2000,"加拿大"),
("Carrots",2000,"加拿大"),
("Beans",2000,"墨西哥"))
import spark.sqlContext.implicits._
val df = data.toDF("Product","Amount","Country")
df.show()
执行以上代码,输出内容如下:
+-------+------+-------+ |Product|Amount|Country| +-------+------+-------+ | Banana| 1000| 美国| |Carrots| 1500| 美国| | Beans| 1600| 美国| | Orange| 2000| 美国| | Orange| 2000| 美国| | Banana| 400| 中国| |Carrots| 1200| 中国| | Beans| 1500| 中国| | Orange| 4000| 中国| | Banana| 2000| 加拿大| |Carrots| 2000| 加拿大| | Beans| 2000| 墨西哥| +-------+------+-------+
Pivot Spark DataFrame
Spark SQL提供pivot()函数来将数据从一列旋转到多列(将行转置到列)。它是一种聚合,其中一个分组列值转换为具有不同数据的单个列。从上面的DataFrame中,要获得每种产品出口到每个国家的总金额,将按产品分组,按国家旋转,并计算金额的总和。
val pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.show()
这将把Country从DataFrame行转换为列,并产生以下输出。当数据不存在时,默认情况下它表示为空。
+-------+----+------+------+----+ |Product|中国|加拿大|墨西哥|美国| +-------+----+------+------+----+ | Orange|4000| null| null|4000| | Beans|1500| null| 2000|1600| | Banana| 400| 2000| null|1000| |Carrots|1200| 2000| null|1500| +-------+----+------+------+----+
注意,pivot是一个非常昂贵的操作,因此建议提供列数据(如果已知)作为函数的参数,如下所示。
val countries = Seq("美国","中国","加拿大","墨西哥")
val pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivotDF.show()
执行以上代码,输出内容如下:
+-------+----+----+------+------+ |Product|美国|中国|加拿大|墨西哥| +-------+----+----+------+------+ | Orange|4000|4000| null| null| | Beans|1600|1500| null| 2000| | Banana|1000| 400| 2000| null| |Carrots|1500|1200| 2000| null| +-------+----+----+------+------+
另一种方法是进行两阶段聚合,例如:
val pivotDF = df.groupBy("Product","Country")
.sum("Amount")
.groupBy("Product")
.pivot("Country")
.sum("sum(Amount)")
pivotDF.show()
执行以上代码,输出内容如下:
+-------+----+------+------+----+ |Product|中国|加拿大|墨西哥|美国| +-------+----+------+------+----+ | Orange|4000| null| null|4000| | Beans|1500| null| 2000|1600| | Banana| 400| 2000| null|1000| |Carrots|1200| 2000| null|1500| +-------+----+------+------+----+
上面两个示例返回相同的输出,但性能更好。
Unpivot Spark DataFrame
Unpivot是一个反向操作,我们可以通过将列值旋转成行值来实现。Spark SQL没有unpivot函数,因此将使用stack()函数。
下面的代码将列Country转换为行:
// unpivot
val unPivotDF = pivotDF
.select($"Product", expr("stack(3, '加拿大', `加拿大`, '中国', `中国`, '墨西哥', `墨西哥`) as (Country,Total)"))
.where("Total is not null")
unPivotDF.show()
请注意上面代码中,表达式中列名引用使用单引号,而列值引用使用反撇号。执行以上代码,输出内容如下:
+-------+-------+-----+ |Product|Country|Total| +-------+-------+-----+ | Orange| 中国| 4000| | Beans| 中国| 1500| | Beans| 墨西哥| 2000| | Banana| 加拿大| 2000| | Banana| 中国| 400| |Carrots| 加拿大| 2000| |Carrots| 中国| 1200| +-------+-------+-----+
不使用聚合进行pivot转换
如果DataFrame足够小,那么可以收集列名来形成模式,收集值来形成行,然后创建一个新的DataFrame。请看下面的示例:
// 创建一个DataFrame
val df = Seq(
("col1", "val1"),
("col2", "val2"),
("col3", "val3"),
("col4", "val4"),
("col5", "val5")
).toDF("COLUMN_NAME", "VALUE")
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
// 从现有DataFrame创建模式(schema)
val schema = StructType(df.select(collect_list("COLUMN_NAME")).first().getAs[Seq[String]](0).map(x => StructField(x, StringType)))
// 创建 RDD[Row]
val values = sc.parallelize(Seq(Row.fromSeq(df.select(collect_list("VALUE")).first().getAs[Seq[String]](0))))
// 创建新的DataFrame
val df_new = spark.createDataFrame(values, schema)
df_new.show(false)
执行以上代码,输出内容如下:
+----+----+----+----+----+ |col1|col2|col3|col4|col5| +----+----+----+----+----+ |val1|val2|val3|val4|val5| +----+----+----+----+----+