2

I'm filling a 3D-array with a function that depends on the values of other 1D-arrays as illustrated in the code below. The code involving my real data takes forever because the length of my 1d-arrays (and hence my 3D-array) is of order 1 million. Is there any way to do this much faster, for example without using loops in python ?

An idea that might seem stupid but I'm still wondering if it wouldn't be faster to fill in this object importing code in C++ in my program... I'm new to C++ so I did not try it out.

import numpy as np
import time

start_time = time.time()
kx = np.linspace(0,400,100)
ky = np.linspace(0,400,100)
kz = np.linspace(0,400,100)

Kh = np.empty((len(kx),len(ky),len(kz)))

for i in range(len(kx)):
    for j in range(len(ky)):
        for k in range(len(kz)):
            if np.sqrt(kx[i]**2+ky[j]**2) != 0:
                Kh[i][j][k] = np.sqrt(kx[i]**2+ky[j]**2+kz[k]**2)
            else:
                Kh[i][j][k] = 1


print('Finished in %s seconds' % (time.time() - start_time))

2 Answers 2

1

You can use the @njit decorator from numba, a high peformance JIT compiler. It reduces the time by more than one order of magnitude. Below is the comparison and the code. It is as simple as importing njit and then just using @njit as the decorator to your function. This is the official website.

I also computed time for 1000*1000*1000 data points using njit and it took just 17.856173038482666 seconds. Using parallel version as @njit(parallel=True) further reduces the time to 9.36257791519165 seconds. Doing the same with normal function would take several minutes.

I also did some time comparison of njit and the matrix operation as suggested by @Bily in his answer below. While the times are comparable for number of points up to 700, the njit method clearly wins for larger number of points > 700 as you can see in the figure below.

import numpy as np
import time
from numba import njit

kx = np.linspace(0,400,100)
ky = np.linspace(0,400,100)
kz = np.linspace(0,400,100)

Kh = np.empty((len(kx),len(ky),len(kz)))

@njit  # <----- Decorating your function here
def func_njit(kx, ky, kz, Kh):
    for i in range(len(kx)):
        for j in range(len(ky)):
            for k in range(len(kz)):
                if np.sqrt(kx[i]**2+ky[j]**2) != 0:
                    Kh[i][j][k] = np.sqrt(kx[i]**2+ky[j]**2+kz[k]**2)
                else:
                    Kh[i][j][k] = 1
    return Kh                

start_time = time.time()
Kh = func_njit(kx, ky, kz, Kh)
print('NJIT Finished in %s seconds' % (time.time() - start_time))

def func_normal(kx, ky, kz, Kh):
    for i in range(len(kx)):
        for j in range(len(ky)):
            for k in range(len(kz)):
                if np.sqrt(kx[i]**2+ky[j]**2) != 0:
                    Kh[i][j][k] = np.sqrt(kx[i]**2+ky[j]**2+kz[k]**2)
                else:
                    Kh[i][j][k] = 1
    return Kh 

start_time = time.time()
Kh = func_normal(kx, ky, kz, Kh)
print('Normal function Finished in %s seconds' % (time.time() - start_time))

NJIT Finished in 0.36797094345092773 seconds
Normal function Finished in 5.540749788284302 seconds

Comparison of njit and the matrix method

enter image description here

Sign up to request clarification or add additional context in comments.

5 Comments

woah, this is clearly the way to go ! I just tried it out on a reduced portion of my data and it works really well, thanks a lot =)
Glad to help. I recently came across it and found it very useful. Check my recent edit to the answer with parallel=True. It speeds up even further
nice one. Anyway, this technique of measuring elapsed time is odd (at least you should restart start_time once the first method finishes).
@New2Python: Thanks for the heads up. I will update the time thing soon :)
@Bazingaa You are likely measuring the compilation time in your example. (If you add a call to Kh = func_njit(kx, ky, kz, Kh) before the time measurement you will get runtime of the function. Also you can avoid compilation overhead on starting a new interpreter instance by adding (cache=True) to the njit decorator. But this won't work on a parallel version.
0

A basic rule for using numpy is: use matrix operation instead of for loops whenever possible.

import numpy as np
import time

kx = np.linspace(0,400,100)
ky = np.linspace(0,400,100)
kz = np.linspace(0,400,100)
Kh = np.empty((len(kx),len(ky),len(kz)))

def func_matrix_operation(kx, ky, kz, _):
  kx_ = np.expand_dims(kx ** 2, 1) # shape: (100, 1)
  ky_ = np.expand_dims(ky ** 2, 0) # shape: (1, 100)
  # Make use of broadcasting such that kxy[i, j] = kx[i] ** 2 + ky[j] ** 2     
  kxy = kx_ + ky_ # shape: (100, 100)
  kxy_ = np.expand_dims(kxy, 2) # shape: (100, 100, 1)
  kz_ = np.reshape(kz ** 2, (1, 1, len(kz))) # shape: (1, 1, 100)
  kxyz = kxy_ + kz_ # kxyz[i, j, k] = kx[i] ** 2 + ky[j] ** 2 + kz[k] ** 2

  kh = np.sqrt(kxyz)
  kh[kxy == 0] = 1
  return kh

start_time = time.time()
Kh1 = func_matrix_operation(kx, ky, kz, Kh)
print('Matrix operation Finished in %s seconds' % (time.time() - start_time))

def func_normal(kx, ky, kz, Kh):
  for i in range(len(kx)):
    for j in range(len(ky)):
      for k in range(len(kz)):
        if np.sqrt(kx[i] ** 2 + ky[j] ** 2) != 0:
          Kh[i][j][k] = np.sqrt(kx[i] ** 2 + ky[j] ** 2 + kz[k] ** 2)
        else:
          Kh[i][j][k] = 1
  return Kh

start_time = time.time()
Kh2 = func_normal(kx, ky, kz, Kh)
print('Normal function Finished in %s seconds' % (time.time() - start_time))

assert np.array_equal(Kh1, Kh2)

The output is:

Matrix operation Finished in 0.018651008606 seconds
Normal function Finished in 5.78078794479 seconds

3 Comments

Nice solution. Perhaps you can explain briefly how expand_dims works in terms of arguments like 1, 0 and -1 and what you are doing with np.reshape an explain its argument (1, 1, len(kz)). Currently it's not so clear
I tried your method and found that for 1000*1000*1000 data points, your code took 54 seconds and njit took 15 seconds approximately. For 100*100*100, your code was much quicker. I will do further time analysis and plot some results in the answer
@Bazingaa The code is updated with more comments, please check if it is clear now. numba is generally faster than numpy code. However, in my computer, the performance gap is not as large as yours. It is 10s (numpy) vs 8s (numba) with 1000 data points and 29s vs 24s with 1400 data points.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.