PySpark利用udf新增一列

在PySpark中,对DataFrame新增一列有几种写法:

df=spark.createDataFrame([('p1',56),('p2',23),('p3',11),('p4',40),('p5',29)],['name','age'])
df.show()
===>>
+----+---+
|name|age|
+----+---+
|  p1| 56|
|  p2| 23|
|  p3| 11|
|  p4| 40|
|  p5| 29|
+----+---+

1、直接四则运算:

df=df.withColumn('add_column', df.age + 2)
===>>
+----+---+----------+
|name|age|add_column|
+----+---+----------+
|  p1| 56|        58|
|  p2| 23|        25|
|  p3| 11|        13|
|  p4| 40|        42|
|  p5| 29|        31|
+----+---+----------+

2、使用lambda表达式+UserDefinedFunction:

from pyspark.sql import functions as F

df=df.withColumn(
'add_column', F.UserDefinedFunction(lambda obj: int(obj)+2)(df.age))
df.show()
===>>
+----+---+----------+
|name|age|add_column|
+----+---+----------+
|  p1| 56|        58|
|  p2| 23|        25|
|  p3| 11|        13|
|  p4| 40|        42|
|  p5| 29|        31|
+----+---+----------+

3、使用udf函数:

def get_age_group(age):
    if age<=15:
        return 'Little'
    if age<=30:
        return 'Young'
    if age<=55:
        return 'Mature'
    else:
        return 'Senior'
# udf函数的两个参数,一个是函数名,一个是该函数的返回值类型
df=df.withColumn('age_group',udf(get_age_group,types.StringType())(df['age']))
df.show()
===>>
+----+---+---------+
|name|age|age_group|
+----+---+---------+
|  p1| 56|   Senior|
|  p2| 23|    Young|
|  p3| 11|   Little|
|  p4| 40|   Mature|
|  p5| 29|    Young|
+----+---+---------+

4、udf当做装饰器:

@udf(returnType=types.StringType())
def get_age_group(age):
    if age<=15:
        return 'Little'
    if age<=30:
        return 'Young'
    if age<=55:
        return 'Mature'
    else:
        return 'Senior'

df=df.withColumn('age_group',get_age_group(df['age']))
df.show()
===>>
+----+---+---------+
|name|age|age_group|
+----+---+---------+
|  p1| 56|   Senior|
|  p2| 23|    Young|
|  p3| 11|   Little|
|  p4| 40|   Mature|
|  p5| 29|    Young|
+----+---+---------+

5、返回值类型比较复杂,例如返回数组:

import random
def get_labels():
# 返回一个string类型的list labels
=['A','B','C','D','E'] random.shuffle(labels) count=random.randint(1,len(labels)-1) return labels[:count] # ArrayType代表数组型 df=df.withColumn('labels',udf(get_labels,types.ArrayType(types.StringType()))()) df.show() ===>> +----+---+------------+ |name|age| labels| +----+---+------------+ | p1| 56| [D, E]| | p2| 23| [A, E, C]| | p3| 11|[C, D, B, A]| | p4| 40|[D, A, C, B]| | p5| 29| [D, B]| +----+---+------------+
原文地址:https://www.cnblogs.com/aaronhoo/p/15471294.html