if spark >= 2.4 then you don't need UDF, check the example below-
Load the input data
val df = spark.sql(
"""
|select user_id, user_loans_arr, new_loan
|from values
| ('u1', array(named_struct('loan_date', '2019-01-01', 'loan_amount', 100)), named_struct('loan_date',
| '2020-01-01', 'loan_amount', 100)),
| ('u2', array(named_struct('loan_date', '2020-01-01', 'loan_amount', 200)), named_struct('loan_date',
| '2020-01-01', 'loan_amount', 100))
| T(user_id, user_loans_arr, new_loan)
""".stripMargin)
df.show(false)
df.printSchema()
/**
* +-------+-------------------+-----------------+
* |user_id|user_loans_arr |new_loan |
* +-------+-------------------+-----------------+
* |u1 |[[2019-01-01, 100]]|[2020-01-01, 100]|
* |u2 |[[2020-01-01, 200]]|[2020-01-01, 100]|
* +-------+-------------------+-----------------+
*
* root
* |-- user_id: string (nullable = false)
* |-- user_loans_arr: array (nullable = false)
* | |-- element: struct (containsNull = false)
* | | |-- loan_date: string (nullable = false)
* | | |-- loan_amount: integer (nullable = false)
* |-- new_loan: struct (nullable = false)
* | |-- loan_date: string (nullable = false)
* | |-- loan_amount: integer (nullable = false)
*/
Process as per below requirement
user_loans_arr and new_loan as inputs and add the new_loan struct to the existing user_loans_arr. Then, from user_loans_arr delete all the elements whose loan_date is older than 12 months.
spark >= 2.4
df.withColumn("user_loans_arr",
expr(
"""
|FILTER(array_union(user_loans_arr, array(new_loan)),
| x -> months_between(current_date(), to_date(x.loan_date)) < 12)
""".stripMargin))
.show(false)
/**
* +-------+--------------------------------------+-----------------+
* |user_id|user_loans_arr |new_loan |
* +-------+--------------------------------------+-----------------+
* |u1 |[[2020-01-01, 100]] |[2020-01-01, 100]|
* |u2 |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
* +-------+--------------------------------------+-----------------+
*/
spark < 2.4
// spark < 2.4
val outputSchema = df.schema("user_loans_arr").dataType
import java.time._
val add_and_filter = udf((userLoansArr: mutable.WrappedArray[Row], loan: Row) => {
(userLoansArr :+ loan).filter(row => {
val loanDate = LocalDate.parse(row.getAs[String]("loan_date"))
val period = Period.between(loanDate, LocalDate.now())
period.getYears * 12 + period.getMonths < 12
})
}, outputSchema)
df.withColumn("user_loans_arr", add_and_filter($"user_loans_arr", $"new_loan"))
.show(false)
/**
* +-------+--------------------------------------+-----------------+
* |user_id|user_loans_arr |new_loan |
* +-------+--------------------------------------+-----------------+
* |u1 |[[2020-01-01, 100]] |[2020-01-01, 100]|
* |u2 |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
* +-------+--------------------------------------+-----------------+
*/