PySpark SQL高级分析函数

PySpark SQL提供有许多的高级分析函数,如多维聚合函数、时间窗口聚合函数和窗口分析函数等。本节就介绍这些高级分析函数的使用。

1. 使用多维聚合函数

常用的多维聚合函数包括rollup()和cube(),它们基本上是在多列上进行分组的高级版本,通常用于在这些列的组合和排列中生成子总数和大总数。

1) rollup()

当使用分层数据时,比如不同部门和分部的销售收入数据等,rollup()可以很容易地计算出它们的子总数和总数。rollup()按给定的列集的层次结构,并且总是在层次结构中的第一列启动rolling up过程。使用rollup()函数的代码如下:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder \
   .master("spark://localhost:7077") \
   .appName("pyspark demo") \
   .getOrCreate()

# 读取超市订单汇总数据
filePath = "/data/spark/超市订单.csv"
ordersDF = spark \
	.read \
	.option("header", "true") \
	.option("inferSchema","true") \
	.csv(filePath )

print(f"订单数量:{ordersDF.count()}")
ordersDF.printSchema()

执行以上代码,输出结果如下:

订单数量:10000
root
 |-- 行 ID: integer (nullable = true)
 |-- 订单 ID: string (nullable = true)
 |-- 订单日期: string (nullable = true)
 |-- 发货日期: string (nullable = true)
 |-- 邮寄方式: string (nullable = true)
 |-- 客户 ID: string (nullable = true)
 |-- 客户名称: string (nullable = true)
 |-- 细分: string (nullable = true)
 |-- 城市: string (nullable = true)
 |-- 省/自治区: string (nullable = true)
 |-- 国家: string (nullable = true)
 |-- 地区: string (nullable = true)
 |-- 产品 ID: string (nullable = true)
 |-- 类别: string (nullable = true)
 |-- 子类别: string (nullable = true)
 |-- 产品名称: string (nullable = true)
 |-- 销售额: double (nullable = true)
 |-- 数量: integer (nullable = true)
 |-- 折扣: double (nullable = true)
 |-- 利润: double (nullable = true)

查看前10条数据,代码如下:

ordersDF.show(10)

执行以上代码,输出结果如下:

+-----+---------------+----------+----------+--------+----------+--------+------+------+---------+----+----+--------------------+--------+------+--------------------------+--------+----+----+-------+
|行 ID|        订单 ID|  订单日期|  发货日期|邮寄方式|   客户 ID|客户名称|  细分|  城市|省/自治区|国家|地区|             产品 ID|    类别|子类别|                  产品名称|  销售额|数量|折扣|   利润|
+-----+---------------+----------+----------+--------+----------+--------+------+------+---------+----+----+--------------------+--------+------+--------------------------+--------+----+----+-------+
|    1|US-2017-1357144| 2017/4/27| 2017/4/29|    二级|曾惠-14485|    曾惠|  公司|  杭州|     浙江|中国|华东|办公用-用品-10002717|办公用品|  用品|        Fiskars 剪刀, 蓝色| 129.696|   2| 0.4|-60.704|
|    2|CN-2017-1973789| 2017/6/15| 2017/6/19|  标准级|许安-10165|    许安|消费者|  内江|     四川|中国|西南|办公用-信封-10004832|办公用品|  信封|  GlobeWeis 搭扣信封, 红色|  125.44|   2| 0.0|  42.56|
|    3|CN-2017-1973789| 2017/6/15| 2017/6/19|  标准级|许安-10165|    许安|消费者|  内江|     四川|中国|西南|办公用-装订-10001505|办公用品|装订机| Cardinal 孔加固材料, 回收|   31.92|   2| 0.4|    4.2|
|    4|US-2017-3017568| 2017/12/9|2017/12/13|  标准级|宋良-17170|    宋良|  公司|  镇江|     江苏|中国|华东|办公用-用品-10003746|办公用品|  用品|     Kleencut 开信刀, 工业| 321.216|   4| 0.4|-27.104|
|    5|CN-2016-2975416| 2016/5/31|  2016/6/2|    二级|万兰-15730|    万兰|消费者|  汕头|     广东|中国|中南|办公用-器具-10003452|办公用品|  器具|   KitchenAid 搅拌机, 黑色| 1375.92|   3| 0.0|  550.2|
|    6|CN-2015-4497736|2015/10/27|2015/10/31|  标准级|俞明-18325|    俞明|消费者|景德镇|     江西|中国|华东|  技术-设备-10001640|    技术|  设备|       柯尼卡 打印机, 红色|11129.58|   9| 0.0|3783.78|
|    7|CN-2015-4497736|2015/10/27|2015/10/31|  标准级|俞明-18325|    俞明|消费者|景德镇|     江西|中国|华东|办公用-装订-10001029|办公用品|装订机|        Ibico 订书机, 实惠|  479.92|   2| 0.0| 172.76|
|    8|CN-2015-4497736|2015/10/27|2015/10/31|  标准级|俞明-18325|    俞明|消费者|景德镇|     江西|中国|华东|  家具-椅子-10000578|    家具|  椅子|        SAFCO 扶手椅, 可调| 8659.84|   4| 0.0|2684.08|
|    9|CN-2015-4497736|2015/10/27|2015/10/31|  标准级|俞明-18325|    俞明|消费者|景德镇|     江西|中国|华东|办公用-纸张-10001629|办公用品|  纸张|Green Bar 计划信息表, 多色|   588.0|   5| 0.0|   46.9|
|   10|CN-2015-4497736|2015/10/27|2015/10/31|  标准级|俞明-18325|    俞明|消费者|景德镇|     江西|中国|华东|办公用-系固-10004801|办公用品|系固件|    Stockwell 橡皮筋, 整包|  154.28|   2| 0.0|  33.88|
+-----+---------------+----------+----------+--------+----------+--------+------+------+---------+----+----+--------------------+--------+------+--------------------------+--------+----+----+-------+
only showing top 10 rows

