I need to permute elements of columns in the matrix A (3D matrix by axis 0) by 2D permutation matrix pi obtained from argsort, that contains new indices for all columns.
By application permutation matrix pi on the matrix A (A[pi]) I will get a 4D matrix with new shape. For example, the shape of A is (2,3,4) and the shape of A[pi] is (2,3,3,4).
I am able to extract the required sorted matrix from A[pi] using the command:
swapaxes (diagonal(A[pi], axis1=2, axis2=1),1,2)
But it seems to be too complicated and slow.
Is there another elegant solution?
Example:
print(A)
[[[ 73 701 2411 2414]
[ 5515 8292 8414 16135]
[ 100 1241 2146 2931]]
[[ 1335 1747 3418 6312]
[ 3788 5449 5753 9738]
[ 565 3038 3800 5430]]]
pi=argsort(Norm_order(A),0)
print(pi)
[[1, 0, 1],
[0, 1, 0]]
print(swapaxes(diagonal(A[pi],axis1=2,axis2=1),1,2))
[[[ 1335 1747 3418 6312]
[ 5515 8292 8414 16135]
[ 565 3038 3800 5430]]
[[ 73 701 2411 2414]
[ 3788 5449 5753 9738]
[ 100 1241 2146 2931]]]