发布日期:2022-11-13 VIP内容

航班延误数据集分析

我们将使用美国交通部的一些航班信息,探索最导致航班延误的航班属性。使用Spark Dataset,我们将探索这些航班数据来回答以下问题:当航班延误超过40分钟时,

  • 哪家航空公司的航班延误次数最多?
  • 每周哪几天的航班延误次数最多?
  • 哪些始发机场的航班延误次数最多?
  • 每天什么时候的航班延误次数最多?

航班数据是JSON文件,每个航班记录有以下信息:

属性	         含义
id	             ID,由由承运人、日期、出发地、目的地、航班号组成
dofW	         星期几(1 = Monday星期一,7 = Sunday星期日)
carrier	         承运人代码
origin	         起始机场代码
dest	             目的地机场代码
crsdephour	 规定起飞时间hour(scheduled departure hour )
crsdeptime	 规定起飞时间time(scheduled departure time)
depdelay	     起飞延误分钟数(departure delay in minutes)
crsarrtime	 预定到达时间(scheduled arrival time)
arrdelay	     到达延误分钟数(arrival delay minutes)
crselapsedtime	 飞行时间
dist	             距离(distance)

每条航班信息的格式如下:

{    
  "_id": "AA_2017-01-01_ATL_LGA_1678",
  "dofW": 7,
  "carrier": "AA",
  "origin": "ATL",
  "dest": "LGA",
  "crsdephour": 17,
  "crsdeptime": 1700,
  "depdelay": 0.0,
  "crsarrtime": 1912,
  "arrdelay": 0.0,
  "crselapsedtime": 132.0,
  "dist": 762.0
}

代码实现如下。

import org.apache.spark.ml.feature.Bucketizer
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
  * 航班延误数据集分析
  */
object FlightDelayDemo {

  // 域对象
  case class Flight(_id: String, dofW: Integer, carrier: String, origin: String,
                    dest: String, crsdephour: Integer, crsdeptime: Double, depdelay: Double,
                    crsarrtime: Double, arrdelay: Double, crselapsedtime: Double, dist: Double)
    extends Serializable

  // schema
  val schema = StructType(Array(
    StructField("_id", StringType, true),
    StructField("dofW", IntegerType, true),
    StructField("carrier", StringType, true),
    StructField("origin", StringType, true),
    StructField("dest", StringType, true),
    StructField("crsdephour", IntegerType, true),
    StructField("crsdeptime", DoubleType, true),
    StructField("depdelay", DoubleType, true),
    StructField("crsarrtime", DoubleType, true),
    StructField("arrdelay", DoubleType, true),
    StructField("crselapsedtime", DoubleType, true),
    StructField("dist", DoubleType, true)
  ))

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("flightdataset")
      .master("local[*]")
      .getOrCreate()

    var file: String = "input/flightdelay/flights20170102.json"

    // 如果是通过命令行传入数据集路径的话
    if (args.length == 1) {
      file = args(0)
    }

    import spark.implicits._
    val df: Dataset[Flight] = spark.read
      .format("json")
      .option("inferSchema", "false")
      .schema(schema)
      .load(file)
      .as[Flight]

    println("训练集")
    println(s"数据集中包含记录数量${df.count()}条")

    // 查看前5条记录
    df.show(5)

    println("过滤上午10点起飞的航班。take 3")
    df.filter(flight => flight.crsdephour == 10).take(3)

    // 按承运人(carrier)分组计数
    println("按承运人(carrier)分组计数:")
    df.groupBy("carrier").count().show()

    // 按目的地统计超过40分钟的起飞延误,并按延误时间倒序排序。
    println("按目的地统计超过40分钟的起飞延误,并按延误时间倒序排序")
    df.filter($"depdelay" > 40).groupBy("dest").count().orderBy(desc("count")).show(3)

    // 数据探索
    // 以列格式在内存中缓存DataFrame
    df.cache        // 缓存
    // 创建临时表视图
    df.createOrReplaceTempView("flights")
    // 以列格式在内存中缓存表
    spark.catalog.cacheTable("flights")     // 缓存表

    // 1)显示前5个最长的航班延误信息 - top 5
    println("最长的延误")
    // 使用DataFrame transformation
    df.select($"carrier", $"origin", $"dest", $"depdelay", $"crsdephour")
      .filter($"depdelay" > 40)
      .orderBy(desc("depdelay"))
      .show(5)

    // 使用SQL
    spark.sql(
      """
        |select carrier,origin, dest, depdelay,crsdephour, dist, dofW
        |from flights
        |where depdelay > 40
        |order by depdelay desc
        |limit 5
      """.stripMargin).show