对数据进行过滤,以便更容易地看到rollup()的结果,代码如下:

twoSummary = ordersDF.select("地区","省/自治区","订单 ID").where("`地区`=='华东' or `地区`=='华北'")

# 让我们看看数据是什么样子的
print("数据量:", twoSummary.count())
twoSummary.show()

执行以上代码,输出结果如下:

数据量:4327
+----+---------+---------------+
| 地区| 省/自治区|           订单 ID|
+----+---------+---------------+
| 华东|      浙江|US-2017-1357144|
| 华东|      江苏|US-2017-3017568|
| 华东|      江西|CN-2015-4497736|
| 华东|      江西|CN-2015-4497736|
| 华东|      江西|CN-2015-4497736|
| 华东|      江西|CN-2015-4497736|
| 华东|      江西|CN-2015-4497736|
| 华东|      山东|CN-2015-2752724|
| 华东|      山东|CN-2015-2752724|
| 华东|      山东|CN-2015-2752724|
| 华东|      江苏|US-2016-2511714|
| 华东|      江苏|US-2016-2511714|
| 华东|      上海|CN-2017-5631342|
| 华东|      上海|CN-2017-5631342|
| 华东|      上海|CN-2017-5631342|
| 华东|      上海|CN-2017-5631342|
| 华东|      上海|CN-2017-5631342|
| 华东|      上海|CN-2017-5631342|
| 华东|      上海|CN-2017-5631342|
| 华东|      浙江|US-2016-4150614|
+----+---------+---------------+
only showing top 20 rows

接下来按地区、省/自治区执行rollup()操作,然后计算计数的总和,最后按null排序,代码如下:

from pyspark.sql.functions import *

twoSummary.rollup("地区", "省/自治区") \
          .agg(count("订单 ID").alias("total")) \
          .orderBy(col("地区").asc_nulls_last(), col("省/自治区").asc_nulls_last()) \
          .show()

执行以上代码,输出结果如下:

+----+---------+-----+
| 地区| 省/自治区|total|
+----+---------+-----+
| 华东|      上海|  292|
| 华东|      安徽|  347|
| 华东|      山东|  914|
| 华东|      江苏|  583|
| 华东|      江西|  139|
| 华东|      浙江|  424|
| 华东|      福建|  259|
| 华东|     null| 2958|
| 华北|    内蒙古|  224|
| 华北|      北京|  252|
| 华北|      天津|  304|
| 华北|      山西|  201|
| 华北|      河北|  388|
| 华北|     null| 1369|
|null|     null| 4327|
+----+---------+-----+

