PySpark SQL聚合与透视函数

对大数据进行分析通常都需要对数据进行聚合操作。聚合通常需要某种形式的分组,要么在整个数据集上,要么在一个或多个列上,然后对它们应用聚合函数,比如对每个组进行求和、计数或求平均值等。PySpark SQL提供了许多常用的聚合函数。

1. 聚合函数

在PySpark中,所有的聚合都是通过函数完成的。聚合函数被设计用来在一组行上执行聚合,不管那组行是由DataFrame中的所有行还是一组子行组成的。

为了演示这些函数的用法,下面的示例中将使用“2018年11月14日深圳市价格定期监测信息”数据集。这个数据集包含一些主要副食品的监测信息,以csv格式存储在文件中。请将该数据集上传到HDFS分布式文件系统的/data/spark目录下。

首先读取价格监测信息数据集,并创建DataFrame,代码如下:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder \
   .master("spark://localhost:7077") \
   .appName("pyspark demo") \
   .getOrCreate()

# 读取数据源文件,创建DataFrame
filePath = "/data/spark/2018年11月14日深圳市价格定期监测信息.csv"
priceDF = spark \
      .read \
      .option("header","true") \
      .option("inferSchema","true") \
      .csv(filePath)

priceDF.printSchema()
priceDF.show(5)

执行以上代码,输出结果如下:

root
 |-- RECORDID: string (nullable = true)
 |-- JCLB: string (nullable = true)
 |-- JCMC: string (nullable = true)
 |-- BQ: double (nullable = true)
 |-- SQ: double (nullable = true)
 |-- TB: double (nullable = true)
 |-- HB: double (nullable = true)

+--------------------+----+--------+-----+-----+------+------+
|              RECORDID|JCLB|     JCMC|    BQ|   SQ|     TB|     HB|
+--------------------+----+--------+-----+-----+------+------+
|537B9A6E0C836F36E...|null|      椰菜| 2.27|2.261|-0.067| 0.004|
|537B9A6E0C846F36E...|null|    东北米|2.863|2.843|-0.067| 0.007|
|537B9A6E0C856F36E...|null|    早籼米| 3.08|3.044| 0.219| 0.012|
|537B9A6E0C866F36E...|null|    晚籼米|3.217| 3.22| 0.081|-0.001|
|537B9A6E0C876F36E...|null|  泰国香米| 9.48| 9.48| 0.047|   0.0|
+--------------------+----+--------+-----+-----+------+------+
only showing top 5 rows

数据集中每一行代表从一条商品的价格监测信息。其中各个字段的含义如下:

(1) JCLB:监测类别。
(2) JCMC:监测名称。
(3) BQ:本期价格。
(4) SQ:上期价格。
(5) TB:同比价格变化。
(6) HB:环比价格变化。

对数据集进行简单探索。首先,找出这个数据集总共有多少行,代码如下:

print(f"监测的商品数量有:{priceDF.count()}")

执行以上代码,输出结果如下:

监测的商品数量有:331

下面使用一些常用的聚合函数进行统计

1) count(col)

统计指定列的数量。例如,统计数据集中商品(JCMC)的数量和监测类别(JCLB)的数量,代码如下:

# 商品数量
priceDF.select(count("JCMC").alias("监测商品")).show()

# 当统计一列中的项目数量时,count(col)函数不包括计数中的null值
priceDF.select(count("JCLB").alias("监测类别")).show()

# 判断"JCLB"列为null值的有多少
nullJclb = priceDF.where(col("JCLB").isNull()).count()
print("\"JCLB\"列为null值的有:", nullJclb)

执行以上代码,输出结果如下:

+--------+
|监测商品  |
+--------+
|     331 |
+--------+

+--------+
|监测类别  |
+--------+
|     298 |
+--------+

"JCLB"列为null值的有:33

2) countDistinct(col)

它只计算每个组的唯一项。例如,统计总共有多少个商品类别,多少个商品,代码如下:

priceDF.select(
   countDistinct("JCLB").alias("监测类别"),
   countDistinct("JCMC").alias("监测商品"),
   count("*").alias("总数量")
).show()

执行以上代码,输出结果如下:

+--------+--------+------+
|  监测类别| 监测商品| 总数量|
+--------+--------+------+
|        6|       35|    331|
+--------+--------+------+

