2

I have an array

import numpy as np
X = np.array([[0.7513,  0.6991,    0.5472,    0.2575],
          [0.2551,  0.8909,    0.1386,    0.8407],
          [0.5060,  0.9593,    0.1493,    0.2543],
          [0.5060,  0.9593,    0.1493,    0.2543]])
y = np.array([[1,2,3,4]])

How to replace the diagonal of X with y. We can write a loop but any faster way?

3 Answers 3

3

A fast and reliable method is np.einsum:

>>> diag_view = np.einsum('ii->i', X)

This creates a view of the diagonal:

>>> diag_view
array([0.7513, 0.8909, 0.1493, 0.2543])

This view is writable:

>>> diag_view[None] = y
>>> X                                                                                                               
array([[1.    , 0.6991, 0.5472, 0.2575],                                                                            
       [0.2551, 2.    , 0.1386, 0.8407],                                                                            
       [0.506 , 0.9593, 3.    , 0.2543],                                                                            
       [0.506 , 0.9593, 0.1493, 4.    ]])                                                                           

This works for contiguous and non-contiguous arrays and is very fast:

contiguous:
loop          21.146424998732982
diag_indices  2.595232878000388
einsum        1.0271988900003635
flatten       1.5372659160002513

non contiguous:
loop          20.133818001340842
diag_indices  2.618005960001028
einsum        1.0305795049989683
Traceback (most recent call last): <- flatten does not work here
...

How does it work? Under the hood einsum does an advanced version of @Julien's trick: It adds the strides of arr:

>>> arr.strides
(3200, 16)
>>> np.einsum('ii->i', arr).strides
(3216,)

One can convince oneself that this will always work as long as arr is organized in strides, which is the case for numpy arrays.

While this use of einsum is pretty neat it is also almost impossible to find if one doesn't know. So spread the word!

Code to recreate the timings and the crash:

import numpy as np

n = 100
arr = np.zeros((n, n))
replace = np.ones(n)

def loop():
    for i in range(len(arr)):
        arr[i,i] = replace[i]

def other():
    l = len(arr)
    arr.shape = -1
    arr[::l+1] = replace
    arr.shape = l,l

def di():
    arr[np.diag_indices(arr.shape[0])] = replace

def es():
    np.einsum('ii->i', arr)[...] = replace

from timeit import timeit
print('\ncontiguous:')
print('loop         ', timeit(loop, number=1000)*1000)
print('diag_indices ', timeit(di))
print('einsum       ', timeit(es))
print('flatten      ', timeit(other))

arr = np.zeros((2*n, 2*n))[::2, ::2]
print('\nnon contiguous:')
print('loop         ', timeit(loop, number=1000)*1000)
print('diag_indices ', timeit(di))
print('einsum       ', timeit(es))
print('flatten      ', timeit(other))
Sign up to request clarification or add additional context in comments.

1 Comment

Very insightful!
3

This should be pretty fast (especially for bigger arrays, for your example it's about twice slower):

arr = np.zeros((4,4))
replace = [1,2,3,4]

l = len(arr)
arr.shape = -1
arr[::l+1] = replace
arr.shape = l,l

Test on bigger array:

n = 100
arr = np.zeros((n,n))
replace = np.ones(n)

def loop():
    for i in range(len(arr)):
        arr[i,i] = replace[i]

def other():
    l = len(arr)
    arr.shape = -1
    arr[::l+1] = replace
    arr.shape = l,l

%timeit(loop())
%timeit(other())

14.7 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.55 µs ± 24.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

2 Comments

is there any different if the replace = array([[1], [2], [3], [4 ]]) ? not a list. Thanks
I believe not, but if if you see some unexpected things you could just do replace.shape = -1 too...
2

Use diag_indices for a vectorized solution:

X[np.diag_indices(X.shape[0])] = y

array([[1.    , 0.6991, 0.5472, 0.2575],
       [0.2551, 2.    , 0.1386, 0.8407],
       [0.506 , 0.9593, 3.    , 0.2543],
       [0.506 , 0.9593, 0.1493, 4.    ]])

2 Comments

Any idea why this is slower than my hacked method? (3 times slower for n=1000...)
@Julien Fancy indexing is comparatively expensive. Plain slicing is much cheaper.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.