这个输出显示了华东区和华北区的每个城市的子总数,而总计显示在最后一行,并带有在“地区”和“省/自治区”的列上的null值。注意带有asc_nulls_last选项进行排序,因此PySpark SQL会将null值排序到最后位置。

2) cube()

一个cube()函数可以看作是rollup()函数的更高级版本。它在分组列的所有组合中执行聚合。因此,结果包括rollup()提供的以及其他组合所提供的。在上面的“地区”和“省/自治区”的例子中,结果将包括每个“省/自治区”的聚合。使用cube()函数的方法类似于使用rollup()函数,代码如下:

twoSummary \
	.cube("地区", "省/自治区") \
	.agg(count("订单 ID").alias("total")) \
	.orderBy(col("地区").asc_nulls_last(), col("省/自治区").asc_nulls_last()) \
	.show(30)

执行以上代码,输出结果如下所示:

+----+---------+-----+
| 地区| 省/自治区|total|
+----+---------+-----+
| 华东|      上海|  292|
| 华东|      安徽|  347|
| 华东|      山东|  914|
| 华东|      江苏|  583|
| 华东|      江西|  139|
| 华东|      浙江|  424|
| 华东|      福建|  259|
| 华东|     null| 2958|
| 华北|    内蒙古|  224|
| 华北|      北京|  252|
| 华北|      天津|  304|
| 华北|      山西|  201|
| 华北|      河北|  388|
| 华北|     null| 1369|
|null|      上海|  292|
|null|    内蒙古|  224|
|null|      北京|  252|
|null|      天津|  304|
|null|      安徽|  347|
|null|      山东|  914|
|null|      山西|  201|
|null|      江苏|  583|
|null|      江西|  139|
|null|      河北|  388|
|null|      浙江|  424|
|null|      福建|  259|
|null|     null| 4327|
+----+---------+-----+

在结果表格中,在“地区”列中有null值的行表示一个地区中所有城市的聚合。因此,一个cube()的计算结果总是比会rollup()的结果有更多的行。

2. 使用时间窗口聚合

在高级分析函数中,第二个功能是基于时间窗口执行聚合,这在处理来自物联网设备的事务或传感器值等时间序列数据时非常有用。这些时序数据由一系列的时间顺序数据点组成。这种数据集在金融或电信等行业很常见。

在PySpark 2.0中引入了时间窗口的聚合,使其能够轻松地处理时间序列数据。可以使用时间窗口聚合分析时序数据,比如京东股票的周平均收盘价,或者京东股票跨每一周的月移动平均收盘价。

有两种类型的时间窗口:滚动窗口和滑动窗口。与滚动窗口(也叫固定窗口)相比,滑动窗口需要提供额外的输入参数,用来说明在计算下一个桶时,一个时间窗口应该滑动多少。

【示例】编写PySpark SQL批处理程序,分析京东股票历史交易数据,统计京东股票的周平均价格,以及京东股票的月平均收盘价(每周计算一次)。

要计算京东股票的周平均价格,代码如下:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder \
   .master("spark://localhost:7077") \
   .appName("pyspark demo") \
   .getOrCreate()

# 加载京东股票历史交易数据
csvPath = "/data/spark/jd/jd-formated.csv"
jdDF = spark \
    .read \
    .option("header", "true") \
    .option("inferSchema","true") \
    .csv(csvPath)

# 显示该schema, 第一列是交易日期
jdDF.printSchema()
jdDF.show(10)

执行以上代码,输出结果如下:

root
 |-- Date: string (nullable = true)
 |-- Close: double (nullable = true)
 |-- Volume: integer (nullable = true)
 |-- Open: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)

