1

I am using an Aggregator to apply some custom merge on a DataFrame after grouping its records by their primary key:

case class Player(
  pk: String, 
  ts: String, 
  first_name: String, 
  date_of_birth: String
)

case class PlayerProcessed(
  var ts: String, 
  var first_name: String, 
  var date_of_birth: String
)

// Cutomer Aggregator -This just for the example, actual one is more complex
object BatchDedupe extends Aggregator[Player, PlayerProcessed, PlayerProcessed] {

  def zero: PlayerProcessed = PlayerProcessed("0", null, null)

  def reduce(bf: PlayerProcessed, in : Player): PlayerProcessed = {
    bf.ts = in.ts
    bf.first_name = in.first_name
    bf.date_of_birth = in.date_of_birth
    bf
  }

  def merge(bf1: PlayerProcessed, bf2: PlayerProcessed): PlayerProcessed = {
    bf1.ts = bf2.ts
    bf1.first_name = bf2.first_name
    bf1.date_of_birth = bf2.date_of_birth
    bf1
  }

  def finish(reduction: PlayerProcessed): PlayerProcessed = reduction
  def bufferEncoder: Encoder[PlayerProcessed] = Encoders.product
  def outputEncoder: Encoder[PlayerProcessed] = Encoders.product
}


val ply1 = Player("12121212121212", "10000001", "Rogger", "1980-01-02")
val ply2 = Player("12121212121212", "10000002", "Rogg", null)
val ply3 = Player("12121212121212", "10000004", null, "1985-01-02")
val ply4 = Player("12121212121212", "10000003", "Roggelio", "1982-01-02")

val seq_users = sc.parallelize(Seq(ply1, ply2, ply3, ply4)).toDF.as[Player]

val grouped = seq_users.groupByKey(_.pk)

val non_sorted = grouped.agg(BatchDedupe.toColumn.name("deduped"))
non_sorted.show(false)

This returns:

+--------------+--------------------------------+
|key           |deduped                         |
+--------------+--------------------------------+
|12121212121212|{10000003, Roggelio, 1982-01-02}|
+--------------+--------------------------------+

Now, I would like to order the records based on ts before aggregating them. From here I understand that .sortBy("ts") do not guarantee the order after the .groupByKey(_.pk). So I was trying to apply the .sortBy between the .groupByKey and the .agg

The output of the .groupByKey(_.pk) is a KeyValueGroupedDataset[String,Player], being the second element an Iterator. So to apply some sorting logic there I convert it into a Seq:

val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts))}.agg(BatchDedupe.toColumn.name("deduped"))
sorted.show(false)

However, the output of .mapGroups after adding the sorting logic is a Dataset[(String, Seq[Player])]. So when I try to invoke the .agg function on it I am getting the following exception:

Caused by: ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to $line050e0d37885948cd91f7f7dd9e3b4da9311.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Player

How could I convert back the output of my .mapGroups(...) into a KeyValueGroupedDataset[String,Player]?

I tried to cast back to Iterator as follows:

val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts).toIterator)}.agg(BatchDedupe.toColumn.name("deduped"))

But this approach produced the following exception:

UnsupportedOperationException: No Encoder found for Iterator[Player]
- field (class: "scala.collection.Iterator", name: "_2")
- root class: "scala.Tuple2"

How else can I add the sort logic between the .groupByKey and .agg methods?

5
  • The great benefit of Spark aggregators is that data is reduced (in parallel) on the mapper side before sending all individual rows over the wire. What you are trying to achieve is not going to work with aggregators... Why do you need the data sorted, is there no alternative way? E.g. min/max_by(value, timestamp) patterns? Commented Jun 27, 2022 at 15:36
  • If you're willing to shuffle all data around (without the benefit of a map side reduce) you can use the secondary sort pattern and then apply a reduce function to your partition iterator. Commented Jun 27, 2022 at 15:40
  • @Moritz Thanks for your advice. The size of the dataframe to aggregate is not that big (since comes in small batches), so shuffling could be tolerable. The required logic basically picks for each field the last value (order is defined by ts) as long as it is not null. That is why I was developing a udf Commented Jun 27, 2022 at 16:22
  • In that case you really don't have to bother about sorting. I'd recommend to timestamp every field in your aggregation buffer and simply always keep the latest non null value. alternatively you can use the build in aggregation function max_by to to the same : input.groupBy(...).agg(columns.map(name -> max_by(col(name), col("ts")))).as(encoder) Commented Jun 28, 2022 at 9:07
  • @Moritz timestamp every field separately is an option I consider, but the actual number of columns is rather big. However, would it be possible you write as an answer the second option? I believe it can work and I will not even need the Aggregator if that max_by is ignoring the null values Commented Jun 28, 2022 at 10:50

1 Answer 1

2

Based on the discussion above, the purpose of the Aggregator is to get the latest field values per Player by ts ignoring null values.

This can be achieved fairly easily aggregating all fields individually using max_by. With that there's no need for a custom Aggregator nor the mutable aggregation buffer.

import org.apache.spark.sql.functions._

val players: Dataset[Player] = ...

// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
   .filterNot(_ == "pk")
   .map(colName => expr(s"max_by($colName, if(isNotNull($colName), ts, null))").as(colName))

val aggregatedPlayers = players
   .groupBy(col("pk"))
   .agg(aggColumns.head, aggColumns.tail: _*)
   .as[Player]

On the most recent versions of Spark you can also use the build in max_by expression:

import org.apache.spark.sql.functions._

val players: Dataset[Player] = ...

// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
   .filterNot(_ == "pk")
   .map(colName => max_by(col(colName), when(col(colName).isNotNull, col("ts"))).as(colName))

val aggregatedPlayers = players
   .groupBy(col("pk"))
   .agg(aggColumns.head, aggColumns.tail: _*)
   .as[Player]
Sign up to request clarification or add additional context in comments.

5 Comments

Thanks for the answer Moritz, but I am getting: Caused by: NotSerializableException: org.apache.spark.sql.Column Serialization stack: - object not serializable (class: org.apache.spark.sql.Column, value: max_by(ts, ts) AS ts) - element of array (index: 0) - array (class [Ljava.lang.Object;, size 3) - field (class: scala.collection.mutable.ArrayBuffer, name: array, type: class [Ljava.lang.Object;) - object (class scala.collection.mutable.ArrayBuffer, ArrayBuffer(max_by(ts, ts) AS ts, max_by(first_name, ts) AS first_name, max_by(date_of_birth, ts) AS date_of_birth)) - field (class:
This looks like your code contains a closure that depends on the current scope where aggColumns is defined and attempts to serialize everything. Column not being serializable is just a symptom of the actual problem here... In any case, you could just use def aggColumns instead of val aggColumns to prevent the issue
thanks again, exception resolved and last value based on ts is selected. However null values are not ignored. For example, for the Dataset I used as example date_of_birth is "1985-01-02", but first_name is null
Strange, I don't think that's SQL conform ... anyways, you have to add a if/when clause then. I've updated above
Perfect, thanks a lot Moritz

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.