I am trying to calculate the hours after sunrise over a data array that has a length of ca. 300k (chunk size ca. 900). The resulted array is a dask.array with no problem using xr.apply_ufunc and astroplan functions. However, it appears to be extremely slow when I use this dask.array for filtering data using xr.where(). Where can I improve it?
Here is my workstream:
from astropy.time import Time
from astroplan import Observer
import astropy.units as u
import xarray as xr
import numpy as np
def cal_sunrise_h(lat, lon, mjd):
points = Observer(longitude=lon*u.deg, latitude=lat*u.deg, elevation=89*u.km)
times = Time(mjd, format='mjd')
sunrise = points.sun_rise_time(times, which="previous")
hours_after_sunrise = (times-sunrise).sec/3600
return hours_after_sunrise
# some fake dataset for reproducing the problem
total_len = 300000
chunk_size = 900
mjd = np.linspace(0, 0.1, total_len) + 5.45559e4
latitude = xr.DataArray(np.linspace(-80, 80, total_len), dims='mjd', coords=[mjd])
longitude = xr.DataArray(np.linspace(-180, 180, total_len), dims='mjd', coords=[mjd])
ds = xr.Dataset({'latitude':latitude, 'longitude':longitude}).chunk({'mjd': chunk_size})
# calculate hours after sunrise
hours_after_sunrise = xr.apply_ufunc(cal_sunrise_h, ds.latitude, ds.longitude, ds.mjd,
output_dtypes=[float], dask='parallelized') #dask.array
# make a filter
sunrise_filter = (hours_after_sunrise>5) #dask.array
# mask out with filter
ds.where(sunrise_filter, drop=True) #super slow!