+----------+-----+-------+------+-----+-----+
|       Date|Close| Volume|  Open|  High|  Low|
+----------+-----+-------+------+-----+-----+
|2022-02-15|76.13|6766205| 75.35|76.35| 74.8|
|2022-02-14|74.45|5244967| 73.94|74.62|73.01|
|2022-02-11|73.98|6673354| 75.97|76.55|73.55|
|2022-02-10| 76.4|6432184|75.955|78.39|75.24|
|2022-02-09|78.29|7061571| 76.83|78.67|76.61|
|2022-02-08|75.36|7903249| 73.12|76.07|72.05|
|2022-02-07|73.15|6135832| 74.09|74.99|72.81|
|2022-02-04|73.77|6082889| 71.94|74.95|71.86|
|2022-02-03|71.85|7493688| 72.08| 73.3|71.33|
|2022-02-02|73.21|5887066| 75.58|75.71|72.41|
+----------+-----+-------+------+-----+-----+
only showing top 10 rows

注意到其中Date字段被自动推断为string类型,因此需要对交易日期进行整理(将Date字段的字符串类型转为Date类型),代码如下:

import org.apache.spark.sql.functions._

jdStock = jdDF.withColumn("Date",to_date('Date',"yyyy/M/d"))
jdStock.printSchema()
jdStock.show(10)

执行以上代码,输出结果如下:

root
 |-- Date: date (nullable = true)
 |-- Close: double (nullable = true)
 |-- Volume: integer (nullable = true)
 |-- Open: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)

+----------+-----+-------+------+-----+-----+
|       Date|Close| Volume|  Open|  High|  Low|
+----------+-----+-------+------+-----+-----+
|2022-02-15|76.13|6766205| 75.35|76.35| 74.8|
|2022-02-14|74.45|5244967| 73.94|74.62|73.01|
|2022-02-11|73.98|6673354| 75.97|76.55|73.55|
|2022-02-10| 76.4|6432184|75.955|78.39|75.24|
|2022-02-09|78.29|7061571| 76.83|78.67|76.61|
|2022-02-08|75.36|7903249| 73.12|76.07|72.05|
|2022-02-07|73.15|6135832| 74.09|74.99|72.81|
|2022-02-04|73.77|6082889| 71.94|74.95|71.86|
|2022-02-03|71.85|7493688| 72.08| 73.3|71.33|
|2022-02-02|73.21|5887066| 75.58|75.71|72.41|
+----------+-----+-------+------+-----+-----+
only showing top 10 rows

可以看到,Date字段的数据类型已经被转换为了date日期类型。接下来,使用时间窗口函数来计算京东股票的周平均收盘价,代码如下:

# 使用窗口函数计算groupBy变换内的周平均价格
# 这是一个滚动窗口的例子,也就是固定窗口
jdWeeklyAvg = jdStock \
	.groupBy(window('Date', "1 week")) \
	.agg(avg("Close").alias("weekly_avg"))

# 结果模式有窗口启动和结束时间
jdWeeklyAvg.printSchema()

# 按开始时间顺序显示结果,并四舍五入到小数点后2位
jdWeeklyAvg \
    .orderBy("window.start") \
    .selectExpr("window.start", "window.end", "round(weekly_avg, 2) as weekly_avg") \
    .show(10)

执行以上代码,输出结果如下所示:

root
 |-- window: struct (nullable = false)
 |    |-- start: timestamp (nullable = true)
 |    |-- end: timestamp (nullable = true)
 |-- weekly_avg: double (nullable = true)

+-------------------+-------------------+----------+
|                 start|                   end|weekly_avg|
+-------------------+-------------------+----------+
|2017-02-09 08:00:00|2017-02-16 08:00:00|      30.23|
|2017-02-16 08:00:00|2017-02-23 08:00:00|      30.29|
|2017-02-23 08:00:00|2017-03-02 08:00:00|      30.65|
|2017-03-02 08:00:00|2017-03-09 08:00:00|      30.93|
|2017-03-09 08:00:00|2017-03-16 08:00:00|      31.41|
|2017-03-16 08:00:00|2017-03-23 08:00:00|      31.09|
|2017-03-23 08:00:00|2017-03-30 08:00:00|      31.45|
|2017-03-30 08:00:00|2017-04-06 08:00:00|      31.65|
|2017-04-06 08:00:00|2017-04-13 08:00:00|      32.43|
|2017-04-13 08:00:00|2017-04-20 08:00:00|      33.22|
+-------------------+-------------------+----------+
only showing top 10 rows

