You need arrays that broadcast correctly to the output shape you want. You can add the missing dimension back using np.expand_dims:
index = np.expand_dims(np.argmin(array, axis=2), axis=2)
This makes it easy to set or extract the elements that you want to remove:
index = list(np.indices(array.shape, sparse=True))
index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = array[tuple(index)]
np.indices with sparse=True returns a set of ranges shaped to broadcast the index correctly in each dimension. A nicer alternative is to use np.take_along_axis:
index = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = np.take_along_axis(array, index, axis=2)
You can use these results to create a mask, e.g. with np.put_along_axis:
mask = np.ones(array.shape, dtype=bool)
np.put_along_axis(mask, index, 0, axis=2)
Indexing the array with the mask gives you:
result = array[mask].reshape(*array.shape[:2], -1)
The reshape works because your pixels are stored in the last dimension, which should be contiguous in memory. That means that the mask removes one out of three elements correctly, and thus is ordered correctly in memory. That is not usual with masking operations.
Another alternative is to use np.delete with a raveled array and np.ravel_multi_index:
i = np.indices(array.shape[:2], sparse=True)
index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
result = np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)
Just for fun, you can use the fact that you only have three elements per pixel to create a full index of the elements you want to keep. The idea is that the sum of all three indices is 3. Therefore, 3 - np.argmin(array, axis=2) - np.argmax(array, axis=2) is the median element. If you stack the median and the max, you get an index similar to what sort gives you:
amax = np.argmax(array, axis=2)
amin = np.argmin(array, axis=2)
index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
result = np.take_along_axis(array, index, axis=2)
The call to np.clip is necessary to handle the case where all elements are equal, in which case both argmax and argmin return zero.
Timing
Comparing the approaches:
def remove_min_indices(array):
index = list(np.indices(array.shape, sparse=True))
index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
mask = np.ones(array.shape, dtype=bool)
mask[tuple(index)] = False
return array[mask].reshape(*array.shape[:2], -1)
def remove_min_put(array):
mask = np.ones(array.shape, dtype=bool)
np.put_along_axis(mask, np.expand_dims(np.argmin(array, axis=2), axis=2), 0, axis=2)
return array[mask].reshape(*array.shape[:2], -1)
def remove_min_delete(array):
i = np.indices(array.shape[:2], sparse=True)
index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
return np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)
def remove_min_sort_c(array):
return np.sort(array, axis=2)[..., 1:]
def remove_min_sort_i(array):
array.sort(axis=2)
return array[..., 1:]
def remove_min_median(array):
amax = np.argmax(array, axis=2)
amin = np.argmin(array, axis=2)
index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
return np.take_along_axis(array, index, axis=2)
Tested for arrays made like array = np.random.randint(10, size=(N, N, 3), dtype=np.uint8) for N in {100, 1K, 10K, 100K, 1M}:
N | IND | PUT | DEL | S_C | S_I | MED |
-----+---------+---------+---------+---------+---------+---------+
100 | 648. µs | 658. µs | 765. µs | 497. µs | 246. µs | 905. µs |
1K | 67.9 ms | 68.1 ms | 85.7 ms | 51.7 ms | 24.0 ms | 123. ms |
10K | 6.86 s | 6.86 s | 8.72 s | 5.17 s | 2.39 s | 13.2 s |
-----+---------+---------+---------+---------+---------+---------+
Times scale with N^2, as expected. Sorting returns a different result from the other approaches, but is clearly the most efficient. Masking with put_along_axis seems to be the more efficient approach for larger arrays, while raw indexing seems to be more efficient for smaller ones.
array.sort(axis=2); array[:,:,1:]