2025-09-20
Spark
00
请注意,本文编写于 94 天前,最后修改于 94 天前,其中某些信息可能已经过时。

目录

UDF
UDAF

UDF

  1. 创建 DataFrame
scala
scala> val df = spark.read.json("data/user.json") df: org.apache.spark.sql.DataFrame = [age: bigint, username: string]
  1. 注册UDF
scala
scala> spark.udf.register("addName",(x:String)=> "Name:"+x) res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
  1. 创建临时表
scala
scala> df.createOrReplaceTempView("people")
  1. 应用UDF
scala
scala> spark.sql("Select addName(name),age from people").show()

UDAF

强类型的 Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(), countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通 过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从 Spark3.0 版本 后,UserDefinedAggregateFunction 已经不推荐使用了。可以统一采用强类型聚合函数 Aggregator

需求:计算平均工资

一个需求可以采用很多种不同的方法实现需求

  1. 实现方式 - RDD
scala
val conf: SparkConf = new SparkConf().setAppName("app").setMaster("local[*]") val sc: SparkContext = new SparkContext(conf) val res: (Int, Int) = sc.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangw", 40))).map { case (name, age) => { (age, 1) } }.reduce { (t1, t2) => { (t1._1 + t2._1, t1._2 + t2._2) } } println(res._1/res._2) // 关闭连接 sc.stop()
  1. 实现方式 - 累加器
scala
class MyAC extends AccumulatorV2[Int,Int]{ var sum:Int = 0 var count:Int = 0 override def isZero: Boolean = { return sum ==0 && count == 0 } override def copy(): AccumulatorV2[Int, Int] = { val newMyAc = new MyAC newMyAc.sum = this.sum newMyAc.count = this.count newMyAc } override def reset(): Unit = { sum =0 count = 0 } override def add(v: Int): Unit = { sum += v count += 1 } override def merge(other: AccumulatorV2[Int, Int]): Unit = { other match { case o:MyAC=>{ sum += o.sum count += o.count } case _=> } } override def value: Int = sum/count }
  1. 实现方式 - UDAF - 弱类型
scala
/* 定义类继承 UserDefinedAggregateFunction,并重写其中方法 */ class MyAveragUDAF extends UserDefinedAggregateFunction { // 聚合函数输入参数的数据类型 def inputSchema: StructType = StructType(Array(StructField("age",IntegerType))) // 聚合函数缓冲区中值的数据类型(age,count) def bufferSchema: StructType = { StructType(Array(StructField("sum",LongType),StructField("count",LongType))) } // 函数返回值的数据类型 def dataType: DataType = DoubleType // 稳定性:对于相同的输入是否一直返回相同的输出。 def deterministic: Boolean = true // 函数缓冲区初始化 def initialize(buffer: MutableAggregationBuffer): Unit = { // 存年龄的总和 buffer(0) = 0L // 存年龄的个数 buffer(1) = 0L } // 更新缓冲区中的数据 def update(buffer: MutableAggregationBuffer,input: Row): Unit = { if (!input.isNullAt(0)) { buffer(0) = buffer.getLong(0) + input.getInt(0) buffer(1) = buffer.getLong(1) + 1 } } // 合并缓冲区 def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } // 计算最终结果 def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1) } 。。。 //创建聚合函数 var myAverage = new MyAveragUDAF //在 spark 中注册聚合函数 spark.udf.register("avgAge",myAverage) spark.sql("select avgAge(age) from user").show()
  1. 实现方式 - UDAF - 强类型
scala
//输入数据类型 case class User01(username:String,age:Long) //缓存类型 case class AgeBuffer(var sum:Long,var count:Long) /** * 定义类继承 org.apache.spark.sql.expressions.Aggregator * 重写类中的方法 */ class MyAveragUDAF1 extends Aggregator[User01,AgeBuffer,Double]{ override def zero: AgeBuffer = { AgeBuffer(0L,0L) } override def reduce(b: AgeBuffer, a: User01): AgeBuffer = { b.sum = b.sum + a.age b.count = b.count + 1 b } override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = { b1.sum = b1.sum + b2.sum b1.count = b1.count + b2.count b1 } override def finish(buff: AgeBuffer): Double = { buff.sum.toDouble/buff.count } //DataSet 默认额编解码器,用于序列化,固定写法 //自定义类型就是 product 自带类型根据类型选择 override def bufferEncoder: Encoder[AgeBuffer] = { Encoders.product } override def outputEncoder: Encoder[Double] = { Encoders.scalaDouble } } 。。。 //封装为 DataSet val ds: Dataset[User01] = df.as[User01] //创建聚合函数 var myAgeUdaf1 = new MyAveragUDAF1 //将聚合函数转换为查询的列 val col: TypedColumn[User01, Double] = myAgeUdaf1.toColumn //查询 ds.select(col).show()

Spark3.0 版本可以采用强类型的 Aggregator 方式代替 UserDefinedAggregateFunction

scala
// TODO 创建 UDAF 函数 val udaf = new MyAvgAgeUDAF // TODO 注册到 SparkSQL 中 spark.udf.register("avgAge", functions.udaf(udaf)) // TODO 在 SQL 中使用聚合函数 // 定义用户的自定义聚合函数 spark.sql("select avgAge(age) from user").show // ************************************************** case class Buff( var sum:Long, var cnt:Long ) // totalage, count class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{ override def zero: Buff = Buff(0,0) override def reduce(b: Buff, a: Long): Buff = { b.sum += a b.cnt += 1 b } override def merge(b1: Buff, b2: Buff): Buff = { b1.sum += b2.sum b1.cnt += b2.cnt b1 } override def finish(reduction: Buff): Double = { reduction.sum.toDouble/reduction.cnt } override def bufferEncoder: Encoder[Buff] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }

本文作者:Dewar

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!