I am trying to perform a Fourier analysis on a randomly generated linear combination of sinusoids. I am operating on a timescale from 0 min to 860 min (i.e. a total of 861 timepoints), and my function produces a sine wave at a default sampling rate of 1 sample/min (this is intended to replicate real timecourse data that I have gathered). In order to avoid aliasing, I am interpolating my data 30X to a rate of 30 samples/min and running an FFT on the interpolated data. I am then trying to pull the peak frequencies from that FFT analysis and compare them (visually) to the original known frequencies of the randomly generated sinusoid.
This is my code:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import detrend, find_peaks
from scipy.fft import fft, rfft, fftfreq, rfftfreq, fftshift
def make_sample_sin(time_length, max_freq = 15, sampling_rate = 1, num_waves = 3):
"""
Generates randomized linear combinations of sine waves to use when testing Fourier analysis.
Parameters
----------
time_length : flt
Total time (min)
max_freq : flt, optional
Maximum frequency to randomly sample. The default is 15 cycles/min.
sampling_rate : flt, optional
Frequency of sample collection (samples/min). The default is 1/min.
num_waves : int, optional
Number of sine waves to include in the linear combination. The default is 3.
Returns
-------
sin_dict : dict
Dictionary of flattened 2D arrays of timecourse sinusoid data corresponding to each treatment
freqs_dict : dict
Dictionary of dicts containing the original frequencies of the component sine functions.
Entries are of the form amp: freq.
"""
sin_dict = dict.fromkeys(["-PGN", "1X", "10X", "100X"])
freqs_dict = dict.fromkeys(["-PGN", "1X", "10X", "100X"])
times = np.arange(time_length) / sampling_rate
# print(times)
for key in sin_dict:
sin_array = np.zeros((0, time_length))
freqs_all = [] # frequencies in cycles/min!!!
for cell in range(10):
signal = np.zeros(time_length)
freqs_cell = {}
for _ in range(num_waves):
freq = np.random.uniform(0.1, max_freq) # set random frequency
amp = np.random.uniform(0.5, 5) # set random amplitude
phase = np.random.uniform(0, 2 * np.pi) # set random phase
# add random sine wave to signal
signal += amp * np.sin(2 * np.pi * freq * times + phase)
# add frequency to list
freqs_cell[amp] = freq
# reorder freqs_cell
sorted_freqs_cell = {k: v for k, v in sorted(freqs_cell.items(), key=lambda item: item[0], reverse=True)}
# update
sin_array = np.vstack((sin_array, signal))
freqs_all.append(sorted_freqs_cell)
sin_dict[key] = sin_array
freqs_dict[key] = freqs_all
return sin_dict, freqs_dict, time_length, max_freq
def run_fft(subcluster_traces, test = False, interpolate = False, **kwargs):
"""
Runs FFT on smoothed trace data.
Parameters
----------
subcluster_traces : dict
Dictionary of ratio time trace values for cells in each subcluster within clusters for all treatments.
test : bool, optional
Whether or not the provided is test (e.g. random sinusoid) or experimental. The default is False.
interpolate : bool, optional
Whether or not to perform additional interpolation. The default is False.
**sampling_rate : flt
Frequency of sample collection (samples/min). Must be provided.
**max_freq : flt
Maximum frequency of randomly sampled sines. Must be provided if test is True.
**interp_rate : int
Degree of interpolation to perform. Must be provided if interpolate is True.
**time_length : int
Total time (min). Must be provided if test is True.
Raises
------
ValueError
Occurs if sampling_rate is not provided.
Occurs if interpolate is True and interp_rate is not provided.
Occurs if test is True and max_freq or time_length are not provided.
Occurs if sampling_rate < 2 * max_freq.
Occurs if timesteps are not uniform.
Returns
-------
data_dict : dict
Dictionary of flattened 2D arrays of timecourse data contained in subcluster_traces corresponding to each treatment
detrended_dict : dict
Dictionary of flattened 2D arrays of smoothed, detrended data corresponding to each treatment
y_dict : dict
Dictionary of lists of cell names corresponding to each treatment
fft_dict : dict
Dictionary of complex ndarrays containing the 1D n-point DFT calculated using the FFT algorithm corresponding to each treatment
fft_freq_dict : dict
Dictionary of arrays containing the DFT sample frequencies corresponding to each treatment
fft_amp_dict : dict
Dictionary of arrays containing the amplitude of the complex points in fft_dict corresponding to each treatment
fft_phase_dict : dict
Dictionary of arrays containing the phase angles of the complex points in fft_dict corresponding to each treatment
fft_peak_dict : dict
Dictionary of dfs containing the peak amplitude and the frequency at which it occurs for each cell in each treatment
sampling_rate : flt
Frequency of sample collection post-interpolation (if applicable) (1/min)
"""
# check to make sure sampling_rate is provided
if "sampling_rate" not in kwargs:
raise ValueError("sampling_rate must be provided.")
sampling_rate = kwargs["sampling_rate"] # sampling rate in samples/min
# check to make sure interp_rate is provided if interpolate is True
if interpolate:
if "interp_rate" not in kwargs:
raise ValueError("interp_rate must be provided when interpolation is True.")
interp_rate = kwargs["interp_rate"]
sampling_rate *= interp_rate # post-interpolation sampling rate in samples/min
# check to make sure max_freq and time_length are provided if test is True
if test:
if "max_freq" not in kwargs:
raise ValueError("max_freq must be provided when test is True.")
max_freq = kwargs["max_freq"]
if "time_length" not in kwargs:
raise ValueError("time_length must be provided when test is True.")
time_length = kwargs["time_length"]
# check to make sure sampling rate is >= 2 * max_freq
if sampling_rate < 2 * max_freq:
raise ValueError("sampling_rate insufficienty high to prevent aliasing. Additional interpolation required.")
# flatten trace data (each row of data_dict is a cell and each column is a timepoint)
# x_dict is times and y_dict is cells
if test is False:
x_dict, y_dict, data_dict = flatten_subcluster_traces(subcluster_traces)
else:
data_dict = subcluster_traces
treatments = data_dict.keys()
step_size = 0
detrended_dict = dict.fromkeys(treatments)
fft_dict = dict.fromkeys(treatments)
fft_freq_dict = dict.fromkeys(treatments)
fft_amp_dict = dict.fromkeys(treatments)
fft_phase_dict = dict.fromkeys(treatments)
fft_peak_dict = dict.fromkeys(treatments)
for treatment, treatment_data in data_dict.items():
print(f"Sampling rate (post-interpolation) ({treatment}): {sampling_rate}/min")
# calculate number of samples (post-interpolation)
# print(x_dict[treatment][-1])
num_times = sampling_rate * (x_dict[treatment][-1]) if test is False else sampling_rate * (time_length - 1)
num_cells = treatment_data.shape[0]
print(f"# timepoints ({treatment}): {num_times}")
print(f"# cells ({treatment}): {num_cells}")
# calculate step size
if test is False:
step_sizes = np.diff(x_dict[treatment])
if np.all(step_sizes == step_sizes[0]):
step_size = step_sizes[0] / sampling_rate # step size in minutes
else:
raise ValueError("Timesteps are not uniform.")
# print(treatment, step_size)
else:
step_size = 1 / sampling_rate # step size in minutes
print(f"Step size (post-interpolation) ({treatment}): {step_size} min")
if interpolate is True: # interpolate
times = list(range(treatment_data.shape[1])) # original times, in min
times_int = np.arange(times[0], times[-1], 1 / sampling_rate) # interpolated times, in sec
# print(times_int)
int_data = np.array([np.interp(times_int, times, row) for row in treatment_data])
# print(int_values.shape[1])
# # plot random row to check interpolation (UNCOMMENT IF DESIRED)
# rand_int = random.randint(0, len(treatment_data) - 1)
# rand_data_org = treatment_data[rand_int]
# rand_data_int = int_data[rand_int]
# fig, axs = plt.subplots(1, 2, figsize = (16, 6), sharex = True, sharey = True)
# axs[0].plot(times, rand_data_org, linewidth = 0.5, color = "blue")
# axs[0].set_xlabel('Time')
# axs[0].set_ylabel('Nuclear Relish fraction (fold change)')
# axs[1].plot(times_int, rand_data_int, linewidth = 0.5, color = "red")
# plt.suptitle(f'Comparing interpolated data ({treatment}, interp_rate = {interp_rate})')
# plt.tight_layout()
else:
int_data = treatment_data
# print(f"Step size (post-interpolation ({treatment}): {step_size} sec\n")
# print(f"Number of timepoints (post-interpolation ({treatment}): {num_times}")
freq_bins = np.arange(0, num_times // 2) * (sampling_rate / num_times)
freq_resolution = sampling_rate / num_times
print(f"Frequency bins ({treatment}): n = {len(freq_bins)} bins")
print(f"Frequency resolution ({treatment}): {freq_resolution} Hz/bin\n")
# detrend data
detrended_data = detrend(int_data, axis = 1)
detrended_dict[treatment] = detrended_data
# run rfft
fft_array = rfft(detrended_data)
fft_dict[treatment] = fft_array
# # perform Hilbert transform
# fft_array_shift = fftshift(fft_array)
# # print(detrended_data.shape[1]//2)
# fft_array_shift[len(detrended_data) // 2 :] *= 2
# get freqencies
# window_length = fft_array.shape[1]
fft_freq_array = rfftfreq(detrended_data.shape[1], d = step_size)
fft_freq_dict[treatment] = fft_freq_array
# get amplitudes and scale to signal length
fft_amp_array = np.abs(fft_array) / num_times
fft_amp_dict[treatment] = fft_amp_array
# get phase angles
fft_phase_array = np.angle(fft_array)
fft_phase_dict[treatment] = fft_phase_array
# make df of peak amplitudes and frequencies
fft_peak_amp = np.max(fft_amp_array, axis = 1)
fft_peak_idx = np.argmax(fft_amp_array, axis = 1)
# print(f"Peak indices ({treatment}): ", len(fft_peak_idx))
fft_peak_freq = [fft_freq_array[idx] for idx in fft_peak_idx]
# print(f"Peak amplitudes ({treatment}): ", len(fft_peak_amp))
# print(f"Peak frequencies ({treatment}): ", len(fft_peak_freq))
if test is False:
amp_df = pd.DataFrame(index = y_dict[treatment], columns = ["Peak amp.", "Peak freq."])
else:
amp_df = pd.DataFrame(columns = ["Peak amp.", "Peak freq."])
amp_df["Peak amp."] = fft_peak_amp
amp_df["Peak freq."] = fft_peak_freq
fft_peak_dict[treatment] = amp_df
if test is False:
return data_dict, detrended_dict, y_dict, fft_dict, fft_freq_dict, fft_amp_dict, fft_phase_dict, fft_peak_dict, sampling_rate
else:
return detrended_dict, fft_dict, fft_freq_dict, fft_amp_dict, fft_phase_dict, fft_peak_dict, sampling_rate
def plot_component_sinusoids(df_dict_smooth_org, detrended_dict, fft_dict, fft_freq_dict, fft_amp_dict, fft_phase_dict, sampling_rate, interp_rate = 1, test = False, **kwargs):
"""
Plots component sinusoids corresponding to major peaks in Fourier spectrum.
Parameters
----------
df_dict_smooth_org : dict
Dictionary containing dataframes with smoothed ratio values for PGN, 1X, 10X, and 100X
detrended_dict : dict
Dictionary of flattened 2D arrays of smoothed, detrended data corresponding to each treatment
fft_dict : dict
Dictionary of complex ndarrays containing the 1D n-point DFT calculated using the FFT algorithm corresponding to each treatment
fft_freq_dict : dict
Dictionary of arrays containing the DFT sample frequencies corresponding to each treatment
fft_amp_dict : dict
Dictionary of arrays containing the amplitude of the complex points in fft_dict corresponding to each treatment
fft_phase_dict : dict
Dictionary of arrays containing the phase angles of the complex points in fft_dict corresponding to each treatment
sampling_rate : flt
Frequency of sample collection post-interpolation (if applicable) (1/min)
test : bool, optional
Whether or not the provided is test (e.g. random sinusoid) or experimental. The default is False.
**interp_rate : int
Degree of interpolation to perform. The default is 1.
**cells_dict : dict
Dictionary of lists of cell names corresponding to each treatment. Must be provided if test is False.
**sin_freqs : dict
Dictionary of arrays containing the original frequencies of component sine functions. Must be provided if test is True.
Entries are of the form amp: freq.
Raises
------
ValueError
Occurs if test is False and cells_dict is not provided.
Occurs if test is True and sin_freqs is not provided.
Returns
-------
None.
"""
print(f"Post-interpolation sampling rate: {sampling_rate:.1f}/min")
# check to make sure cells_dict is provided if test is False
if test is False:
if "cells_dict" not in kwargs:
raise ValueError("cells_dict must be provided when test is False.")
cells_dict = kwargs["cells_dict"]
# check to make sure sin_freqs is provided if test is True
else:
if "sin_freqs" not in kwargs:
raise ValueError("sin_freqs must be provided when test is True.")
sin_freqs = kwargs["sin_freqs"]
for i, (treatment, fft_array) in enumerate(fft_dict.items()): # iterate over treatments
# retrieve frequencies and cell names
cells = cells_dict[treatment] if test is False else None
freqs = fft_freq_dict[treatment]
# retrieve original frequencies if running test
sin_treat = sin_freqs[treatment] if test is True else None
# randomly select cell to plot
cell_idx = np.random.randint(0, fft_array.shape[0])
# print(treatment, cell_idx)
cell_name = cells[cell_idx] if test is False else f"Test curve {i + 1}-{cell_idx}"
# print(treatment, cell_name)
cell_data = fft_array[cell_idx]
cell_amps = fft_amp_dict[treatment][cell_idx]
cell_angs = fft_phase_dict[treatment][cell_idx]
# retrieve peak amplitudes and frequencies
# # limit to positive frequencies (wholly real signals) - NOT NECESSARY WHEN rfft() IS USED IN run_fft()
# pos_idxs = np.where(freqs >= 0)[0]
# # print(pos_idxs)
# pos_freqs = freqs[pos_idxs]
# # print("Initial: ", type(pos_freqs))
# pos_freqs = np.array([freq / (60 * sampling_rate) for freq in pos_freqs]) # transform FFT frequencies to real frequencies (in Hz)
# # print("Final: ", type(pos_freqs))
# pos_data = cell_data[pos_idxs]
# pos_amps = cell_amps[pos_idxs]
# pos_angs = cell_angs[pos_idxs]
# # check to make sure the correct number of data points have been extracted
# if len(set([len(pos_freqs), len(pos_data), len(pos_amps), len(pos_angs)])) != 1:
# raise ValueError("Positive frequency indexing incorrect.")
# find peaks among positive frequencies
max_amp = cell_amps.max()
# nyquist = sampling_rate / 2
# threshold = max_amp / 20
# height = max_amp / 20
prom = max_amp / 10
# print(f"Threshold ({treatment}): ", threshold)
peak_idxs = find_peaks(cell_amps, prominence = prom, distance = 5)[0]
# peak_idxs = findpeaks(pos_data, method = "peakdetect")
peak_freqs = freqs[peak_idxs]
peak_amps = cell_amps[peak_idxs]
peak_angs = cell_angs[peak_idxs]
# print(f"Peak amplitudes ({treatment}): ", peak_amps)
# print(f"Peak frequencies ({treatment}): ", peak_freqs)
# sort peaks in descending order of amplitude
num_peaks = len(peak_idxs)
# print(f"Number of peaks ({treatment}): ", num_peaks)
# sort_idxs = np.argsort(peak_amps)[::-1]
# print(sort_idxs)
peak_tuples = zip(peak_freqs, peak_amps, peak_angs)
peak_dict = {idx: [freq, amp, ang] for idx, (freq, amp, ang) in enumerate(peak_tuples)}
# print(peak_dict)
sorted_dict = sorted(peak_dict.items(), key = lambda x: x[1][1], reverse=True)
# print(sorted_dict)
peak_freqs = [item[1][0] for item in sorted_dict]
peak_amps = [item[1][1] for item in sorted_dict]
peak_angs = [item[1][2] for item in sorted_dict]
# print(f"Peak amplitudes ({treatment}): ", peak_amps)
# print(f"Peak frequencies ({treatment}): ", peak_freqs)
# print(f"Peak phase angles ({treatment}): ", peak_angs)
# plot original trace on first plot
if test is False:
# initialize figure
fig, axs = plt.subplots(2, max(3, num_peaks), figsize=(8 * (1 + min(num_peaks, len(fft_array))), 16))
# print(treatment, cell_name)
trace = df_dict_smooth_org[treatment][cell_name]
else:
# initialize figure
fig, axs = plt.subplots(3, max(3, num_peaks), figsize=(8 * (1 + min(num_peaks, len(fft_array))), 24))
trace = df_dict_smooth_org[treatment][cell_idx, :]
sines = sin_treat[cell_idx]
times = np.linspace(0, len(trace) - 1, len(trace))
# print(cell_name, f" ({treatment}): ", trace)
axs[0, 0].plot(times, trace)
axs[0, 0].set_title("Original trace", fontsize = 20)
axs[0, 0].set_xlabel("Time (min)", fontsize = 15)
axs[0, 0].set_ylabel("Nuclear Relish fraction (fold change)", fontsize = 15)
# plot detrended and interpolated trace on second plot
times_detrended = np.arange(0, (detrended_dict[treatment].shape[1]) / sampling_rate, 1 / sampling_rate)
# print(f"Original times (len = {len(times)}): {times}")
# print(f"Detrended/interpolated times (len = {len(times_detrended)}): {times_detrended}")
detrended_trace = detrended_dict[treatment][cell_idx]
axs[0, 1].plot(times_detrended, detrended_trace)
axs[0, 1].axhline(y = 0, color = "grey", linestyle = "--", linewidth = 0.5)
axs[0, 1].set_title(f"Detrended and interpolated trace (interp = {interp_rate}X)", fontsize = 20)
axs[0, 1].set_xlabel("Time (min)", fontsize = 15)
axs[0, 1].set_ylabel ("Detrended nuclear Relish fraction (fold change)", fontsize = 15)
# plot original FFT amplitude on third plot
axs[0, 2].plot(freqs, cell_amps)
axs[0, 2].set_title("Fast Fourier Transform", fontsize = 20)
axs[0, 2].set_xlabel("Frequency (cycles/min)", fontsize = 15)
axs[0, 2].set_ylabel("Amplitude", fontsize = 15)
for i, peak_freq in enumerate(peak_freqs): # plot sinusoid corresponding to each of the peak frequencies
# retrieve sine curve
print(f"{treatment} ({cell_name}): Peak frequency ({i + 1}) = {peak_freq:.3e} Hz")
amp = peak_amps[i]
phase = peak_angs[i]
sine = amp * np.sin(2 * np.pi * peak_freq * times + phase)
# plot
axs[1, i].plot(times, sine)
if peak_freq < 0.01:
axs[1, i].set_title(f"{peak_freq:.3e} Hz (Amplitude = {amp:.3f})", fontsize = 20)
else:
axs[1, i].set_title(f"{peak_freq:.3f} Hz (Amplitude = {amp:.3f})", fontsize = 20)
axs[1, i].set_xlabel("Time (min)", fontsize = 15)
axs[1, i].set_ylabel("Amplitude", fontsize = 15)
sines_round = {round(amp, 3): round(freq, 3) for amp, freq in sines.items()}
freqs_round = ', '.join(map(str, sines_round.values()))
print(f"{treatment} ({cell_name}): {{original_amp: original_freq}} = {sines_round}")
print("\n")
for i, (org_amp, org_freq) in enumerate(sines.items()): # plot sinusoid corresponding to each of the original frequencis
sine = org_amp * np.sin(2 * np.pi * org_freq * times)
# plot
axs[2, i].plot(times, sine)
axs[2, i].set_title(f"Original frequency: {org_freq:.3f} 1/min", fontsize = 20)
axs[2, i].set_xlabel("Time (min)", fontsize = 15)
axs[2, i].set_ylabel("Amplitude", fontsize = 15)
if test is False:
plt.suptitle(f"{cell_name} ({treatment}): Sampling rate = {sampling_rate:.1f}/min", fontsize = 25)
else:
plt.suptitle(f"{cell_name}: Frequencies = {freqs_round} 1/min\nSampling rate = {sampling_rate:.1f}/min", fontsize = 25)
plt.tight_layout()
plt.subplots_adjust(top = 0.92)
plt.show()
# test calls
test_sin_dict, test_sin_freqs, test_sin_time_length, test_sin_max_freq = make_sample_sin(861, max_freq = 15)
test_sin_detrended_dict, test_sin_fft_dict, test_sin_fft_freq_dict, test_sin_fft_amp_dict, test_sin_fft_phase_dict, test_sin_fft_peak_dict, test_sampling_rate = run_fft(test_sin_dict, test = True, interpolate = True, sampling_rate = 1, interp_rate = 30, time_length = test_sin_time_length, max_freq = test_sin_max_freq)
plot_component_sinusoids(test_sin_dict, test_sin_detrended_dict, test_sin_fft_dict, test_sin_fft_freq_dict, test_sin_fft_amp_dict, test_sin_fft_phase_dict, test_sampling_rate, interp_rate = 30, test = True, sin_freqs = test_sin_freqs)
Please make sure test = True for all calls to functions. (There are test calls at the bottom of the code snippet which should allow you to run the functions as intended.) Certain artifacts (e.g. the case when test = False and the use of scipy.signal.detrend are there for the purposes of processing the real timecourse data).
This is an example of one of the plots produced by my function: FFT analysis of randomly generated linear combination of sinusoids
Note that the order of the subplots is as follows:
- Row 1: Original trace; interpolated trace; FFT frequency/amplitude graph
- Row 2: Component sinusoids corresponding to the peak frequencies isolated by the FFT
- Row 3: Component sinusoids corresponding to the original frequencies generated by
make_sample_sin
Since I am producing sine waves without noise, I am expecting these values to be the same; however, the calculated FFT frequencies (denoted by peak_freq in the plot_component_sinusoids function) are observably not the same as the original frequencies (denoted by org_freq in the plot_component_sinusoids function). How can I resolve this discrepancy? Or have I just failed to correctly extract the desired frequency values/made some other silly mistake?
I am relatively new to using the FFT for computational analysis, so any help would be much appreciated. Thank you!