0

I have this multidimensional array of shape (500000,3,2,3),let's call it data. The data is basically 500000 sets of 3 points,each of the 3 points seperated into its x and y coordinates (hence the 2). The last 3 in the shape represents different rotations of the 3 points. Now, I've got this 1d array of 500000 numbers between 0 and 2 that tell me which of the rotations I want to keep, let's call it rot_index. I would like to construct a multidimensional array of shape (500000,3,2) that only keeps the correctly rotated data points. Any ideas on how to extract the data with the correct index from the original data array? I tried something like this, but it didn't work

data[:,:,:,rot_index]

Edit:

here is some example data (giving 10 sets of points instead of 500000)

data = 
[[[[0.70846822 0.98552876 0.66736535]
   [0.         0.         0.        ]]

  [[0.66736535 0.70846822 0.98552876]
   [1.54545219 2.39798549 2.33974762]]

  [[0.98552876 0.66736535 0.70846822]
   [3.88519982 3.94343768 4.73773311]]]


 [[[0.8132551  1.18845796 1.53004225]
   [0.         0.         0.        ]]

  [[1.18845796 1.53004225 0.8132551 ]
   [1.43211754 2.58720625 2.26386152]]

  [[1.53004225 0.8132551  1.18845796]
   [4.01932379 4.85106777 3.69597906]]]


 [[[0.66123513 0.93651048 0.83170562]
   [0.         0.         0.        ]]

  [[0.93651048 0.83170562 0.66123513]
   [2.09747072 2.38383457 1.80188002]]

  [[0.83170562 0.66123513 0.93651048]
   [4.48130529 4.18571459 3.89935074]]]


 [[[1.31047414 0.67740955 1.42020073]
   [0.         0.         0.        ]]

  [[0.67740955 1.42020073 1.31047414]
   [1.66061575 1.97600777 2.64656179]]

  [[1.42020073 1.31047414 0.67740955]
   [3.63662352 4.62256956 4.30717753]]]


 [[[1.4085555  1.64177102 0.27708893]
   [0.         0.         0.        ]]

  [[0.27708893 1.4085555  1.64177102]
   [0.62154257 3.04315813 2.61848461]]

  [[1.64177102 0.27708893 1.4085555 ]
   [3.24002718 3.6647007  5.66164274]]]


 [[[0.48080385 0.85910831 0.52342904]
   [0.         0.         0.        ]]

  [[0.52342904 0.48080385 0.85910831]
   [1.08970318 2.57102289 2.62245924]]

  [[0.85910831 0.52342904 0.48080385]
   [3.71216242 3.66072607 5.19348213]]]


 [[[1.13610207 1.51237019 0.47256909]
   [0.         0.         0.        ]]

  [[1.51237019 0.47256909 1.13610207]
   [2.92304081 2.59328103 0.76686347]]

  [[0.47256909 1.13610207 1.51237019]
   [5.51632184 3.3601445  3.68990428]]]


 [[[1.08397801 1.16506242 0.84703646]
   [0.         0.         0.        ]]

  [[1.16506242 0.84703646 1.08397801]
   [2.37250664 2.04419242 1.86648625]]

  [[0.84703646 1.08397801 1.16506242]
   [4.41669906 3.91067866 4.23899289]]]


 [[[0.98734317 1.11177984 0.90283297]
   [0.         0.         0.        ]]

  [[1.11177984 0.90283297 0.98734317]
   [2.25981006 2.13666143 1.88671382]]

  [[0.90283297 0.98734317 1.11177984]
   [4.39647149 4.02337525 4.14652387]]]


 [[[1.94118244 1.14738719 1.98251535]
   [0.         0.         0.        ]]

  [[1.14738719 1.98251535 1.94118244]
   [1.83291888 1.90183408 2.54843234]]

  [[1.98251535 1.94118244 1.14738719]
   [3.73475296 4.45026642 4.38135123]]]]

And here is a list of the indices I want to keep:

rot_index = np.array([1 2 1 1 1 1 1 2 1 1])

So just as an example, if you consider

data[0,:,:,0] = [[0.70846822 0.]
 [0.66736535 1.54545219]
 [0.98552876 3.88519982]]
data[0,:,:,1] = [[0.98552876 0.]
 [0.70846822 2.39798549]
 [0.66736535 3.94343768]]
data[0,:,:,2] = [[0.66736535 0.]
 [0.98552876 2.33974762]
 [0.70846822 4.73773311]]

These are 3 different "rotations" of the same sample, and if we look at the first element of rot_index, it is a 1. So I only want to keep

data[0,:,:,1] = [[0.98552876 0.]
 [0.70846822 2.39798549]
 [0.66736535 3.94343768]]
3
  • Could you please provide some example data? Commented Oct 18, 2020 at 19:39
  • @Toby_TheBlock Sure, I added some! Commented Oct 18, 2020 at 20:49
  • @Zach - Pls check out my answer Commented Oct 20, 2020 at 7:34

1 Answer 1

1

Using numpy advanced indexing, and under that, the specific subtopic of combining advanced and basic indexing this should work (where data_array is a numpy ndarray having your data):

result = data_array[range(500000),...,rot_index]

For your sample data, this produces:

[[[0.98552876 0.        ]
  [0.70846822 2.39798549]
  [0.66736535 3.94343768]]

 [[1.53004225 0.        ]
  [0.8132551  2.26386152]
  [1.18845796 3.69597906]]

 [[0.93651048 0.        ]
  [0.83170562 2.38383457]
  [0.66123513 4.18571459]]

 [[0.67740955 0.        ]
  [1.42020073 1.97600777]
  [1.31047414 4.62256956]]

 [[1.64177102 0.        ]
  [1.4085555  3.04315813]
  [0.27708893 3.6647007 ]]

 [[0.85910831 0.        ]
  [0.48080385 2.57102289]
  [0.52342904 3.66072607]]

 [[1.51237019 0.        ]
  [0.47256909 2.59328103]
  [1.13610207 3.3601445 ]]

 [[0.84703646 0.        ]
  [1.08397801 1.86648625]
  [1.16506242 4.23899289]]

 [[1.11177984 0.        ]
  [0.90283297 2.13666143]
  [0.98734317 4.02337525]]

 [[1.14738719 0.        ]
  [1.98251535 1.90183408]
  [1.94118244 4.45026642]]]
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.