上面的例子使用了一个星期的滚动窗口,其中交易数据没有重叠。因此,每个交易只使用一次来计算移动平均值。而下面的例子使用了滑动窗口来计算京东股票的月平均收盘价,每周计算一次。这意味着在计算平均每月移动平均值时,一些交易数据将被多次使用。在这个滑动窗口中,窗口的大小是四个星期,每个窗口一次滑动一个星期。代码如下:

# 使用时间窗口函数来计算京东股票的月平均收盘价
# 4周窗口长度,每次1周的滑动
jdMonthlyAvg = jdStock \
    .groupBy(window('Date', "4 week", "1 week")) \
    .agg(avg("Close").alias("monthly_avg"))

# 按开始时间显示结果
jdMonthlyAvg \
    .orderBy("window.start") \
    .selectExpr("window.start", "window.end", "round(monthly_avg, 2) as monthly_avg") \
    .show(10)

执行以上代码,输出结果如下:

+-------------------+-------------------+-----------+
|                 start|                   end|monthly_avg|
+-------------------+-------------------+-----------+
|2017-01-19 08:00:00|2017-02-16 08:00:00|       30.23|
|2017-01-26 08:00:00|2017-02-23 08:00:00|       30.28|
|2017-02-02 08:00:00|2017-03-02 08:00:00|       30.46|
|2017-02-09 08:00:00|2017-03-09 08:00:00|       30.62|
|2017-02-16 08:00:00|2017-03-16 08:00:00|       30.85|
|2017-02-23 08:00:00|2017-03-23 08:00:00|       31.02|
|2017-03-02 08:00:00|2017-03-30 08:00:00|       31.22|
|2017-03-09 08:00:00|2017-04-06 08:00:00|        31.4|
|2017-03-16 08:00:00|2017-04-13 08:00:00|       31.66|
|2017-03-23 08:00:00|2017-04-20 08:00:00|       32.13|
+-------------------+-------------------+-----------+
only showing top 10 rows

3. 使用窗口分析函数

第三类高级分析函数是在逻辑分组中执行聚合的函数,这个逻辑分组被称为窗口,这些函数被称为窗口函数。有时需要对一组数据进行操作,并为每组输入行返回一个值,而窗口函数提供了这种独特的功能,使其易于执行计算,如移动平均、累积和或每一行的rank。使用窗口函数,能够轻松地执行例如移动平均、累积和/或每一行的排名这样的计算。它们显著提高了PySpark的SQL和DataFrame API的表达能力。

使用窗口函数有两个主要步骤,如下所述:

  • (1) 第一步是定义一个窗口规范,该规范定义了称为frame的行逻辑分组,这是每一行被计算的上下文。
  • (2) 第二步是应用一个合适的窗口函数。

窗口规范定义了窗口函数将使用的三个重要组件,这三个组件分别介绍如下:

  • (1) 第一个组件被称为partition by,指定用来对行进行分组的列(一个或多个列)。
  • (2) 第二个组件称为order by,它定义了如何根据一个或多个列来排序各行,以及顺序是升序或降序。
  • (3) 最后一个组件称为frame,它定义了窗口相对于当前行的边界。换句话说,frame限制了在计算当前行的值时包括哪些行。可以通过行索引或order by表达式的实际值来指定在window frame中包含的一系列行。

最后一个组件frame是可选的,有的窗口函数需要,有的窗口函数或场景不需要。窗口规范是使用在pyspark.sql.Window类中定义的函数构建的。rowsBetween和rangeBetween函数分别用来定义行索引和实际值的范围。

窗口函数可分为三种类型:排序函数、分析函数和聚合函数。在下面两个表中分别描述了排序函数和分析函数。对于聚合函数,可以使用前面提到的任何聚合函数作为窗口函数。

排序函数:

函数名称 描述
rank 返回一个frame内行的排名和排序,基于一些排序规则
dense_rank 类似于rank,但是在不同的排名之间没有间隔,紧密衔接显示
ntile(n) 在一个有序的窗口分区中返回ntile分组ID。比如,如果n是4,那么前25%行得到的ID值为1,第二个%25行得到的ID值为2,依次类推。
row_number 返回一个序列号,每个frame从1开始

分析函数:

函数名称 描述
cume_dist 返回一个frame的值的累积分布。换句话说,低于当前行的行的比例。
lag(col,offset) 返回当前行之前offset行的列值
lead(col,offset) 返回当前行之后offset行的列值

