4

I'm trying to optimize some code and one of the time consuming operations is the following:

import numpy as np
survivors = np.where(a > 0)[0]
pos = len(survivors)
a[:pos] = a[survivors]
b[:pos] = b[survivors]
c[:pos] = c[survivors]

In my code a is a very large (more than 100000) NumPy array of floats. Many of them will be 0.

Is there a way to speed this up?

6
  • what is b and c are you using the indices of a!=0 to index them? Commented Aug 10, 2017 at 10:49
  • Yes. So the idea is that a, b, and c represent different characteristics of the same object. Then I want to select the objects where a >0, and get a new a, b, and c with only this objects Commented Aug 10, 2017 at 10:54
  • a[:pos] implies that everything after a[pos:] is unused and likewise with b and c is this true? Is this a genetic algorithm where you want to keep the survivors for the next gen? Commented Aug 10, 2017 at 11:00
  • Exactly. From where on I only care about a[:pos] instead of the whole a. By the way, how can I make the inline code in the comments? Commented Aug 10, 2017 at 11:01
  • So a has to stay in its original shape? Commented Aug 10, 2017 at 12:21

1 Answer 1

3

As far as I see it there's nothing that could speed it up with pure NumPy. However if you have numba you could write your own version of this "selection" using a jitted function:

import numba as nb

@nb.njit
def selection(a, b, c):
    insert_idx = 0
    for idx, item in enumerate(a):
        if item > 0:
            a[insert_idx] = a[idx]
            b[insert_idx] = b[idx]
            c[insert_idx] = c[idx]
            insert_idx += 1

In my test runs this was roughly a factor 2 faster than your NumPy code. However numba might be a heavy dependency if you're not using conda.

Example:

>>> import numpy as np
>>> a = np.array([0., 1., 2., 0.])
>>> b = np.array([1., 2., 3., 4.])
>>> c = np.array([1., 2., 3., 4.])
>>> selection(a, b, c)
>>> a, b, c
(array([ 1.,  2.,  2.,  0.]),
 array([ 2.,  3.,  3.,  4.]),
 array([ 2.,  3.,  3.,  4.]))

Timing:

It's hard to time this accuratly because all approaches work in-place, so I actually use timeit.repeat to measure the timings with a number=1 (that avoids broken timings due to the in-place-ness of the solutions) and I used the min of the resulting list of timings because that's advertised as the most useful quantitative measure in the documentation:

Note

It’s tempting to calculate mean and standard deviation from the result vector and report these. However, this is not very useful. In a typical case, the lowest value gives a lower bound for how fast your machine can run the given code snippet; higher values in the result vector are typically not caused by variability in Python’s speed, but by other processes interfering with your timing accuracy. So the min() of the result is probably the only number you should be interested in. After that, you should look at the entire vector and apply common sense rather than statistics.

Numba solution

import timeit

min(timeit.repeat("""selection(a, b, c)""",
              """import numpy as np
from __main__ import selection

a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
""", repeat=100, number=1))

0.007700118746939211

Original solution

import timeit

min(timeit.repeat("""survivors = np.where(a > 0)[0]
pos = len(survivors)
a[:pos] = a[survivors]
b[:pos] = b[survivors]
c[:pos] = c[survivors]""",
              """import numpy as np
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
""", repeat=100, number=1))

0.028622144571883723

Alexander McFarlane's solution (now deleted)

import timeit

min(timeit.repeat("""survivors = comb_array[:, 0].nonzero()[0]
comb_array[:len(survivors)] = comb_array[survivors]""",
              """import numpy as np
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()

comb_array = np.vstack([a,b,c]).T""", repeat=100, number=1))

0.058305527038669425

So the Numba solution can actually speed this up by a factor 3-4 while the solution of Alexander McFarlane is actually slower (2x) than the original approach. However the small number of repeats may bias the timings somewhat.

Sign up to request clarification or add additional context in comments.

7 Comments

@AlexanderMcFarlane I'm not so sure that your approach is working correctly. There's basically no way I could explain a 1000 times speedup over a vectorized numpy operation. I guess 2-5 times is the limit you could expect to be faster if one avoids temporary arrays or one uses a more efficient operation.
I'm going to test it now, but does your solution works if 'a', 'b' and 'c' are views on other arrays? Will they be modified as I need ? As an aside, I've been testing Numba on other bottlenecks of the code, specifically operations with numpy arrays (same shape, summing, dividing, get the array sum) and found it slower than original numpy code... Any general advice?
Yes, it will work on views (you can verify this by running selection(a[0:2], b, c) instead of selection(a, b, c) in my example.
@MSeifert thanks for looking into my approach! I've learned a fair bit about the weaknesses of %%timeit magic command with your help - I'll delete my solution as it will thoroughly confuse someone looking at it - feel free to incorporate it in your answer as an approach to avoid
@DiogoSantos It's hard to give general guidelines on the performance of numba. One rule of thump is: Write all the loops yourself and try to avoid calling "complicated" numpy functions in the numba function (e.g. advanced indexing or operations that create a temporary array). Then there are also different ways to iterate over numpy arrays, sometimes it's faster to use for element in array or for idx in range(len(array)) or even for element in np.nditer(array). One has to experiment a bit to get the fastest numba function and sometimes the fastest way depends on the numba version.
|

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.