1

I have the following python code that uses PySpark to mock a fraud detection system for credit cards:

from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, unix_timestamp
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    DoubleType,
    TimestampType,
    ArrayType,
    LongType,
)
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle

import pandas as pd


class MySP(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle):
        self._state = handle.getValueState(
            "state",
            StructType([
                StructField("locations", ArrayType(StringType())),
                StructField("timestamps", ArrayType(LongType())),
            ])
        )

    def handleInputRows(self, key, rows, timerValues):
        if not self._state.exists():
            current_state = {"locations": [], "timestamps": []}
        else:
            current_state = {"locations": self._state.get()[0], "timestamps": self._state.get()[1]}

        new_locations = []
        new_timestamps = []
        for pdf in rows:
            new_locations.extend(pdf["location"].tolist())
            new_timestamps.extend(pdf["unix_timestamp"].tolist())

        current_state["locations"].extend(new_locations)
        current_state["timestamps"].extend(new_timestamps)

      
        self._state.update((current_state["locations"], current_state["timestamps"]))

        yield pd.DataFrame()


def main():
    spark = (
        SparkSession.builder.appName("RealTimeFraudDetector")
        .config(
            "spark.sql.streaming.stateStore.providerClass",
            "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
        )
        .getOrCreate()
    )
    spark.sparkContext.setLogLevel("WARN")

    schema = StructType(
        [
            StructField("transaction_id", StringType(), True),
            StructField("card_number", StringType(), True),
            StructField("card_holder", StringType(), True),
            StructField("amount", DoubleType(), True),
            StructField("currency", StringType(), True),
            StructField("location", StringType(), True),
            StructField("timestamp", TimestampType(), True),
        ]
    )

    output_schema = StructType(
        [
            StructField("card_number", StringType(), True),
            StructField("is_fraud", StringType(), True),
            StructField("message", StringType(), True),
        ]
    )

    kafka_df = (
        spark.readStream.format("kafka")
        .option("kafka.bootstrap.servers", "broker:29092")
        .option("subscribe", "transaction")
        .load()
    )

    transaction_df = (
        kafka_df.select(from_json(col("value").cast("string"), schema).alias("data"))
        .select("data.*")
        .withColumn("unix_timestamp", unix_timestamp(col("timestamp")))
    )

    filtered_df = (
        transaction_df.withWatermark("timestamp", "10 minutes")
        .groupBy("card_number")
        .transformWithStateInPandas(
            MySP(), outputStructType=output_schema, outputMode="append", timeMode="None"
        )
    )

    query = filtered_df.writeStream.outputMode("append").format("console").start()

    query.awaitTermination()


if __name__ == "__main__":
    main()

After the first batch is processed, and it starts processing the second batch, I get the following error:

consumer  | 25/09/05 07:27:40 WARN TaskSetManager: Lost task 128.0 in stage 5.0 (TID 530) (987409f72424 executor driver): TaskKilled (Stage cancelled: Job aborted due to stage failure: Task 120 in stage 5.0 failed 1 times, most recent failure: Lost task 120.0 in stage 5.0 (TID 522) (987409f72424 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/worker.py", line 2044, in main
consumer  |     process()
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/worker.py", line 2036, in process
consumer  |     serializer.dump_stream(out_iter, outfile)
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 1236, in dump_stream
consumer  |     super().dump_stream(flatten_iterator(), stream)
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 564, in dump_stream
consumer  |     return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
consumer  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 115, in dump_stream
consumer  |     for batch in iterator:
consumer  |                  ^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 557, in init_stream_yield_batches
consumer  |     for series in iterator:
consumer  |                   ^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 1233, in flatten_iterator
consumer  |     for pdf in iter_pdf:
consumer  |                ^^^^^^^^
consumer  |   File "/app/detector.py", line 41, in handleInputRows
consumer  |     current_state = {"locations": self._state.get()[0], "timestamps": self._state.get()[1]}
consumer  |                                   ^^^^^^^^^^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/streaming/stateful_processor.py", line 62, in get
consumer  |     return self._valueStateClient.get(self._stateName)
consumer  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/streaming/value_state_client.py", line 78, in get
consumer  |     raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}")
consumer  | pyspark.errors.exceptions.base.PySparkRuntimeError: Error getting value state: couldn't introspect javabean: java.lang.IllegalArgumentException: wrong number of arguments
consumer  | 
consumer  |     at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:581)
consumer  |     at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
consumer  |     at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:532)
consumer  |     at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
consumer  |     at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601)
consumer  |     at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
consumer  |     at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:31)
consumer  |     at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask$IteratorWithMetrics.hasNext(WriteToDataSourceV2Exec.scala:545)
consumer  |     at org.apache.spark.sql.connector.write.DataWriter.writeAll(DataWriter.java:107)
consumer  |     at org.apache.spark.sql.execution.streaming.sources.PackedRowDataWriter.writeAll(PackedRowWriterFactory.scala:53)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask$.write(WriteToDataSourceV2Exec.scala:587)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.$anonfun$run$5(WriteToDataSourceV2Exec.scala:483)
consumer  |     at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1323)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.run(WriteToDataSourceV2Exec.scala:535)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.run$(WriteToDataSourceV2Exec.scala:466)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask$.run(WriteToDataSourceV2Exec.scala:584)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec.$anonfun$writeWithV2$2(WriteToDataSourceV2Exec.scala:427)
consumer  |     at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
consumer  |     at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
consumer  |     at org.apache.spark.scheduler.Task.run(Task.scala:147)
consumer  |     at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
consumer  |     at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
consumer  |     at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
consumer  |     at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
consumer  |     at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
consumer  |     at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
consumer  |     at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
consumer  |     at java.base/java.lang.Thread.run(Unknown Source)

There must be an error in how I save the data into the state, since it seems that PyArrow cannot deserialize it, but I know that I can save ArrayType of LongType and StringType from this article What could be the cause? I'm pretty new to PySpark, I was developing this project to learn how to use it, but I've been stuck on this for days trying multiple solutions to no avail.

4
  • I believe you need to return the schema you have specified: outputStructType=output_schema but you are returning: yield pd.DataFrame() Commented Sep 5 at 12:56
  • @Frank the error is on the retrieval of the state after it has been set. Plus, this is the simplified version of the code, it fails even if I don't return an empty dataframe Commented Sep 5 at 13:25
  • 1
    I think what @Frank means is that the schema of the data frame you return should match the output schema you have specified. Commented Sep 8 at 10:01
  • we have currently the exact same problem. pretty bad documented and hard to find something out. Commented Nov 12 at 13:26

1 Answer 1

0

I've resorted to implement a workaround by using two list states:

class MySP(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle):
        list_timestamp_schema = StructType([StructField("timestamp", LongType(), True)])
        list_location_schema = StructType([StructField("location", StringType(), True)])
        self._timestamp_state = handle.getListState(stateName="timestampState", schema=list_timestamp_schema)
        self._location_state = handle.getListState(stateName="locationState", schema=list_location_schema)

This way I can save and load the state without deserialization errors

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.