对于聚合函数,可以使用前面提到的任何聚合函数作为窗口函数。

【示例】现在假设有两个用户user01和user02,两个用户的购物交易数据如下:

用户ID	交易日期		交易金额
user01	2018-07-02	13.35
user01	2018-07-06	27.33
user01	2018-07-04	21.72
user02	2018-07-07	69.74
user02	2018-07-01	59.44
user02	2018-07-05	80.14

有了这个购物交易数据,请尝试使用窗口函数来回答以下问题:

  • (1) 对于每一个用户,最高的交易金额是多少?
  • (2) 每个用户的交易金额和最高交易金额之间的差是多少?
  • (3) 每个用户的交易金额相对上一次交易的变化是多少?
  • (4) 每个用户的移动平均交易金额是多少?
  • (5) 每个用户的累计交易金额是多少?

首先,构造一个DataFrame,包含这个小型购物交易数据,代码如下:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder \
   .master("spark://localhost:7077") \
   .appName("pyspark demo") \
   .getOrCreate()

# 为两个用户设置的小型购物交易数据集
txDataDF= spark.createDataFrame([
    ("user01", "2018-07-02", 13.35),
    ("user01", "2018-07-06", 27.33),
    ("user01", "2018-07-04", 21.72),
    ("user02", "2018-07-07", 69.74),
    ("user02", "2018-07-01", 59.44),
    ("user02", "2018-07-05", 80.14)
],["uid", "tx_date", "amount"])
    
txDataDF.printSchema()
txDataDF.show()

执行以上代码,输出结果如下:

root
 |-- uid: string (nullable = true)
 |-- tx_date: string (nullable = true)
 |-- amount: double (nullable = false)

+------+----------+------+
|    uid|   tx_date|amount|
+------+----------+------+
|user01|2018-07-02| 13.35|
|user01|2018-07-06| 27.33|
|user01|2018-07-04| 21.72|
|user02|2018-07-07| 69.74|
|user02|2018-07-01| 59.44|
|user02|2018-07-05| 80.14|
+------+----------+------+

下面应用窗口函数来回答这些问题。

(1) 为了回答第一个问题,可以将rank()窗口函数应用于一个窗口规范,该规范按用户ID对数据进行分区,并按交易金额对其进行降序排序。rank()窗口函数根据每一frame中每一行的排序顺序给每一行分配一个排名,代码如下:

# 导入Window类
from pyspark.sql import Window

# 定义window规范,按用户id分区,按数量降序排序
w = Window.partitionBy("uid").orderBy(desc("amount"))

# 增加一个新列,以包含每行的等级,应用rank函数以对每行分级(rank)
txDataWithRankDF = txDataDF.withColumn("rank", rank().over(w))
# txDataWithRankDF.show()

# 根据等级过滤行,以找到第一名并显示结果
txDataWithRankDF.where('rank == 1').show()

执行以上代码,输出结果如下:

+------+----------+------+----+
|    uid|   tx_date|amount|rank|
+------+----------+------+----+
|user02|2018-07-05| 80.14|   1|
|user01|2018-07-06| 27.33|   1|
+------+----------+------+----+

可以看出,用户user01的最高交易金额是27.33,用户user02的最高交易金额是80.14。

(2) 解决第二个问题的方法是在每个分区的所有行的amount列上应用max()函数。除了按用户ID分区之外,它还需要定义一个包含每个分区中所有行的frame边界。要定义这个frame,可以使用Window.rangeBetween()函数,以Window.unboundedPreceding作为开始值,以Window.unboundedFollowing作为结束值,代码如下:

# 使用rangeBetween来定义frame边界,它包含每个frame中的所有行
w = Window \
    .partitionBy("uid") \
    .orderBy(desc("amount")) \
    .rangeBetween(Window.unboundedPreceding,Window.unboundedFollowing)

# 增加amount_diff列,将max()函数应用于amount列,然后计算差值
txDiffWithHighestDF = txDataDF.withColumn(
    "amount_diff", round((max("amount").over(w) - col("amount")), 3)
)

# 显示结果
txDiffWithHighestDF.show()

执行以上代码,输出结果如下:

