发布日期:2023-06-15 VIP内容

示例:如何pivot和unpivot DataFrame

Spark pivot()函数用于将数据从一个DataFrame/Dataset列pivot/旋转到多个列(将行转换为列),而unpivot用于将其转换回来(将列转换为行)。

pivot()是一种聚合,其中一个分组列值转换为具有不同数据的单独列。

首先,创建一个示例DataFrame,代码如下:

val data = Seq(
	("Banana",1000,"美国"), 
	("Carrots",1500,"美国"), 
	("Beans",1600,"美国"),
	("Orange",2000,"美国"),
	("Orange",2000,"美国"),
	("Banana",400,"中国"),
	("Carrots",1200,"中国"),
	("Beans",1500,"中国"),
	("Orange",4000,"中国"),
	("Banana",2000,"加拿大"),
	("Carrots",2000,"加拿大"),
	("Beans",2000,"墨西哥"))

import spark.sqlContext.implicits._

val df = data.toDF("Product","Amount","Country")
df.show()

执行以上代码,输出内容如下:

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
| Banana|  1000|   美国|
|Carrots|  1500|   美国|
|  Beans|  1600|   美国|
| Orange|  2000|   美国|
| Orange|  2000|   美国|
| Banana|   400|   中国|
|Carrots|  1200|   中国|
|  Beans|  1500|   中国|
| Orange|  4000|   中国|
| Banana|  2000| 加拿大|
|Carrots|  2000| 加拿大|
|  Beans|  2000| 墨西哥|
+-------+------+-------+

Pivot Spark DataFrame

Spark SQL提供pivot()函数来将数据从一列旋转到多列(将行转置到列)。它是一种聚合,其中一个分组列值转换为具有不同数据的单个列。从上面的DataFrame中,要获得每种产品出口到每个国家的总金额,将按产品分组,按国家旋转,并计算金额的总和。

val pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.show()

这将把Country从DataFrame行转换为列,并产生以下输出。当数据不存在时,默认情况下它表示为空。

+-------+----+------+------+----+
|Product|中国|加拿大|墨西哥|美国|
+-------+----+------+------+----+
| Orange|4000|  null|  null|4000|
|  Beans|1500|  null|  2000|1600|
| Banana| 400|  2000|  null|1000|
|Carrots|1200|  2000|  null|1500|
+-------+----+------+------+----+

注意,pivot是一个非常昂贵的操作,因此建议提供列数据(如果已知)作为函数的参数,如下所示。

val countries = Seq("美国","中国","加拿大","墨西哥")
val pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivotDF.show()

执行以上代码,输出内容如下:

+-------+----+----+------+------+
|Product|美国|中国|加拿大|墨西哥|
+-------+----+----+------+------+
| Orange|4000|4000|  null|  null|
|  Beans|1600|1500|  null|  2000|
| Banana|1000| 400|  2000|  null|
|Carrots|1500|1200|  2000|  null|
+-------+----+----+------+------+

另一种方法是进行两阶段聚合,例如:

val pivotDF = df.groupBy("Product","Country")
      .sum("Amount")
      .groupBy("Product")
      .pivot("Country")
      .sum("sum(Amount)")
      
pivotDF.show()

执行以上代码,输出内容如下:

+-------+----+------+------+----+
|Product|中国|加拿大|墨西哥|美国|
+-------+----+------+------+----+
| Orange|4000|  null|  null|4000|
|  Beans|1500|  null|  2000|1600|
| Banana| 400|  2000|  null|1000|
|Carrots|1200|  2000|  null|1500|
+-------+----+------+------+----+

上面两个示例返回相同的输出,但性能更好。

Unpivot Spark DataFrame

Unpivot是一个反向操作,我们可以通过将列值旋转成行值来实现。Spark SQL没有unpivot函数,因此将使用stack()函数。

下面的代码将列Country转换为行:

// unpivot
val unPivotDF = pivotDF
  .select($"Product", expr("stack(3, '加拿大', `加拿大`, '中国', `中国`, '墨西哥', `墨西哥`) as (Country,Total)"))
  .where("Total is not null")
  
unPivotDF.show() 

请注意上面代码中,表达式中列名引用使用单引号,而列值引用使用反撇号。执行以上代码,输出内容如下:

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
| Orange|   中国| 4000|
|  Beans|   中国| 1500|
|  Beans| 墨西哥| 2000|
| Banana| 加拿大| 2000|
| Banana|   中国|  400|
|Carrots| 加拿大| 2000|
|Carrots|   中国| 1200|
+-------+-------+-----+

不使用聚合进行pivot转换

如果DataFrame足够小,那么可以收集列名来形成模式,收集值来形成行,然后创建一个新的DataFrame。请看下面的示例:

// 创建一个DataFrame
val df = Seq(
  ("col1", "val1"),
  ("col2", "val2"),
  ("col3", "val3"),
  ("col4", "val4"),
  ("col5", "val5")
).toDF("COLUMN_NAME", "VALUE")

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

// 从现有DataFrame创建模式(schema)
val schema = StructType(df.select(collect_list("COLUMN_NAME")).first().getAs[Seq[String]](0).map(x => StructField(x, StringType)))

// 创建 RDD[Row] 
val values = sc.parallelize(Seq(Row.fromSeq(df.select(collect_list("VALUE")).first().getAs[Seq[String]](0))))

// 创建新的DataFrame
val df_new = spark.createDataFrame(values, schema)

df_new.show(false) 

执行以上代码,输出内容如下:

+----+----+----+----+----+
|col1|col2|col3|col4|col5|
+----+----+----+----+----+
|val1|val2|val3|val4|val5|
+----+----+----+----+----+