I'm currently learning UDFs and wrote the PostgreSQL UDF below to calculate the mean average deviation (MAD). It is the average absolute difference between the mean and the current value over any window. In python pandas/numpy, to find the MAD, we could write something like this:
series_mad = abs(series - series.mean()).mean()
Where series is a set of numbers and series_mad is a single numeric value representing the MAD of the series.
I'm trying to write this in PostgreSQL using Windows and UDF. So far, this is what I have:
CREATE TYPE misc_tuple AS (
arr_store numeric[],
ma_period integer
);
CREATE OR REPLACE FUNCTION mad_func(prev misc_tuple, curr numeric, ma_period integer)
RETURNS misc_tuple AS $$
BEGIN
IF curr is null THEN
RETURN (null::numeric[], -1);
ELSEIF prev.arr_store is null THEN
RETURN (ARRAY[curr]::numeric[], ma_period);
ELSE
-- accumulate new values in array
prev.arr_store := array_append(prev.arr_store, curr);
RETURN prev;
END IF;
END;
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION mad_final(prev misc_tuple)
RETURNS numeric AS $$
DECLARE
total_len integer;
count numeric;
mad_val numeric;
mean_val numeric;
BEGIN
count := 0;
mad_val := 0;
mean_val := 0;
total_len := array_length(prev.arr_store, 1);
-- first loop to find the mean of the series
FOR i IN greatest(1,total_len-prev.ma_period+1)..total_len
LOOP
mean_val := mean_val + prev.arr_store[i];
count := count + 1;
END LOOP;
mean_val := mean_val/NULLIF(count,0);
-- second loop to subtract mean from each value
FOR i IN greatest(1,total_len-prev.ma_period+1)..total_len
LOOP
mad_val := mad_val + abs(prev.arr_store[i]-mean_val);
END LOOP;
RETURN mad_val/NULLIF(count, 0);
END;
$$ LANGUAGE plpgsql;
CREATE OR REPLACE AGGREGATE mad(numeric, integer) (
SFUNC = mad_func,
STYPE = misc_tuple,
FINALFUNC = mad_final
);
This is how I'm testing the performance:
-- find rolling 12-period MAD
SELECT x,
mad(x, 12) OVER (ROWS 12-1 PRECEDING)
FROM generate_series(0,1000000) as g(x);
Currently, it takes ~45-50 secs on my desktop (i5 4670, 3.4 GHz, 16 GB RAM). I'm still learning UDFs, so I'm not sure what else I could do to my function to make it faster. I have a few other similar UDFs - but ones which don't use arrays and they take <15 secs on the same 1m rows. My guess is maybe I'm not efficiently looping the arrays or something could be done about the 2 loops in the UDF.
Are there any changes I can make to this UDF to make it faster?
generate_seriesalready does that for you. The problem I'm trying to solve is to speed up the query I have mentioned in my post. I've also mentioned what the UDF calculates and my assumptions on why the code is slow. Since I'm just learning UDFs I was not sure what other techniques could be used to speed up my code and asked for suggestions - that is the question. Expected results are the same as what my code outputs but with much faster runtime.