注意,它和distinct()不同。distinct()函数用来按行去重,包括null值。例如,计检测的商品类别有多少,代码如下:

priceDF.select("JCLB").distinct().count()       # 7
priceDF.select("JCLB").distinct().show()

执行以上代码,输出结果如下:

7
+------+
|  JCLB|
+------+
|   粮食|
| 食用油|
|  null|
| 水产品|
| 肉奶蛋|
|   蔬菜|
| 肉蛋奶|
+------+

3) approx_count_distinct (col, max_estimated_error=0.05)

近似唯一计数。在一个大数据集里计算每个组中唯一项的确切数量是一个成本很高且很耗时的操作。在某些用例中,有一个近似惟一的计数就足够了。例如,在线广告业务中,每小时有数亿个广告曝光并且需要生成一份报告来显示每个特定类型的成员段的独立访问者的数量。PySpark实现了approx_count_distinct()函数用来统计近似惟一计数。因为唯一计数是一个近似值,所以会有一定数量的误差。这个函数允许指定一个可接受估算误差的值。使用approx_count_distinct()函数,代码如下:

# 统计price DataFrame的"JCMC"列。默认估算误差是0.05 (5%)
priceDF.select(
      count("JCMC"),
      countDistinct("JCMC"), 
      approx_count_distinct("JCMC", 0.05)
    ).show()

执行以上代码,输出结果如下:

+-----------+--------------------+---------------------------+
|count(JCMC)|count(DISTINCT JCMC)|approx_count_distinct(JCMC)|
+-----------+--------------------+---------------------------+
|          331|                     35|                              33|
+-----------+--------------------+---------------------------+

4) min(col), max(col)

获取col列的最小值和最大值。例如,统计本期价格的最大值和最小值,代码如下:

# 统计本期价格(BQ)最大值和最小值
priceDF.select(
      min("BQ").alias("最便宜的"), 
      max("BQ").alias("最贵的")
 ).show()

执行以上代码,输出结果如下:

+--------+------+
|  最便宜的| 最贵的| 
+--------+------+
|      1.7|43.515|
+--------+------+

5) sum(col)

这个函数计算一个数字列中的值的总和。例如,计算本期价格之和,代码如下:

priceDF.select(sum("BQ")).show()

执行以上代码,输出结果如下:

+------------------+
|             sum(BQ)|
+------------------+
|2904.4770000000003|
+------------------+

6) sumDistinct(col)

该函数只汇总了一个数字列的不同值。例如,计算本期价格(唯一值)之和,代码如下:

priceDF.select(sumDistinct("BQ")).show()

执行以上代码,输出结果如下:

+------------------+
|  sum(DISTINCT BQ)|
+------------------+
|2544.9589999999994|
+------------------+

7) avg(col)

这个函数计算一个数字列的平均值。这个方便的函数简单地取总并除以项目的数量。例如,计算本期的平均价格,代码如下:

priceDF.select(avg("BQ"), sum("BQ") / count("BQ")).show()

执行以上代码,输出结果如下:

+-----------------+---------------------+
|            avg(BQ)|(sum(BQ) / count(BQ))|
+-----------------+---------------------+
|8.774854984894262|     8.774854984894262|
+-----------------+---------------------+

2. 分组聚合

分组聚合不会在DataFrame中对全局组执行聚合,而是在DataFrame中的每个子组中执行聚合。通常分组执行聚合的过程分为两步。第一步是通过使用groupBy(col1、col2、……)转换来执行分组,也就是指定要按哪些列分组。与其他返回DataFrame的转换不同,这个groupBy()转换返回一个RelationalGroupedDataset类的实例。类RelationalGroupedDataset提供了一组标准的聚合函数,可以将它们应用到每个子组中。这些聚合函数有avg(cols)、count()、mean(cols)、min(cols)、max(cols)和sum(cols)。除了count()函数之外,其余所有的函数都在数字列上执行。

例如,要按JCLB(检测类别)列分组并执行一个count()聚合(groupBy()列将自动包含在输出中),代码如下:

# 按检测的商品大类分组统计
priceDF.groupBy("JCLB").count().show(truncate=False)

执行以上代码,输出结果如下:

