SparkSQL——自定义函数

从Spark2.0以上的版本开始,spark是使用全新的SparkSession接口代替Spark1.6中的SQLcontext和HiveContext

来实现对数据的加载、转换、处理等工作,并且实现了SQLcontext和HiveContext的所有功能。

我们在新版本中并不需要之前那么繁琐的创建很多对象,只需要创建一个SparkSession对象即可。

SparkSession支持从不同的数据源加载数据,并把数据转换成DataFrame,并支持把DataFrame转换成SQLContext自身中的表。

然后使用SQL语句来操作数据,也提供了HiveQL以及其他依赖于Hive的功能支持。

创建SparkSession

SparkSession 是 Spark SQL 的入口。

使用 Dataset 或者 Datafram 编写 Spark SQL 应用的时候,第一个要创建的对象就是 SparkSession。

Builder 是 SparkSession 的构造器。 通过 Builder, 可以添加各种配置。

Builder 的方法如下:

Method Description getOrCreate 获取或者新建一个 sparkSession enableHiveSupport 增加支持 hive Support appName 设置 application 的名字 config 设置各种配置。


关于SparkSQL——DataFrame的创建与使用看这篇文章


虽然spark.sql.function中的已经包含了大多数常用的函数,但是总有一些场景是内置函数无法满足要求的,此时就需要使用自定义函数了(UDF)。

本文主要从以下几个方面介绍Spark中的自定义函数问题:

第一,自定义函数分类

第二,自定义函数的使用


第一,自定义函数的分类

在SparkSQL中支持自定义函数,主要可以分为以下三类:

UDF: 输入一个参数,返回一个结果, 一对一 。类似于to_char

UDTF: 输入一个参数,返回多个结果, 一对多。 spark SQL中没有UDTF(spark中用flatMap可以实现该功能)

UDAF: 输入多个参数,返回一个结果, 多对一 。 自定义聚合函数,类似于count avg sum


第二,自定义函数的使用

自定义UDF函数,以判断一个数是奇数还是偶数为例。

调用session.udf.register函数进行定义并注册函数,而在sql语句中只需调用函数名(参数)的格式进行调用。

package xxx

import java.lang

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

/**
 * 自定义UDF,判断一个数是奇数还是偶数
 */
object UDFTest {
  def main(args: Array[String]): Unit = {

    // 创建SparkSQL的入口
    val session: SparkSession = SparkSession.builder().appName("UDFTest").master("local[*]").getOrCreate()

    // 数据(1,2,....,10)默认列名为id,类型为long
    val num: Dataset[lang.Long] = session.range(1, 11)

    // 定义一个自定义函数(UDF),并注册, 在该函数Executor端执行
    session.udf.register("oddOrEven", (num: Long) =>{
      var result = "未知"
      if(num % 2 == 0){
        result = "偶数"
      }else{
        result = "奇数"
      }
      result
    })

    // 注册临时表
    num.createTempView("v_num")
    // 执行sql
    val result: DataFrame = session.sql("SELECT id, oddOrEven(id) As oddOrEven FROM v_num")

    result.show()

    session.close()
  }

}

自定义UDAF函数。以实现几何平均数为例。

几何平均数:是n个变量值连乘积的n次方根。

自定义UDAF函数需要自定义UDAF类,并且该类需要继承UserDefinedAggregateFunction类。

 /**
   * 自定义UDAF类, 实现继承UserDefinedAggregateFunction
   */
  class GeoMean extends UserDefinedAggregateFunction{
    //输入数据的名称,类型
    override def inputSchema: StructType = StructType(List(
      StructField("value", DoubleType)
    ))

    //产生中间结果的数据类型
    override def bufferSchema: StructType = StructType(List(
      //相乘之后返回的积
      StructField("product", DoubleType),
      //参与运算数字的个数
      StructField("counts", LongType)
    ))

    //最终返回的结果类型
    override def dataType: DataType = DoubleType


    //确保一致性 一般用true
    override def deterministic: Boolean = true

    //指定初始值
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //相乘的初始值
      buffer(0) = 1.0
      //参与运算数字的个数的初始值,数据类型与中间结果类型一致
      buffer(1) = 0L
    }

    //每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //每有一个数字参与运算就进行相乘(包含中间结果)
      buffer(0) = buffer.getDouble(0) * input.getDouble(0)
      //参与运算数据的个数也有更新
      buffer(1) = buffer.getLong(1) + 1L
    }

    //全局聚合 (将每个分区产生的结果进行聚合)
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      //每个分区计算的结果进行相乘
      buffer1(0) =  buffer1.getDouble(0) * buffer2.getDouble(0)
      //每个分区参与预算的中间结果进行相加
      buffer1(1) =  buffer1.getLong(1) + buffer2.getLong(1)
    }

    //计算最终的结果
    override def evaluate(buffer: Row): Double = {
      math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
    }

  }

结果:



UDAF的使用

package user.defined.function

import java.lang

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}


/**
 * 用户自定义UADF函数,实现几何平均数
 */
object UDAFTest {

  def main(args: Array[String]): Unit = {
    val session: SparkSession = SparkSession.builder().appName("UDFTest").master("local[*]").getOrCreate()

    // 生成1-10的数据
    val value: Dataset[lang.Long] = session.range(1, 11)

    // 将自定义UDAF类注册成自定义聚合函数,生成maen函数
    val geomean = new GeoMean
   session.udf.register("maen", geomean)

    // 将数据注册成视图
   value.createTempView("v_rang")

  // 在sql中调用自定义函数
  val result: DataFrame = session.sql("SELECT maen(id) AS results FROM v_rang")

    result.show()

    session.close()
  }
}

结果:


举报
评论 0