1

I have a spark dataframe that looks like this:

import pandas as pd
dt = pd.DataFrame({'id': ['a','a','a','a','a','a','b','b','b','b'], 
                   'delta': [1,2,3,4,5,6,7,8,9,10],
                   'pos': [2,0,0,1,2,1,2,0,0,1],
                   'index': [1,2,3,4,5,6,1,2,3,4]})

I would like to sum the deltas from pos==2 until pos==1, for all the times that this occurs, by id

So I would like a column to the spark dataframe that will look like this:

[6, 0, 0, 0, 4, 0, 24, 0, 0, 0]

Explanation of result:

  • 6 -> for id 'a', find the first pos==2 and sum all the deltas until (not including) the next pos==1, so 1+2+3 =6
  • 0 this is not in pos==2
  • 0 this is not in pos==2
  • 0 this is not in pos==2
  • 4 -> for id 'a', find the next pos==2 and sum all the deltas until (not including) the next pos==1, so just 4
  • 21 - > for id 'b', find the first pos==2 and sum all the deltas until (not including) the next pos==1, so 7+8+9=24

Any ideas how I can do that efficiently in pyspark ?

EDIT

The dataframe is order by index and id

4
  • 1
    window functions? Commented Jul 10, 2020 at 13:23
  • @chlebek I am confused on how to use the window function on this context Commented Jul 10, 2020 at 13:36
  • Can i order the dataframe with the delta column? Commented Jul 10, 2020 at 14:03
  • @cronoik I edited my question. the dataframe is ordered by index and id Commented Jul 10, 2020 at 14:20

1 Answer 1

3

As already mentioned in the comments you are looking for a way to apply a window function. You need to create an addtional id which allows you to partition your data to calculate the sum of your delta row. PLease have a look at the commented example below:

import pyspark.sql.functions as F
from pyspark.sql import Window

df = spark.createDataFrame([
                            ('a', 1, 2, 1),
                            ('a', 2, 0, 2),
                            ('a', 3, 0, 3),
                            ('a', 4, 1, 4),
                            ('a', 5, 2, 5),
                            ('a', 6, 1, 6),
                            ('b', 7, 2, 1),
                            ('b', 8, 0, 2),
                            ('b', 9, 0, 3),
                            ('b',10,1,4)
                        ],
                        ['id', 'delta', 'pos', 'index']
                    )

w1 = Window.partitionBy('id').orderBy('index')
w2 = Window.partitionBy('id', 'subgroup2').orderBy(F.desc('index'))

#apply a subgroup ID to each pos == 2 value
df = df.withColumn("subgroup", F.when(F.col('pos') == 2, F.monotonically_increasing_id()))
df.show()

#forward-fill the subgroup ID to following rows except of rows containing with pos == 1
df = df.withColumn('subgroup2', F.when(F.col('pos') != 1, F.last('subgroup', True).over(w1.rowsBetween(Window.unboundedPreceding,0))))
df.show()

#calculate the sum for each subgroup ID
df = df.withColumn('deltaSum', F.when(F.col('pos') == 2, F.sum('delta').over(w2)))
df.sort('id', 'index').show()

Output:

+---+-----+---+-----+----------+
| id|delta|pos|index|  subgroup|
+---+-----+---+-----+----------+
|  a|    1|  2|    1|         0|
|  a|    2|  0|    2|      null|
|  a|    3|  0|    3|      null|
|  a|    4|  1|    4|      null|
|  a|    5|  2|    5|         1|
|  a|    6|  1|    6|      null|
|  b|    7|  2|    1|8589934592|
|  b|    8|  0|    2|      null|
|  b|    9|  0|    3|      null|
|  b|   10|  1|    4|      null|
+---+-----+---+-----+----------+

+---+-----+---+-----+----------+----------+
| id|delta|pos|index|  subgroup| subgroup2|
+---+-----+---+-----+----------+----------+
|  b|    7|  2|    1|8589934592|8589934592|
|  b|    8|  0|    2|      null|8589934592|
|  b|    9|  0|    3|      null|8589934592|
|  b|   10|  1|    4|      null|      null|
|  a|    1|  2|    1|         0|         0|
|  a|    2|  0|    2|      null|         0|
|  a|    3|  0|    3|      null|         0|
|  a|    4|  1|    4|      null|      null|
|  a|    5|  2|    5|         1|         1|
|  a|    6|  1|    6|      null|      null|
+---+-----+---+-----+----------+----------+

+---+-----+---+-----+----------+----------+--------+
| id|delta|pos|index|  subgroup| subgroup2|deltaSum|
+---+-----+---+-----+----------+----------+--------+
|  a|    1|  2|    1|         0|         0|       6|
|  a|    2|  0|    2|      null|         0|    null|
|  a|    3|  0|    3|      null|         0|    null|
|  a|    4|  1|    4|      null|      null|    null|
|  a|    5|  2|    5|         1|         1|       5|
|  a|    6|  1|    6|      null|      null|    null|
|  b|    7|  2|    1|8589934592|8589934592|      24|
|  b|    8|  0|    2|      null|8589934592|    null|
|  b|    9|  0|    3|      null|8589934592|    null|
|  b|   10|  1|    4|      null|      null|    null|
+---+-----+---+-----+----------+----------+--------+
Sign up to request clarification or add additional context in comments.

2 Comments

awesome solution!! i am slightly confused in one scenario. What if there is no matching 1 for a 2 in a group? for example , if i change the last row in the sample dataframe to ```` b 10 9 4 ````. I see that it still gives 34. Say if the sum must happen only if there is a matching 1- is there a simple way?
Yes there is. You can simply forward fill also the rows which contain 1 (just quick and dirty (there are plenty of ways) after you have calculated delatsum) and check if each partition contains 2 and 1 --> if not deltasum = null.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.