+------+-----+
|JCLB  |count|
+------+-----+
|粮食    |55   |
|食用油  |37   |
|水产品  |27   |
|null   |33   |
|肉奶蛋  |7    |
|蔬菜    |117  |
|肉蛋奶  |55   |
+------+-----+

按JCLB和JCMC分组之后,执行count()聚合,并按统计数量降序排序,代码如下:

# 按商品大类和商品小类分组统计
priceDF \
      .groupBy("JCLB", "JCMC") \
      .count() \
      .orderBy(col("count").desc()) \
      .show(truncate=False)
    
priceDF \
      .groupBy("JCLB", "JCMC") \
      .count() \
      .where("JCMC=='花生油'") \
      .orderBy(col("count").desc()) \
      .show(truncate=False)

执行以上代码,输出结果如下:

+------+----------+-----+
|JCLB  |JCMC       | count|
+------+----------+-----+
|粮食   |东北米      |10    |
|食用油 |花生油      |10    |
|蔬菜   |蔬菜均价    |9     |
|蔬菜   |大白菜      |9     |
|粮食   |散装面粉    |9     |
|食用油 |菜籽油      |9     |
|蔬菜   |萝卜        |9     |
|粮食   |袋装面粉    |9     |
|蔬菜   |菠菜        |9     |
|蔬菜   |芹菜        |9     |
|蔬菜   |其中:青椒  |9     |
|蔬菜   |西红柿      |9     |
|蔬菜   |黄瓜        |9     |
|蔬菜   |茄子        |9     |
|水产品 |大头鱼      |9     |
|食用油 |调和油      |9     |
|粮食   |泰国香米    |9     |
|粮食   |早籼米      |9     |
|水产品 |草鱼        |9     |
|粮食   |晚籼米      |9     |
+------+----------+-----+
only showing top 20 rows

+------+------+-----+
|JCLB  |JCMC   |count|
+------+------+-----+
|食用油 |花生油 |10    |
|null  |花生油 |1     |
+------+------+-----+ 

有时需要在同一时间对每个组执行多个聚合。例如,除了计数之外,还想知道最小值和最大值。RelationalGroupedDataset类提供一个名为agg()的功能强大的函数,它接受一个或多个列表达式,这意味着可以使用任何聚合函数。这些聚合函数返回Column类的一个实例,这样就可以使用所提供的函数来应用任何列表达式。一个常见的需求是在聚合完成后重命名列,使之更短、更可读、更易于引用。

例如,按JCLB分组之后,执行多个聚合,代码如下:

# 同时对每个组执行多个聚合
from pyspark.sql.functions import *

priceDF.na.drop() \
      .groupBy("JCLB") \
      .agg(
        count("BQ").alias("本期数量"),
        min("BQ").alias("本期最低价格"),
        max("BQ").alias("本期最高价格"),
        avg("BQ").alias("本期平均价格")
      ).show()

执行以上代码,输出结果如下:

+------+--------+------------+------------+------------------+
|  JCLB|本期数量  |本期最低价格  |本期最高价格   |      本期平均价格   |
+------+--------+------------+------------+------------------+
|  粮食 |       55|        2.844|          9.48| 4.258527272727273|
|食用油 |       37|          4.5|         11.43| 7.179567567567568|
|水产品 |       27|         8.33|         37.26| 17.51562962962963|
|肉奶蛋 |        7|         2.99|         42.93|19.195714285714285|
|  蔬菜 |      117|          1.7|         8.765| 4.372042735042735|
|肉蛋奶 |       55|          2.8|        43.515|18.250181818181815|
+------+--------+------------+------------+------------------+

函数collect_list(col)和collect_set(col)用于在应用分组后收集特定组的所有值。一旦每个组的值被收集到一个集合中,那么就可以自由地以任何选择的方式对其进行操作。这两个函数的返回集合之间有一个小的区别,那就是惟一性。collect_list()函数返回一个可能包含重复值的集合,collect_set()函数返回一个只包含唯一值的集合。

例如,使用collection_list()函数来收集每个商品大类下的商品名称,代码如下:

from pyspark.sql.functions import *

priceDF.na.drop() \
      .groupBy(col('JCLB').alias("监测类别")) \
      .agg(collect_set("JCMC").alias("监测的商品")) \
      .withColumn("监测商品数量",size('监测的商品')) \
      .show(truncate=False)

