I'm doing some data processing with Spark via the Python API. Here's a simplified bit of the class I'm working with:
class data_processor(object):
def __init__(self,filepath):
self.config = Config() # this loads some config options from file
self.type_conversions = {int:IntegerType,str:StringType}
self.load_data(filepath)
self.format_age()
def load_data(self,filepath,delim='\x01'):
cols = [...] # list of column names
types = [int, str, str, ... ] # list of column types
user_data = sc.textFile(filepath,use_unicode=False).map(lambda row: [types[i](val) for i,val in enumerate(row.strip().split(delim))])
fields = StructType([StructField(field_name,self.type_conversions[field_type]()) for field_name,field_type in zip(cols,types)])
self.user_data = user_data.toDF(fields)
self.user_data.registerTempTable('data')
def format_age(self):
age_range = self.config.age_range # tuple of (age_min, age_max)
age_bins = self.config.age_bins # list of bin boundaries
def _format_age(age):
if age<age_range[0] or age>age_range[1]:
return None
else:
return np.digitize([age],age_bins)[0]
sqlContext.udf.register('format_age', lambda x: _format_age(x), IntegerType())
Now, if I instantiate the class with data=data_processor(filepath), I can do queries on the dataframe just fine. This, for examples, works:
sqlContext.sql("select * from data limit 10").take(1)
But I'm clearly not setting up the udf properly. If I try, for instance,
sqlContext.sql("select age, format_age(age) from data limit 10").take(1)
I get an error:
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
(with a long stacktrace, typical of Spark, that's too long to include here).
So, what am I doing wrong exactly? What is the proper way to define a UDF within a method like this (preferably as a class method). I know Spark doesn't like passing class objects, hence the nested structure of format_age (inspired by this question).
Ideas?