    // 2)显示承运人的平均起飞延误时间
    println("承运人的平均起飞延误")
    df.groupBy("carrier").agg(avg("depdelay")).show

    // 3)一周内每天的平均延误起飞时间
    println("一周内每天的平均延误起飞时间")
    spark.sql(
      """
        |SELECT dofW, avg(depdelay) as avgdelay
        |FROM flights
        |GROUP BY dofW
        |ORDER BY avgdelay desc
      """.stripMargin).show

    // 4)按承运人统计起飞延误(延误>=40分钟)
    println("按承运人统计起飞延误")
    // 使用transformation API
    df.filter($"depdelay" > 40)
      .groupBy("carrier")
      .count
      .orderBy(desc("count"))
      .show(5)

    // 使用SQL语句
    spark.sql(
      """
        |select carrier, count(depdelay)
        |from flights
        |where depdelay > 40
        |group by carrier
      """.stripMargin).show

    // 5)按出发机场统计延误次数
    println("如果按出发地机场延误分钟>40,那么延误起飞的次数是多少")
    spark.sql(
      """
        |select origin, count(depdelay)
        |from flights
        |where depdelay > 40
        |group by origin
        |ORDER BY count(depdelay) desc
      """.stripMargin).show

    // 6)按每周每天统计延误起飞次数
    println("按每周天数统计延误起飞次数, 延误以>40计")
    // 使用transformation API
    df.filter($"depdelay" > 40).groupBy("dofW").count.orderBy("dofW").show()
    // 使用SQL
    spark.sql(
      """
        |select dofW, count(depdelay)
        |from flights
        |where depdelay > 40
        |group by dofW
        |order by dofW
      """.stripMargin).show()

    // 7)按小时统计起飞延误次数
    println("按小时统计起飞延误次数")
    spark.sql(
      """
        |select crsdephour, count(depdelay)
        |from flights
        |where depdelay > 40
        |group by crsdephour
        |order by crsdephour
      """.stripMargin).show()

    // 8)按航线统计延误次数
    println("按航线统计延误次数")
    spark.sql(
      """
        |select origin,dest,count(depdelay)
        |from flights
        |where depdelay > 40
        |group by origin,dest
        |ORDER BY count(depdelay) desc
      """.stripMargin).show

    // --------------------------------------------------------------------
    // 另一种方式
    // 对延误时间列(depdelay)数据分桶(40分钟以内的和40分钟以外的)
    val delaybucketizer = new Bucketizer()
      .setInputCol("depdelay")
      .setOutputCol("delayed")
      .setSplits(Array(0.0, 40.0, Double.PositiveInfinity))

    val df4 = delaybucketizer.transform(df)

    // 按是否延误统计
    df4.groupBy("delayed").count.show

    // 创建临时视图
    df4.createOrReplaceTempView("flights")

    // 按出发机场分别统计起飞延误和没有起飞延误的数量
    println("按出发机场分别统计起飞延误和没有起飞延误的数量")
    spark.sql(
      """
        |select origin, delayed, count(delayed)
        |from flights
        |group by origin, delayed
        |order by origin
      """.stripMargin).show

    // 按目的地机场统计起飞延误数量
    println("按目的地机场统计起飞延误数量")
    spark.sql(
      """
        |select dest, delayed, count(delayed)
        |from flights
        |where delayed=1
        |group by dest, delayed
        |order by dest
      """.stripMargin).show

    // 按航线(出发机场,目的机场)统计起飞延误数量
    println("按航线(出发机场,目的机场)统计起飞延误数量")
    spark.sql(
      """
        |select origin, dest, delayed, count(delayed)
        |from flights
        |where delayed=1
        |group by origin,dest,delayed
        |order by origin,dest
      """.stripMargin).show

    // 按dofW(每周每天)统计起飞延误次数
    println("按dofW(每周每天)统计起飞延误次数")
    spark.sql(
      """
        |select dofW, delayed, count(delayed)
        |from flights
        |where delayed=1
        |group by dofW, delayed
        |order by dofW
      """.stripMargin).show

    // 按小时统计的起飞延误的次数
    println("按小时统计的起飞延误的次数")
    spark.sql(
      """
        |select crsdephour, delayed, count(delayed)
        |from flights
        |where delayed=1
        |group by crsdephour, delayed
        |order by crsdephour
      """.stripMargin).show

    // 按承运人统计起飞延误的次数
    println("按承运人统计起飞延误的次数")
    spark.sql(
      """
        |select carrier, delayed, count(delayed)
        |from flights
        |where delayed=1
        |group by carrier, delayed
        |order by carrier
      """.stripMargin).show
  }
}