执行以上代码,输出结果如下:

+--------+-------------------------------+------------+
|  监测类别|                           监测的商品|  监测商品数量|
+--------+-------------------------------+------------+
|      粮食|    [早籼米, 散装面粉, 泰国香米,...|             6|
|    食用油|      [花生油, 菜籽油, 豆油, 调和油]|             4|
|    水产品|                [带鱼, 草鱼, 大头鱼]|             3|
|    肉奶蛋|      [其中:精瘦肉, 鸡蛋, 牛肉, ...|             7|
|      蔬菜|     [蔬菜均价, 西红柿, 大白菜, ...|            14|
|    肉蛋奶|      [其中:精瘦肉, 鸡蛋, 牛肉, ...|             8|
+--------+-------------------------------+------------+

3. 数据透视

数据透视是一种通过聚合和旋转把数据行转换成数据列的技术,它是一种将行转换成列同时应用一个或多个聚合时的方法。这样一来,分类值就会从行转到单独的列中。这种技术通常用于数据分析或报告。

【示例】有一个包含学生信息的数据集,每行包含学生姓名、性别、体重、毕业年份。现在想要知道每个毕业年份每个性别的平均体重。

首先,创建一个DataFrame,代码如下:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder \
   .master("spark://localhost:7077") \
   .appName("pyspark demo") \
   .getOrCreate()

# 构造DataFrame
studentsDF = spark.createDataFrame([
      ("刘宏明", "男", 180, 2015),
      ("赵薇", "女", 110, 2015),
      ("黄海波", "男", 200, 2015),
      ("杨幂", "女", 109, 2015),
      ("楼一萱", "女", 105, 2015),
      ("龙梅子", "女", 115, 2016),
      ("陈知远", "男", 195, 2016)
], ["name","gender","weight","graduation_year"])

studentsDF.show()

执行以上代码,输出结果如下:

+------+------+------+---------------+
|  name|gender|weight|graduation_year|
+------+------+------+---------------+
| 刘宏明|     男|    180|             2015|
|   赵薇|     女|    110|             2015|
| 黄海波|     男|    200|             2015|
|   杨幂|     女|    109|             2015|
| 楼一萱|     女|    105|             2015|
| 龙梅子|     女|    115|             2016|
| 陈知远|     男|    195|             2016|
+------+------+------+---------------+

然后调用pivot()函数在gender列上旋转,统计不同性别的平均体重,代码如下:

# 计算每年每个性别的平均体重
studentsDF \
	.groupBy("graduation_year") \
	.pivot("gender") \
	.avg("weight") \
	.show()

执行以上代码,输出结果如下:

+---------------+-----+-----+
|graduation_year|    女|    男|
+---------------+-----+-----+
|             2015|108.0|190.0|
|             2016|115.0|195.0|
+---------------+-----+-----+

可以利用agg()函数来执行多个聚合,这会在结果表中创建更多的列,代码如下:

studentsDF \
      .groupBy("graduation_year") \
      .pivot("gender") \
      .agg(
        min("weight").alias("min"),
        max("weight").alias("max"),
        avg("weight").alias("avg")
      ).show()

执行以上代码,输出结果如下:

+---------------+------+------+------+------+------+------+
|graduation_year|女_min| 女_max|女_avg| 男_min|男_max|男_avg|
+---------------+------+------+------+------+------+------+
|             2015|    105|    110| 108.0|   180|   200| 190.0|
|             2016|    115|    115| 115.0|   195|   195| 195.0|
+---------------+------+------+------+------+------+------+

如果pivot列有许多不同的值,可以选择性地选择生成聚合的值,代码如下:

studentsDF \
      .groupBy("graduation_year") \
      .pivot("gender", ["男"]) \
      .agg(
        min("weight").alias("min"),
        max("weight").alias("max"),
        avg("weight").alias("avg")
      ).show()

执行以上代码,输出结果如下:

+---------------+------+------+------+
|graduation_year|男_min| 男_max|男_avg|
+---------------+------+------+------+
|             2015|   180|    200| 190.0|
|             2016|   195|    195| 195.0|
+---------------+------+------+------+

为pivot列指定一个distinct值的列表实际上会加速旋转过程。


《PySpark原理深入与编程实战》