1

I recently decided to give matplotlib.pyplot a try, while having used gnuplot for scientific data plotting for years. I started out with simply reading a data file and plot two columns, like gnuplot would do with plot 'datafile' u 1:2. The requirements for my comfort are:

  • Skip lines beginning with a # and skip empty lines.
  • Allow arbitrary numbers of spaces between and before the actual numbers
  • allow arbitrary numbers of columns
  • be fast

Now, the following code is my solution for the problem. However, compared to gnuplot, it really is not as fast. This is a bit odd, since I read that one big advantage of py(plot/thon) over gnuplot is it's speed.

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
for line in open(datafile,'r'):
    if line and line[0] != '#':
        cols = filter(lambda x: x!='',line.split(' '))
        for index,col in enumerate(cols):
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

What would I do to make the data reading faster? I had a quick look at the csv module, but it didn't seem to be very flexible with comments in files and one still needs to iterate over all lines in the file.

1
  • It's interesting that you read that pyplot is faster, I had the opposite impression (unless you mean development speed). E.g. stackoverflow.com/questions/911655 Commented Aug 3, 2016 at 7:52

2 Answers 2

6

Since you have matplotlib installed, you must also have numpy installed. numpy.genfromtxt meets all your requirements and should be much faster than parsing the file yourself in a Python loop:

import numpy as np
import matplotlib.pyplot as plt

import textwrap
fname='/tmp/tmp.dat'
with open(fname,'w') as f:
    f.write(textwrap.dedent('''\
        id col1 col2 col3
        2010 1 2 3 4
        # Foo

        2011 5 6 7 8
        # Bar        
        # Baz
        2012 8 7 6 5
        '''))

data = np.genfromtxt(fname, 
                     comments='#',    # skip comment lines
                     dtype = None,    # guess dtype of each column
                     names=True)      # use first line as column names
print(data)
plt.plot(data['id'],data['col2'])
plt.show()
Sign up to request clarification or add additional context in comments.

9 Comments

Remove the names=True parameter to get a plain numpy array, or use names=('col1','col2',...) to supply headers. See the docs linked above for more details.
Thank you for the help. A quick tip: When one does not have column names, data[0] would be the first row, not the first column. To fix this, I used: data = np.genfromtxt(...).T which transposes the returned ndarray. However, using this solution still is much slower than gnuplot. (Which reads a 4x10000 numbers file immediately while it takes python about 1/4 s)
@janoliver: gnuplot is a specialized tool written in C, while pyplot is Python-based. matplotlib/numpy/Python is more versatile than gnuplot, but I would not assume it is faster than gnuplot in the domain of what gnuplot does.
Of course gnuplot is extremely efficient in doing what it's supposed to do. But I thought that numpy was also optimized pretty well. I also thought that most of the "famous" packages interface with compiled versions of the functions, so that loadtxt or genfromtxt would call some C program themselves to read the file into memory.
@janoliver: I generated some test data of shape (10000,4) and tried to compare the speed of genfromtxt+plt.plot and gnuplot's plot. They were both able to display a scatter plot with no discernable delay. Are you sure your Python program isn't doing something else to cause a 1/4s delay?
|
2

You really need to profile your code to find out what the bottleneck is.

Here are some micro-optimizations:

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
# use with to auto-close the file
for line in open(datafile,'r'):
    # line will never be False because it will always have at least a newline
    # maybe you mean line.rstrip()?
    # you can also try line.startswith('#') instead of line[0] != '#'
    if line and line[0] != '#':
        # not sure of the point of this
        # just line.split() will allow any number of spaces
        # if you do need it, use a list comprehension
        # cols = [col for col in line.split(' ') if col]
        # filter on a user-defined function is slow
        cols = filter(lambda x: x!='',line.split(' '))

        for index,col in enumerate(cols):
            # just made data a collections.defaultdict
            # initialized as data = defaultdict(list)
            # and you can skip this 'if' statement entirely
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

You may be able to do something like:

with open(datafile) as f:
    lines = (line.split() for line in f 
                 if line.rstrip() and not line.startswith('#'))
    data = zip(*[float(col) for col in line for line in lines])

Which will give you a list of tuples instead of an int-keyed dict of lists, but otherwise appears identical. It can be done as a one-liner but I split it up to make it a little easier to read.

2 Comments

Thank you too, for the advice on python in general.
@janoliver Glad to help. Thanks for your comment on the other answer, I didn't know about that :).

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.