16

Numpy has the random.choice function, which allows you to sample from a categorical distribution. How would you repeat this over an axis? To illustrate what I mean, here is my current code:

categorical_distributions = np.array([
    [.1, .3, .6],
    [.2, .4, .4],
])
_, n = categorical_distributions.shape
np.array([np.random.choice(n, p=row)
          for row in categorical_distributions])

Ideally, I would like to eliminate the for loop.

4
  • Looks like a job for map. Commented Dec 8, 2017 at 21:40
  • @Galen The performance numbers would be comparable to the posted loopy solution, if not worse. Commented Dec 8, 2017 at 21:41
  • 1
    @Divakar Agreed. Commented Dec 8, 2017 at 22:58
  • 1
    See stackoverflow.com/questions/34187130/… Commented Dec 9, 2017 at 6:33

2 Answers 2

23

Here's one vectorized way to get the random indices per row, with a as the 2D array of probabilities -

(a.cumsum(1) > np.random.rand(a.shape[0])[:,None]).argmax(1)

Generalizing to cover both along the rows and columns for 2D array -

def random_choice_prob_index(a, axis=1):
    r = np.expand_dims(np.random.rand(a.shape[1-axis]), axis=axis)
    return (a.cumsum(axis=axis) > r).argmax(axis=axis)

Let's verify with the given sample by running it over a million times -

In [589]: a = np.array([
     ...:     [.1, .3, .6],
     ...:     [.2, .4, .4],
     ...: ])

In [590]: choices = [random_choice_prob_index(a)[0] for i in range(1000000)]

# This should be close to first row of given sample
In [591]: np.bincount(choices)/float(len(choices))
Out[591]: array([ 0.099781,  0.299436,  0.600783])

Runtime test

Original loopy way -

def loopy_app(categorical_distributions):
    m, n = categorical_distributions.shape
    out = np.empty(m, dtype=int)
    for i,row in enumerate(categorical_distributions):
        out[i] = np.random.choice(n, p=row)
    return out

Timings on bigger array -

In [593]: a = np.array([
     ...:     [.1, .3, .6],
     ...:     [.2, .4, .4],
     ...: ])

In [594]: a_big = np.repeat(a,100000,axis=0)

In [595]: %timeit loopy_app(a_big)
1 loop, best of 3: 2.54 s per loop

In [596]: %timeit random_choice_prob_index(a_big)
100 loops, best of 3: 6.44 ms per loop
Sign up to request clarification or add additional context in comments.

2 Comments

Excellent answer. How would you implement choice without replacement?
Thank you for the answer. In case it's relevant for anyone, this method is based on the idea of inverse transform sampling, see e.g. stephens999.github.io/fiveMinuteStats/…
0

We can use numba to implement it:

import numba

@numba.njit(fastmath=True)
def _sample_rows_from_pr(pr, rnd):
    """
    Args:
        pr: (n, k) rowwise PMF, should sum to 1
        rnd: (n,) uniform in [0, 1)

    Returns:
        (n,) chosen column for each row
    """
    n, k = pr.shape
    out = np.empty(n, np.int64)
    for i in range(n):
        u = rnd[i]
        s = 0
        # Fallback to account for cases where `pr` sums
        # to less than 1 due to a numerical error:
        chosen = 0
        for j in range(k):
            s += pr[i, j]
            if u < s:
                chosen = j
                break
        out[i] = chosen
    return out

def choice_along_axis1(np_rng, pr):
    """
    Vectorized version of np.random.Generator.choice.

    Assumes `pr` is a 2D array of probs and we want to sample along axis1.

    Args:
        np_rng: 
        pr: (n, k) probabilities, rows sum to 1
    
    Returns:
        (n,) ints in [0, k)
    """
    return _sample_rows_from_pr(pr, np_rng.random(pr.shape[0]))

This is faster than @Divakar solution on my machine:

# Example data:
import numpy as np
import scipy

np_rng = np.random.default_rng(10)
n = 2000
k = 10
hh = np.sin(np.arange(n * k).reshape((n, k)))
pr = scipy.special.softmax(hh, axis=1)

# @Divakar solution:
def choice_along_axis1_v1(np_rng, pr):
    return (pr.cumsum(1) > np_rng.uniform(size=pr.shape[0])[:,None]).argmax(1)

Timing results:

  • choice_along_axis1_v1: 129 μs ± 1.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
  • choice_along_axis1: 34.5 μs ± 715 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.