These are some of my notes on creating UDFs (user defined functions) in PySpark.

UDFs are super useful for anyone doing feature engineering or ETL work. They help break down the workflow by keeping your PySpark code modular. This makes it easy to perform unit testing (since you’re working with modular components that build up to the entire ETL workflow).

Here I show how to create a PySpark UDF which uses,

  1. a single column
  2. multiple columns
  3. an external library function
  4. aggregation by packing columns in to a list.

(All code shown below uses PySpark 3.0.1.)

To start, I generate some data that I’ll use to illustrate all of the UDF constructions I mentioned above.

import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
# needed for udfs
from pyspark.sql.functions import udf, lit, collect_list
from pyspark.sql.types import *

spark = SparkSession.builder.appName('udf_tutorial').getOrCreate()
N = 100
data = {'identifier': np.random.choice([1,2,3,4,5], N, replace=True),
        'var1': np.random.normal(3, 3, N),
        'var2': np.random.choice(['Y','N'], N, replace=True)}
df = pd.DataFrame(data)

This produces the following table.

identifier      var1    var2
0            2 -0.015862    N
1            5  3.792773    N
2            4  2.653766    Y
3            2 -2.231594    Y
4            3  4.900761    N
..         ...       ...  ...
95           2  5.026429    Y
96           5  3.123079    Y
97           1  1.880323    N
98           3 -6.005374    Y
99           5  1.074175    N

Before diving into the different UDFs it’s worth outlining what the workflow ill look like. Typically the process of creating/using a UDF goes something like this:

  1. Define your function in Python.
  2. Register your function with Spark and specify the return type. (See this code possible return types.)
  3. Apply your UDF to your Spark data frame.
  4. Breathe.

Single Column

Here I’m creating a UDF that takes a single column and an external parameter as arguments. First I define the function. In this example, the function returns the product between each element of the column and the parameter we provide. I’m using try/except statements for better error handling.

The lit function will be used to provide the parameter argument as a string, which means it needs to be converted to a float in order to be used for the mathematical operation.

def udf_sc(v1, a):
        a = float(a)
        return(v1 * a)

Now I register the UDF a specify the return type as FloatType().

udf_sc_reg = udf(udf_sc, FloatType())

Now I can use the UDF with the Spark data frame created above.

sdf_transformed = sdf.withColumn('var3', udf_sc_reg('var1', lit('2')))
|identifier|                var1|var2|        var3|
|         2|-0.01586150202926495|   N|-0.031723004|
|         5|  3.7927732716658635|   N|   7.5855465|
|         4|   2.653765608323036|   Y|   5.3075314|
|         2|  -2.231594107013259|   Y|   -4.463188|
|         3|   4.900761231566397|   N|    9.801522|

Multiple Columns

The approach is similar if multiple columns are needed in the UDF. Here the UDF has more complicated conditional statements.

def udf_mc(v1, v2):
        if v2 == 'N':
            a = -0.5
        elif v2 == 'Y':
            a = 0.5
            a = None
        return(v1 + a)

udf_mc_reg = udf(udf_mc, FloatType())

sdf_transformed = sdf.withColumn('var3', udf_mc_reg('var1','var2'))
|identifier|                var1|var2|      var3|
|         2|-0.01586150202926495|   N|-0.5158615|
|         5|  3.7927732716658635|   N| 3.2927732|
|         4|   2.653765608323036|   Y| 3.1537657|
|         2|  -2.231594107013259|   Y|-1.7315941|
|         3|   4.900761231566397|   N|  4.400761|

External Library

Sometimes it’s useful to use a function provided by an external library (e.g. numpy, scipy, etc). It’s pretty straightforward to do this. You just provide the function in your UDF. However, if you’re working in a distributed setup you’ll need to have the library installed on every node in the cluster (not just the main node).

Here I use the expit (inverse logit) function in the scipy library to transform the var1 column. The result is returned as an array (just a lil flex) containing both the original value and the transformed value. The type returned by expit (numpy.float64) is converted to a base Python float since it needs to match PySpark’s FloatType() type.

from scipy.special import expit

def udf_el(v1):
        inv_logit = float(expit(v1))
        return([v1, inv_logit])

udf_el_reg = udf(udf_el, ArrayType(FloatType()))

sdf_transformed = sdf.withColumn('var3', udf_el_reg('var1'))
|identifier|                var1|var2|                var3|
|         2|-0.01586150202926495|   N|[-0.015861502, 0....|
|         5|  3.7927732716658635|   N|[3.7927732, 0.977...|
|         4|   2.653765608323036|   Y|[2.6537657, 0.934...|
|         2|  -2.231594107013259|   Y|[-2.231594, 0.096...|
|         3|   4.900761231566397|   N|[4.900761, 0.9926...|

Custom Aggregate

At the time of writing this, it’s not possible to implement UDAFs (user defined aggregate functions) the way one can in Spark. The work around is to implement them by collecting all the elements of a group into a list using collect_list and then applying a UDF to each groups’ list.

In the code below I take things a little further (another flex) and return an array of different types (using StructType() and StructField()). Another option when returning an array would be to force everything to a single type and convert it back to the appropriate type when you flatten that array to a data frame later on.

def udf_ca(x):
        return([int(len(x)), float(np.mean(x))])

schema = StructType([StructField('length', IntegerType(), False),
                     StructField('mean', FloatType(), False)])

udf_ca_reg = udf(udf_ca, schema)

sdf_transformed = sdf.groupBy('identifier').agg(collect_list('var1').alias('var1_list')).select(udf_ca_reg('var1_list').alias('gpd_metrics'))
|gpd_metrics    |
|[18, 2.6149726]|
|[18, 2.82972]  |
|[23, 2.8389266]|
|[17, 2.7891133]|
|[24, 3.4916883]|
 |-- gpd_metrics: struct (nullable = true)
 |    |-- length: integer (nullable = false)
 |    |-- mean: float (nullable = false)