+------+----------+------+-----------+
|    uid|   tx_date|amount|amount_diff|
+------+----------+------+-----------+
|user02|2018-07-05| 80.14|         0.0|
|user02|2018-07-07| 69.74|        10.4|
|user02|2018-07-01| 59.44|        20.7|
|user01|2018-07-06| 27.33|         0.0|
|user01|2018-07-04| 21.72|        5.61|
|user01|2018-07-02| 13.35|       13.98|
+------+----------+------+-----------+

(3) 解决第三个问题的方法是使用每个分区的当前行的amount列减去上一行的amount列。获取上一行的指定字段用lag()函数。除了按用户ID分区之外,它还需要定义一个包含每个分区中所有行的frame边界。默认frame包括所有前面的行和当前行,代码如下:

# 定义window specification
w = Window.partitionBy("uid").orderBy("tx_date")

# 增加amount_diff列,计算交易量的变动值
lagDF = txDataDF.withColumn(
    "amount_var", round((col("amount") - lag("amount",1).over(w)), 3)
)

# 显示结果
lagDF.show()

执行以上代码,输出结果如下:

+------+----------+------+----------+
|    uid|   tx_date|amount|amount_var|
+------+----------+------+----------+
|user02|2018-07-01| 59.44|      null|
|user02|2018-07-05| 80.14|      20.7|
|user02|2018-07-07| 69.74|     -10.4|
|user01|2018-07-02| 13.35|      null|
|user01|2018-07-04| 21.72|      8.37|
|user01|2018-07-06| 27.33|      5.61|
+------+----------+------+----------+

(4) 为了计算每个用户按交易日期顺序移动的平均移动数量,将利用avg()函数来根据frame中的一组行计算每一行的平均数量。这里希望每一frame都包含三行:当前行加上前面的一行和后面的一行。与前面的例子类似,窗口规范将按用户ID对数据进行分区,但是每一个frame中的行将按交易日期排序,代码如下:

# 应用avg窗口函数来计算移动平均交易量
# 定义window规范,一个好的做法是指定相对于Window.currentRow的偏移量 
w = Window \
	.partitionBy("uid") \
	.orderBy("tx_date") \
	.rowsBetween(Window.currentRow-1, Window.currentRow+1)

# 在窗口上应用avg()函数到amount列,并将移动平均量四舍五入为2个小数
avgDF = txDataDF.withColumn("moving_avg",round(avg("amount").over(w), 2))

# 显示结果
avgDF.show()

执行以上代码,输出结果如下:

+------+----------+------+----------+
|    uid|   tx_date|amount|moving_avg|
+------+----------+------+----------+
|user02|2018-07-01| 59.44|     69.79|
|user02|2018-07-05| 80.14|     69.77|
|user02|2018-07-07| 69.74|     74.94|
|user01|2018-07-02| 13.35|     17.54|
|user01|2018-07-04| 21.72|      20.8|
|user01|2018-07-06| 27.33|     24.53|
+------+----------+------+----------+

(5) 为了计算每个用户的交易金额的累积总和,将把sum()函数应用于一个frame,该frame由所有行到当前行组成。其partitionBy()和orderBy()方法与移动平均示例相同,实现代码如下:

# 定义每个frame的窗口规范包括所有以前的行和当前行
w = Window \
	.partitionBy("uid") \
	.orderBy("tx_date") \
	.rowsBetween(Window.unboundedPreceding, Window.currentRow)

# 将sum函数应用于窗口规范
sumDF = txDataDF.withColumn("culm_sum",round(sum("amount").over(w),2))

# 显示结果
sumDF.show()

执行以上代码,输出结果如下:

+------+----------+------+--------+
|    uid|   tx_date|amount|culm_sum|
+------+----------+------+--------+
|user02|2018-07-01| 59.44|   59.44|
|user02|2018-07-05| 80.14|  139.58|
|user02|2018-07-07| 69.74|  209.32|
|user01|2018-07-02| 13.35|   13.35|
|user01|2018-07-04| 21.72|   35.07|
|user01|2018-07-06| 27.33|    62.4|
+------+----------+------+--------+

窗口规范的默认frame包括所有前面的行和当前行。对于前面的例子,没有必要指定frame,所以应该得到相同的结果。


《Spark原理深入与编程实战》