3

I have two 2D numpy square arrays, A and B. B is an array extracted from A where a certain number of columns and rows (with the same indices) have been stripped. Both of them are symmetric. For instance, A and B could be:

A = np.array([[1,2,3,4,5],
              [2,7,8,9,10],
              [3,8,13,14,15],
              [4,9,14,19,20],
              [5,10,15,20,25]])
B = np.array([[1,3,5],
              [3,13,15],
              [5,15,25]])

such that the missing indices are [1,3] and intersecting indices are [0,2,4].

Is there a "smart" way to extract the indices in A corresponding to the rows/columns present in B that involves advanced indexing and such? All I could come up with was:

        import numpy as np
        index = np.array([],dtype=int)
        n,m = len(A),len(B)
        for j in range(n):
            k = 0
            while set(np.intersect1d(B[j],A[k])) != set(B[j]) and k<m:
                k+=1
            np.append(index,k)

which I'm aware is slow and resource-consuming when dealing with large arrays.

Thank you!

Edit: I did find a smarter way. I extract the diagonal from both arrays and perform the aforementioned loop on it with a simple equality check:

        index = []
        a = np.diag(A)
        b = np.diag(B)
        for j in range(len(b)):
            k = 0
            while a[j+k] != b[j] and k<n:
                k+=1
            index.append(k+j)

Although it still doesn't use advanced indexing and still iterates over a potentially long list, this partial solution looks cleaner and I'm going to stick with it for the time being.

8
  • This should be similar to the matlab function ismember. Have a look at this SO. Commented Mar 19, 2015 at 16:27
  • Are there repeated values in A? Is there always one unique answer? (What if A is all ones, for example?) Commented Mar 19, 2015 at 16:31
  • @unutbu Yes there can be repeated values in A, and yes the answer should be unique. Both matrices should be large enough so that all column or row vectors are different. Commented Mar 19, 2015 at 16:36
  • @Bort Would I still be able to use np.searchsorted over a 2D array? Neither the columns nor the rows in B are equal to that of A. Commented Mar 19, 2015 at 16:48
  • How large is A? (The fastest method may depend on the size.) Commented Mar 19, 2015 at 16:49

1 Answer 1

2

Consider the easy case when all the values are distinct:

A = np.arange(25).reshape(5,5)
ans = [1,3,4]
B = A[np.ix_(ans, ans)]

In [287]: A
Out[287]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])

In [288]: B
Out[288]: 
array([[ 6,  8,  9],
       [16, 18, 19],
       [21, 23, 24]])

If we test the first row of B with each row of A, we will eventually come to the comparison of [6, 8, 9] with [5, 6, 7, 8, 9] from which we can glean the candidate solution of indices [1, 3, 4].

We can generate a set of all possible candidate solutions by pairing the first row of B with each row of A.

If there is only one candidate, then we are done, since we are given that B is a submatrix of A and therefore there is always a solution.

If there is more than one candidate, then we can do the same thing with the second row of B, and take the intersection of the candidate solutions -- After all, a solution must be a solution for each and every row of B.

Thus we can loop through the rows of B and short-circuit once we find there is only one candidate. Again, we are assuming that B is always a submatrix of A.

The find_idx function below implements the idea described above:

import itertools as IT
import numpy as np

def find_idx_1d(rowA, rowB):
    result = []
    if np.in1d(rowB, rowA).all():
        result = [tuple(sorted(idx)) 
                  for idx in IT.product(*[np.where(rowA==b)[0] for b in rowB])]
    return result

def find_idx(A, B):
    candidates = set([idx for row in A for idx in find_idx_1d(row, B[0])])
    for Bi in B[1:]:
        if len(candidates) == 1:
            # stop when there is a unique candidate
            return candidates.pop()
        new = [idx for row in A for idx in find_idx_1d(row, Bi)]  
        candidates = candidates.intersection(new)
    if candidates:
        return candidates.pop()
    raise ValueError('no solution found')

Correctness: The two solutions you've proposed may not always return the correct result, particularly when there are repeated values. For example,

def is_solution(A, B, idx):
    return np.allclose(A[np.ix_(idx, idx)], B)

def find_idx_orig(A, B):
    index = []
    for j in range(len(B)):
        k = 0
        while k<len(A) and set(np.intersect1d(B[j],A[k])) != set(B[j]):
            k+=1
        index.append(k)
    return index

def find_idx_diag(A, B):
    index = []
    a = np.diag(A)
    b = np.diag(B)
    for j in range(len(b)):
        k = 0
        while a[j+k] != b[j] and k<len(A):
            k+=1
        index.append(k+j)
    return index

def counterexample():
    """
    Show find_idx_diag, find_idx_orig may not return the correct result
    """
    A = np.array([[1,2,0],
                  [2,1,0],
                  [0,0,1]])
    ans = [0,1]
    B = A[np.ix_(ans, ans)]
    assert not is_solution(A, B, find_idx_orig(A, B))
    assert is_solution(A, B, find_idx(A, B))

    A = np.array([[1,2,0],
                  [2,1,0],
                  [0,0,1]])
    ans = [1,2]
    B = A[np.ix_(ans, ans)]

    assert not is_solution(A, B, find_idx_diag(A, B))
    assert is_solution(A, B, find_idx(A, B))

counterexample()

Benchmark: Ignoring at our peril the issue of correctness, out of curiosity let's compare these functions on the basis of speed.

def make_AB(n, m):
    A = symmetrize(np.random.random((n, n)))
    ans = np.sort(np.random.choice(n, m, replace=False))
    B = A[np.ix_(ans, ans)]
    return A, B

def symmetrize(a):
    "http://stackoverflow.com/a/2573982/190597 (EOL)"
    return a + a.T - np.diag(a.diagonal())

if __name__ == '__main__':
    counterexample()
    A, B = make_AB(500, 450)
    assert is_solution(A, B, find_idx(A, B))

In [283]: %timeit find_idx(A, B)
10 loops, best of 3: 74 ms per loop

In [284]: %timeit find_idx_orig(A, B)
1 loops, best of 3: 14.5 s per loop

In [285]: %timeit find_idx_diag(A, B)
100 loops, best of 3: 2.93 ms per loop

So find_idx is much faster than find_idx_orig, but not as fast as find_idx_diag.

Sign up to request clarification or add additional context in comments.

1 Comment

Thanks for the detailed answer! Your solution is only marginally slower than mine but always correct, so I'll switch to it.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.