3

I'm trying to produce a Stem plot using the 'matplotlib.pyplot.stem' function. The code works but it is taking over 5 minutes to process.

I have a similar code within Matlab that produces the same plot with the same input data almost instantly.

Is there a way to optimize this code for speed or a better function I could be using?

The arguments for the stem plot 'H' and 'plotdata' are 16384 x 1 arrays.

def stemplot():

    import numpy as np
    from scipy.fftpack import fft
    import matplotlib.pyplot as plt

    ################################################
    # Code to set up the plot data

    N=2048
    dr = 100

    k = np.arange(0,N)

    cos = np.cos
    pi = np.pi

    w = 1-1.932617*cos(2*pi*k/(N-1))+1.286133*cos(4*pi*k/(N-1))-0.387695*cos(6*pi*k/(N-1))+0.0322227*cos(8*pi*k/(N-1))

    y = np.concatenate([w, np.zeros((7*N))])

    H = abs(fft(y, axis = 0))
    H = np.fft.fftshift(H)
    H = H/max(H)
    H = 20*np.log10(H)
    H = dr+H 
    H[H < 0] = 0        # Set all negative values in dr+H to 0

    plotdata = ((np.arange(1,(8*N)+1,1))-1-4*N)/8
    #################################################

    # Plotting Code

    plt.figure
    plt.stem(plotdata,H,markerfmt = " ")

    plt.axis([(-4*N)/8, (4*N)/8, 0, dr])    
    plt.grid()
    plt.ylabel('decibels')
    plt.xlabel('DFT bins')
    plt.title('Frequency response (Flat top)')
    plt.show()


    return

Here is also the Matlab code for reference:

N=2048;
dr = 100;
k=0:N-1

w = 1 - 1.932617*cos(2*pi*k/(N-1)) + 1.286133*cos(4*pi*k/(N-1)) -0.387695*cos(6*pi*k/(N-1)) +0.0322227*cos(8*pi*k/(N-1));

H = abs(fft([w zeros(1,7*N)]));
H = fftshift(H);
H = H/max(H);
H = 20*log10(H);
H = max(0,dr+H); % Sets negative numbers in dr+H to 0


figure
stem(([1:(8*N)]-1-4*N)/8,H,'-');
set(findobj('Type','line'),'Marker','none','Color',[.871 .49 0])
xlim([-4*N 4*N]/8)
ylim([0 dr])
set(gca,'YTickLabel','-100|-90|-80|-70|-60|-50|-40|-30|-20|-10|0')
grid on
ylabel('decibels')
xlabel('DFT bins')
title('Frequency response (Flat top)')
6
  • 4
    Can you make this a minimal reproducible example? Without knowing what your w is, this code cannot be run. Commented Jan 17, 2018 at 12:26
  • Also, return is a statement, not a function. No need for the () after it. Commented Jan 17, 2018 at 12:33
  • No need for it at all in this particular case ;-) Commented Jan 17, 2018 at 12:35
  • 1
    Have a look at this and this github issue. This seems to be a known issue. Commented Jan 17, 2018 at 13:13
  • 1
    I could repeat one of the comments in the second issue: With 16384 points, are you sure that a stem plot is what you want? Commented Jan 17, 2018 at 13:15

2 Answers 2

5

You can simulate a stem plot in the format you desire using ax.vlines. Writing a small function,

def make_stem(ax, x, y, **kwargs):
    ax.axhline(x[0],x[-1],0, color='r')

    ax.vlines(x, 0, y, color='b')

    ax.set_ylim([1.05*y.min(), 1.05*y.max()])

And then altering the the relevant lines in your example as follows:

    # Plotting Code

##    plt.figure
##    plt.stem(plotdata,H,markerfmt = " ")

##    plt.axis([(-4*N)/8, (4*N)/8, 0, dr])

    fig, ax = plt.subplots()
    make_stem(ax, plotdata, H)

produces the plot more or less instantly. I don't, however, know whether this is faster or slower than the answer of @ImportanceOfBeingErnest.

Sign up to request clarification or add additional context in comments.

2 Comments

The solution from my answer below takes 0.25 seconds, while this solution takes 0.23 seconds on my computer. Not a large difference, but apparently still a bit faster.
@Thomas Thanks for your response, I tested this code and it works just as fast as the one above, either method is suitable.
2

There seems to be no need for a stem plot here, since the markers are anyway made invsibible and would not make sense due to the large number of points.

Instead the use of a LineCollection may make sense. This is how matplotlib will do it in a future version anyways - see this PR. The code below runs within 0.25 seconds for me. (This is still slightly longer than using plot, due to the large number of lines.)

import numpy as np
from scipy.fftpack import fft
import matplotlib.pyplot as plt
import time
import matplotlib.collections as mcoll

N=2048
k = np.arange(0,N)
dr = 100

cos = np.cos
pi = np.pi

w = 1-1.932617*cos(2*pi*k/(N-1))+1.286133*cos(4*pi*k/(N-1))-0.387695*cos(6*pi*k/(N-1))+0.0322227*cos(8*pi*k/(N-1))

y = np.concatenate([w, np.zeros((7*N))])

H = abs(fft(y, axis = 0))
H = np.fft.fftshift(H)
H = H/max(H)
H = 20*np.log10(H)
H = dr+H 
H[H < 0] = 0        # Set all negative values in dr+H to 0

plotdata = ((np.arange(1,(8*N)+1,1))-1-4*N)/8


lines = []
for thisx, thisy in zip(plotdata,H):
    lines.append(((thisx, 0), (thisx, thisy)))
stemlines = mcoll.LineCollection(lines, linestyles="-",
                    colors="C0", label='_nolegend_')
plt.gca().add_collection(stemlines)


plt.axis([(-4*N)/8, (4*N)/8, 0, dr])    
plt.grid()
plt.ylabel('decibels')
plt.xlabel('DFT bins')
plt.title('Frequency response (Flat top)')

plt.show()

4 Comments

You beat me once more :) I had a solution using vlines which is also very fast.
@Thomas If your solution uses something other than a LineCollection and if it is significantly faster than the 5 seconds from the question, why not post it here as well?
Done. I don't know about faster, though.
Thanks for your reply, the code works much better than the .stem method! I think initially I was just trying to translate the matlab code directly, in this instance a direct comparison was not the best solution.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.