用户自定义函数(UDF)

尽管Spark SQL为大多数常见用例提供了大量的内置函数,但总会有一些情况下,这些功能都不能提供我们的用例所需要的功能。Spark SQL提供了一个相当简单的工具来编写用户定义的函数(UDF),并在Spark数据处理逻辑或应用程序中使用它们,就像使用内置函数一样。

UDF实际上是我们可以扩展Spark的功能以满足特定需求的一种方式。Spark的UDF可以用Python、Java或Scala来写,它们可以利用和集成任何必要的库。

用户定义标量函数

使用UDF涉及有三个步骤。第一步是编写一个函数并进行测试。第二步是通过将函数名及其签名传递给Spark的udf函数来注册该函数。最后一步是在DataFrame代码或发出SQL查询时使用UDF。在SQL查询中使用UDF时,注册过程略有不同。

下面的代码用一个简单的UDF演示前面提到的三个步骤。

// 在Scala中一个简单的UDF,将数字等级转换为考查等级

// 导入依赖包
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

// 定义case class
case class Student(name:String, score:Int)

  def main(args: Array[String]): Unit = {

    // 创建SparkSession的实例
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("Spark Basic Example")
      .getOrCreate()

    // 在Scala中一个简单的UDF,将数字等级转换为考查等级
    import spark.implicits._

    // 创建学生成绩DataFrame
    val studentDF = Seq(
      Student("张三", 85),
      Student("李四", 90),
      Student("王老五", 55)
    ).toDF()

    // 注册为视图
    studentDF.createOrReplaceTempView("students")

    spark.sql("select * from students").show()

    // 创建一个函数(普通的Scala函数)将成绩转换到考察等级
    def convertGrade(score:Int) : String = {
      score match {
        case `score` if score > 100 => "作弊"
        case `score` if score >= 90 => "优秀"
        case `score` if score >= 80 => "良好"
        case `score` if score >= 70 => "中等"
        case _ => "不及格"
      }
    }

    // 注册为一个UDF(在DSL API中使用时的注册方法)
    val convertGradeUDF = udf(convertGrade(_:Int):String)

    // 使用该UDF将成绩转换为字母等级
    studentDF.select($"name",$"score", convertGradeUDF($"score").as("grade")).show()
  }

执行以上代码,输出结果如下所示:

+------+-----+
|  name|score|
+------+-----+
|  张三|   85|
|  李四|   90|
|王老五|   55|
+------+-----+

+------+-----+------+
|  name|score| grade|
+------+-----+------+
|  张三|   85|  良好|
|  李四|   90|  优秀|
|王老五|   55|不及格|
+------+-----+------+

在SQL查询中使用UDF时,注册过程略有不同:

// 注册为UDF,在SQL中使用
spark.udf.register("convertGrade", convertGrade(_: Int): String)
spark.sql("select name, score, convertGrade(score) as grade from students").show()

执行以上代码,输出结果如下所示:

+------+-----+------+
|  name|score| grade|
+------+-----+------+
|  张三|   85|  良好|
|  李四|   90|  优秀|
|王老五|   55|不及格|
+------+-----+------+

内置的DataFrames函数提供常见的聚合,如count()、countDistinct()、avg()、max()、min()等。虽然这些函数是为DataFrames设计的,但Spark SQL也有类型安全的版本,其中一些在Scala和Java中可以使用强类型数据集。此外,用户不仅限于预定义的聚合函数,还可以创建自己的聚合函数。

无类型的用户定义聚合函数(UDAF)

用户必须继承UserDefinedAggregateFunction抽象类来实现自定义的无类型聚合函数。

例如,用户定义的计算平均值的聚合函数可以如下(这里为了演示如何创建UDAF,因为Spark内置了avg函数):

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

class MyAvg extends UserDefinedAggregateFunction {
  // 此聚合函数的输入参数的数据类型
  def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil)

  // 聚合缓冲区中值的数据类型
  def bufferSchema: StructType = { StructType(
    StructField("sum", DoubleType) :: 
    StructField("count", LongType) :: Nil)
  }

  // 返回值的数据类型
  def dataType: DataType = DoubleType

  // 这个函数是否总是在相同的输入上返回相同的输出
  def deterministic: Boolean = true

  // 初始化给定的聚合缓冲区。 
  // 缓冲区本身是一个' Row ',除了在索引处检索值(例如,get(), getBoolean())等标准方法外,
  // 它还提供了更新其值的机会。 
  // 注意,缓冲区内的数组和map映射仍然是不可变的。
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0
    buffer(1) = 0L
  }

  // 用“input”中的新输入数据更新给定的聚合缓冲区“buffer”
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)
      buffer(1) = buffer.getAs[Long](1) + 1
    }
  }

  // 合并两个聚合缓冲区并将更新后的缓冲区值存储回' buffer1 '
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
    buffer1(1) = buffer1.getAs[Long](1) + buffer2.getAs[Long](1)
  }

  // 计算最终结果
  def evaluate(buffer: Row): Double = buffer.getDouble(0) / buffer.getLong(1)
}

// 注册函数以访问它
spark.udf.register("myAvg", MyAvg)

val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()

val result = spark.sql("SELECT myAvg(salary) as avg_salary FROM employees")
result.show()

虽然UDAF编写起来很复杂,但与Dataset上的mapGroups或简单地编写RDD上的aggregateByKey等选项相比,UDAF的性能相当好。然后,我们可以直接在列上使用UDAF,也可以像对非聚合UDF那样将其添加到函数注册中心。

类型安全的用户定义聚合函数(UDAF)

强类型Dataset的用户定义聚合函数(UDAF)围绕Aggregator抽象类进行。

例如,一个类型安全的用户定义的平计算均值的聚合函数看起来像下面这样:

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)

object MyAvg extends Aggregator[Employee, Average, Double] {
  // 这个聚合的零值,应该满足任意b + 0 = b的性质
  def zero: Average = Average(0L, 0L)

  // 合并两个值以生成一个新值。 
  // 为了提高性能,该函数可以修改' buffer '并返回它,而不是构造一个新对象
  def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
  }

  // 合并两个中间值
  def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  // 转换reduce的输出
  def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count

  // 指定中间值类型的编码器
  def bufferEncoder: Encoder[Average] = Encoders.product

  // 指定最终输出值类型的编码器
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show()

// 将函数转换为'TypedColumn'并给它一个名称
val avgSalary = MyAvg.toColumn.name("avg_salary")
val result = ds.select(avgSalary)
result.show()

《Spark原理深入与编程实战》