SparkSQL——自定义函数
从Spark2.0以上的版本开始,spark是使用全新的SparkSession接口代替Spark1.6中的SQLcontext和HiveContext
来实现对数据的加载、转换、处理等工作,并且实现了SQLcontext和HiveContext的所有功能。
我们在新版本中并不需要之前那么繁琐的创建很多对象,只需要创建一个SparkSession对象即可。
SparkSession支持从不同的数据源加载数据,并把数据转换成DataFrame,并支持把DataFrame转换成SQLContext自身中的表。
然后使用SQL语句来操作数据,也提供了HiveQL以及其他依赖于Hive的功能支持。
创建SparkSession
SparkSession 是 Spark SQL 的入口。
使用 Dataset 或者 Datafram 编写 Spark SQL 应用的时候,第一个要创建的对象就是 SparkSession。
Builder 是 SparkSession 的构造器。 通过 Builder, 可以添加各种配置。
Builder 的方法如下:
Method Description getOrCreate 获取或者新建一个 sparkSession enableHiveSupport 增加支持 hive Support appName 设置 application 的名字 config 设置各种配置。
关于SparkSQL——DataFrame的创建与使用看这篇文章
虽然spark.sql.function中的已经包含了大多数常用的函数,但是总有一些场景是内置函数无法满足要求的,此时就需要使用自定义函数了(UDF)。
本文主要从以下几个方面介绍Spark中的自定义函数问题:
第一,自定义函数分类
第二,自定义函数的使用
第一,自定义函数的分类
在SparkSQL中支持自定义函数,主要可以分为以下三类:
UDF: 输入一个参数,返回一个结果, 一对一 。类似于to_char
UDTF: 输入一个参数,返回多个结果, 一对多。 spark SQL中没有UDTF(spark中用flatMap可以实现该功能)
UDAF: 输入多个参数,返回一个结果, 多对一 。 自定义聚合函数,类似于count avg sum
第二,自定义函数的使用
自定义UDF函数,以判断一个数是奇数还是偶数为例。
调用session.udf.register函数进行定义并注册函数,而在sql语句中只需调用函数名(参数)的格式进行调用。
package xxx
import java.lang
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
/**
* 自定义UDF,判断一个数是奇数还是偶数
*/
object UDFTest {
def main(args: Array[String]): Unit = {
// 创建SparkSQL的入口
val session: SparkSession = SparkSession.builder().appName("UDFTest").master("local[*]").getOrCreate()
// 数据(1,2,....,10)默认列名为id,类型为long
val num: Dataset[lang.Long] = session.range(1, 11)
// 定义一个自定义函数(UDF),并注册, 在该函数Executor端执行
session.udf.register("oddOrEven", (num: Long) =>{
var result = "未知"
if(num % 2 == 0){
result = "偶数"
}else{
result = "奇数"
}
result
})
// 注册临时表
num.createTempView("v_num")
// 执行sql
val result: DataFrame = session.sql("SELECT id, oddOrEven(id) As oddOrEven FROM v_num")
result.show()
session.close()
}
}
自定义UDAF函数。以实现几何平均数为例。
几何平均数:是n个变量值连乘积的n次方根。
自定义UDAF函数需要自定义UDAF类,并且该类需要继承UserDefinedAggregateFunction类。
/**
* 自定义UDAF类, 实现继承UserDefinedAggregateFunction
*/
class GeoMean extends UserDefinedAggregateFunction{
//输入数据的名称,类型
override def inputSchema: StructType = StructType(List(
StructField("value", DoubleType)
))
//产生中间结果的数据类型
override def bufferSchema: StructType = StructType(List(
//相乘之后返回的积
StructField("product", DoubleType),
//参与运算数字的个数
StructField("counts", LongType)
))
//最终返回的结果类型
override def dataType: DataType = DoubleType
//确保一致性 一般用true
override def deterministic: Boolean = true
//指定初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//相乘的初始值
buffer(0) = 1.0
//参与运算数字的个数的初始值,数据类型与中间结果类型一致
buffer(1) = 0L
}
//每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//每有一个数字参与运算就进行相乘(包含中间结果)
buffer(0) = buffer.getDouble(0) * input.getDouble(0)
//参与运算数据的个数也有更新
buffer(1) = buffer.getLong(1) + 1L
}
//全局聚合 (将每个分区产生的结果进行聚合)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//每个分区计算的结果进行相乘
buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
//每个分区参与预算的中间结果进行相加
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终的结果
override def evaluate(buffer: Row): Double = {
math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
}
}
结果:
UDAF的使用
package user.defined.function
import java.lang
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* 用户自定义UADF函数,实现几何平均数
*/
object UDAFTest {
def main(args: Array[String]): Unit = {
val session: SparkSession = SparkSession.builder().appName("UDFTest").master("local[*]").getOrCreate()
// 生成1-10的数据
val value: Dataset[lang.Long] = session.range(1, 11)
// 将自定义UDAF类注册成自定义聚合函数,生成maen函数
val geomean = new GeoMean
session.udf.register("maen", geomean)
// 将数据注册成视图
value.createTempView("v_rang")
// 在sql中调用自定义函数
val result: DataFrame = session.sql("SELECT maen(id) AS results FROM v_rang")
result.show()
session.close()
}
}
结果:
请先 后发表评论~