I have two arrays, one of the shape (200000, 28, 28) and the other of the shape (10000, 28, 28), so practically two arrays with matrices as elements.
Now I want to count and get all the elements (in the form (N, 28, 28)), that overlap in both arrays. With normal for loops it is way to slow, so I tryied it with numpys intersect1d method, but I dont know how to apply it on this types of arrays.
-
Yeah, sorry, I will edit my question to make that clearLuca Thiede– Luca Thiede2017-01-01 15:43:56 +00:00Commented Jan 1, 2017 at 15:43
-
Are those numbers in the arrays in some interval maybe?Divakar– Divakar2017-01-01 15:53:47 +00:00Commented Jan 1, 2017 at 15:53
-
Yeah, they range from 0 to 255Luca Thiede– Luca Thiede2017-01-01 16:00:29 +00:00Commented Jan 1, 2017 at 16:00
-
You can use the approach in this questionEric– Eric2017-01-01 17:10:28 +00:00Commented Jan 1, 2017 at 17:10
-
1Basically this is a variation on the common 'unique row` question.hpaulj– hpaulj2017-01-01 17:13:34 +00:00Commented Jan 1, 2017 at 17:13
Add a comment
|
1 Answer
Using the approach from this question about unique rows
def intersect_along_first_axis(a, b):
# check that casting to void will create equal size elements
assert a.shape[1:] == b.shape[1:]
assert a.dtype == b.dtype
# compute dtypes
void_dt = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
orig_dt = np.dtype((a.dtype, a.shape[1:]))
# convert to 1d void arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
a_void = a.reshape(a.shape[0], -1).view(void_dt)
b_void = b.reshape(b.shape[0], -1).view(void_dt)
# intersect, then convert back
return np.intersect1d(b_void, a_void).view(orig_dt)
Note that using void is unsafe with floats, as it will cause -0 to be unequal to 0