You can create a zero array based on the expected shape and then fill the desired indices by the b values and finally fill remained zero values by tile of the a array with the needed shape as:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
# [[0 1]
# [1 2]
# [2 3]
# [3 4]
# [4 5]]
zero_arr = np.zeros((shape, size), dtype=np.float64)
zero_arr[np.arange(shape)[:, None], ind] = b
# [[5 6 0 0 0 0]
# [0 5 6 0 0 0]
# [0 0 5 6 0 0]
# [0 0 0 5 6 0]
# [0 0 0 0 5 6]]
zero_arr[zero_arr == 0] = np.tile(a, shape)
# [[5 6 1 2 3 4]
# [1 5 6 2 3 4]
# [1 2 5 6 3 4]
# [1 2 3 5 6 4]
# [1 2 3 4 5 6]]
Update
This method will beat tax evader method in terms of performance, on larger arrays, if we create zero_arr based on any value that not contained in the b array. E.g. if we have 0 in b, so zero_arr == 0 will misled the solution. One possible method is to use -np.ones with zero_arr == -1 if b contains just positive values. We can create this array using np.fill or np.full if know which arbitrary value is not in that to use aforementioned indexing method; This value could be selected by using np.arange or np.random and check to find a value that is not in the b, too. But more comprehensive one is as below:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
ind_1 = np.lib.stride_tricks.sliding_window_view(np.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = ind_1 % size
arr = np.zeros((shape, size), dtype=np.float64)
arange_ = np.arange(shape)[:, None]
arr[arange_, ind_0] = b
arr[arange_, np.sort(ind_1)] = np.broadcast_to(a, (shape, a.shape[0])) # or use np.tile
If you don't have any limitation to use other libraries (as you was agree with the numba one) and can run the code on GPU, I believe that JAX library will beat numba on larger arrays. I have converted the written codes by NumPy into JAX jitted form to see how this library can handle such matrix form problems in terms of performance. Besides the benchmarks to evaluate and compare the performances between JAX and numba on this issue, these codes have a learning aspects about how to use jax numpy where (jnp.where) with JAX jit decorator, where we must specify sizes statically in that to be workable. Another aspect was about creating equivalent np.lib.stride_tricks.sliding_window_view or np.lib.stride_tricks.as_strided in jax jitted function by jax library. The evader code is converted, too, but by some changes (I think are needed for JAX usage); I don't know if I could write it in more compacted form (shorter). I think the written code can be rewritten in more optimized form which will get, even, faster codes.
from functools import partial
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
@jit
def jax_initial(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = moving_window(jnp.arange(size), b.shape[0])
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arr = arr.at[jnp.arange(shape)[:, None], ind].set(b)
broad = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0])))
idx = jnp.where(arr == 0, size=broad.size)
return arr.at[idx].set(broad)
@jit
def jax_comp(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = moving_window(jnp.arange(size), b.shape[0])
ind_1 = moving_window(jnp.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = jnp.remainder(ind_1, size) # ind_1 = ind_1 % size
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arange_ = jnp.arange(shape)[:, None]
arr = arr.at[arange_, ind_0].set(b)
arr = arr.at[arange_, jnp.sort(ind_1)].set(jnp.broadcast_to(a, (shape, a.shape[0])))
return arr
@jit
def jax_evader(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
res = jnp.empty([shape, size])
frame = jnp.reshape(jnp.arange(shape * size), (shape, size))
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res_0 = jnp.ravel(jnp.broadcast_to(b, (shape, b.shape[0]))) # jnp.tile(b, res.shape[0])
res_1 = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0]))) # jnp.tile(a, res.shape[0])
idx = jnp.where(diag_mask, size=res_0.size)
idx_v = jnp.where(~diag_mask, size=res_1.size)
res = res.at[idx].set(res_0)
return res.at[idx_v].set(res_1)
some Benchmarks temporary link


Fastest code:
We can parallel the numba code with signatures which, on my system, ran at least 3 times faster than the Pig one:
a = np.random.rand(10000)
b = np.array([5, 6, 7, 8, 9, 10, 11], dtype=np.int64)
@nb.njit('float64[::1], int64[::1]', parallel=True)
def fill_parallel(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in nb.prange(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
numba parallelized code is the fastest code so far.