I have a PySpark DataFrame containing a collection of books where each book can have one or more titles. Each title is classed as being either an original title, OT or an alternative title, AT. For simplicity, I'm omitting other title types. My validation needs to ensure that each book has exactly one OT title can have any number of AT titles.
What I'm trying to do is clean up the data so that:
- If a book has more than one
OTtitle, keep the first and change the rest toAT - If a book has no
OTtitles, change the firstATtitle toOT
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.sql.functions import collect_list, col, struct
data = ([
(1, 'Title 1', 'OT'),
(1, 'Title 2', 'OT'),
(2, 'Title 3', 'AT'),
(2, 'Title 4', 'OT'),
(3, 'Title 5', 'AT'),
])
schema = StructType([
StructField("BookID", IntegerType(), False),
StructField("Title", StringType(), True),
StructField("Type", StringType(), True),
])
df = spark.createDataFrame(data, schema)
df = df.groupby('BookID').agg(collect_list(struct(col('Title'), col('Type'))).alias('Titles'))
display(df)
It sounds like it should be easy but I'm at a bit of a loss as to how to do it. Any help would be greatly appreciated.
I have tried using a udf like below but so far, that approach isn't working. I'm getting an error saying a lambda cannot contain assignment.
def process_titles(titles):
x = list(filter(lambda t: t.Type == 'OT', titles))[1::]
map(lambda t: t.Type = 'AT', x)
return x
process_titles_udf = udf(lambda x: process_titles(x), titles)
df = df.withColumn('test', process_titles_udf('Titles'))
where the udf returns an object of type:
titles = ArrayType(StructType([
StructField("Title", StringType(), True),
StructField("Type", StringType(), True)
]))
titleswithinprocess_titles_udf = udf(lambda x: process_titles(x), titles)?