1

I need to write some custum code using multiple columns within a group of my data.

My custom code is to set a flag if a value is over a threshold, but suppress the flag if it is within a certain time of a previous flag.

Here is some sample code:

df = spark.createDataFrame(
    [
        ("a", 1, 0),
        ("a", 2, 1),
        ("a", 3, 1),
        ("a", 4, 1),
        ("a", 5, 1),
        ("a", 6, 0),
        ("a", 7, 1),
        ("a", 8, 1),
        ("b", 1, 0),
        ("b", 2, 1)
    ],
    ["group_col","order_col", "flag_col"]
)
df.show()
+---------+---------+--------+
|group_col|order_col|flag_col|
+---------+---------+--------+
|        a|        1|       0|
|        a|        2|       1|
|        a|        3|       1|
|        a|        4|       1|
|        a|        5|       1|
|        a|        6|       0|
|        a|        7|       1|
|        a|        8|       1|
|        b|        1|       0|
|        b|        2|       1|
+---------+---------+--------+

from pyspark.sql.functions import udf, col, asc
from pyspark.sql.window import Window
def _suppress(dates=None, alert_flags=None, window=2):
    sup_alert_flag = alert_flag
    last_alert_date = None
    for i, alert_flag in enumerate(alert_flag):
        current_date = dates[i]
        if alert_flag == 1:
            if not last_alert_date:
                sup_alert_flag[i] = 1
                last_alert_date = current_date
            elif (current_date - last_alert_date) > window:
                sup_alert_flag[i] = 1
                last_alert_date = current_date
            else:
                sup_alert_flag[i] = 0
        else:
            alert_flag = 0
    return sup_alert_flag

suppress_udf = udf(_suppress, DoubleType())

df_out = df.withColumn("supressed_flag_col", suppress_udf(dates=col("order_col"), alert_flags=col("flag_col"), window=4).Window.partitionBy(col("group_col")).orderBy(asc("order_col")))

df_out.show()

The above fails, but my expected output is the following:

+---------+---------+--------+------------------+
|group_col|order_col|flag_col|supressed_flag_col|
+---------+---------+--------+------------------+
|        a|        1|       0|                 0|
|        a|        2|       1|                 1|
|        a|        3|       1|                 0|
|        a|        4|       1|                 0|
|        a|        5|       1|                 0|
|        a|        6|       0|                 0|
|        a|        7|       1|                 1|
|        a|        8|       1|                 0|
|        b|        1|       0|                 0|
|        b|        2|       1|                 1|
+---------+---------+--------+------------------+
2
  • can you explain what you are trying to achieve with your custom code.? Commented Mar 19, 2018 at 12:17
  • The custom code is looking through the data within a group, sorted by the order column. If the flag is 1 then it sets the output to 1 if flag has not been set for the previous n points. In my example, group a, 1st row flag=0 so output is 0, 2nd row flag=1, since no flag in previous 5 rows output=1, 3rd row flag=1, but there has been a flag=1 in the previous 5 rows so the output=0. Once 5 row have passed then next flag=1 would also have output=1. Note that I used a simple order column, but in my use case it is a date time. I hope that helps. Commented Mar 19, 2018 at 13:31

2 Answers 2

1

Editing answer after more thought.

The general problem seems to be that the result of the current row depends upon result of the previous row. In effect, there is a recurrence relationship. I haven't found a good way to implement a recursive UDF in Spark. There are several challenges that result from the assumed distributed nature of the data in Spark which would make this difficult to achieve. At least in my mind. The following solution should work but may not scale for large data sets.

from pyspark.sql import Row
import pyspark.sql.functions as F
import pyspark.sql.types as T

suppress_flag_row = Row("order_col", "flag_col", "res_flag")

def suppress_flag( date_alert_flags, window_size ):

    sorted_alerts = sorted( date_alert_flags, key=lambda x: x["order_col"])

    res_flags = []
    last_alert_date = None
    for row in sorted_alerts:
        current_date = row["order_col"]
        aflag = row["flag_col"]
        if aflag == 1 and (not last_alert_date or (current_date - last_alert_date) > window_size):
            res = suppress_flag_row(current_date, aflag, True)
            last_alert_date = current_date
        else:
            res = suppress_flag_row(current_date, aflag, False)

        res_flags.append(res)
    return res_flags

in_fields = [T.StructField("order_col", T.IntegerType(), nullable=True )]
in_fields.append( T.StructField("flag_col", T.IntegerType(), nullable=True) )

out_fields = in_fields
out_fields.append(T.StructField("res_flag", T.BooleanType(), nullable=True) )
out_schema = T.StructType(out_fields)
suppress_udf = F.udf(suppress_flag, T.ArrayType(out_schema) )

window_size = 4
tmp = df.groupBy("group_col").agg( F.collect_list( F.struct( F.col("order_col"), F.col("flag_col") ) ).alias("date_alert_flags"))
tmp2 = tmp.select(F.col("group_col"), suppress_udf(F.col("date_alert_flags"), F.lit(window_size)).alias("suppress_res"))

expand_fields = [F.col("group_col")] + [F.col("res_expand")[f.name].alias(f.name) for f in out_fields]
final_df = tmp2.select(F.col("group_col"), F.explode(F.col("suppress_res")).alias("res_expand")).select( expand_fields )
Sign up to request clarification or add additional context in comments.

1 Comment

Thanks @putnampp that seems to do the job!
0

I think, You don't need custom function for this. you can use rowsBetween option along with window to get the 5 rows range. Please check and let me know if missed something.

>>> from pyspark.sql import functions as F
>>> from pyspark.sql import Window

>>> w = Window.partitionBy('group_col').orderBy('order_col').rowsBetween(-5,-1)
>>> df = df.withColumn('supr_flag_col',F.when(F.sum('flag_col').over(w) == 0,1).otherwise(0))
>>> df.orderBy('group_col','order_col').show()
+---------+---------+--------+-------------+
|group_col|order_col|flag_col|supr_flag_col|
+---------+---------+--------+-------------+
|        a|        1|       0|            0|
|        a|        2|       1|            1|
|        a|        3|       1|            0|
|        b|        1|       0|            0|
|        b|        2|       1|            1|
+---------+---------+--------+-------------+

3 Comments

Thanks you for your help. I can use this as a work around for now, but what it doesn't do is set the output flag to true after 5 rows of the supr_flag_col not being set. I'll update the example with more rows to show what I mean.
supr_flag_col need to be set true after 5 rows or 5th row. ? your output shows, 5th row set as True.
I updated to use a window of 4. hence supr_flag_col needs to be set to 1 after 4 rows of supr_flag_col not being set. However the point here is not so much my algorithm, but rather - can one create a udf with multiple df columns as input? If not is there a work around - for example by using Vector Assembler to convert them to one column, and then indexing them back out again inside the udf)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.