I am having some trouble with correctly implementing the multiprocess pool with the function I have defined here:
def run_ilqr(x0, N, max_iter, regu_init, alpha_init, x_dim, u_dim, n_agents, x_ref, x_ref_T):
# First forward rollout
u_trj = np.random.randn(N-1, n_agents*n_inputs)*0.0001
x_trj = rollout(x0, u_trj, x_dim, u_dim)
total_cost = cost_sum(x_trj, u_trj, x_dim, x_ref, x_ref_T)
regu = regu_init
max_regu = 10000
min_regu = 0.01
alpha = alpha_init
max_alpha = 1.0
min_alpha = 0.0
# Setup traces
cost_trace = [total_cost]
expected_cost_redu_trace = []
redu_ratio_trace = [1]
redu_trace = []
regu_trace = [regu]
alpha_trace = [alpha]
# Run main loop
for it in range(max_iter):
# Backward and forward pass
k_trj, K_trj, expected_cost_redu = backward_pass(x_trj, u_trj, regu, alpha, x_dim, u_dim, x_ref_T, x_ref)
x_trj_new, u_trj_new = forward_pass(x_trj, u_trj, k_trj, K_trj, expected_cost_redu, total_cost, alpha, x_dim, u_dim)
# Evaluate new trajectory
total_cost = cost_sum(x_trj_new, u_trj_new, x_dim, x_ref, x_ref_T)
cost_redu = cost_trace[-1] - total_cost
redu_ratio = cost_redu / abs(expected_cost_redu)
# Accept or reject iteration
if redu_ratio >= 1e-4 and redu_ratio <= 10 :
# Improvement! Accept new trajectories and lower regularization
redu_ratio_trace.append(redu_ratio)
cost_trace.append(total_cost)
x_trj = x_trj_new
u_trj = u_trj_new
regu *= 0.7
# alpha doesn't change if accepted
else:
# Reject new trajectories and increase regularization
regu *= 2.0
alpha = alpha* 0.5 # a scaling factor of 0.5 for alpha is a typical value
cost_trace.append(cost_trace[-1])
redu_ratio_trace.append(0)
regu = min(max(regu, min_regu), max_regu)
regu_trace.append(regu)
redu_trace.append(cost_redu)
alpha = min(max(alpha,min_alpha),max_alpha)
alpha_trace.append(alpha)
# Early termination if expected improvement is small
if expected_cost_redu <= 1e-6:
break
return x_trj, u_trj, cost_trace, regu_trace, redu_ratio_trace, redu_trace, alpha_trace
The details of the function above does not matter much, but it requires 10 input arguments, and among them x0,x_ref,_x_ref_T are a 1-dimensional vectors ( such as shape (12,) ), N,max_iter,regu_init,alpha_init,n_agents are scalars, x_dim,u_dim are list-like such as [4,4,4] or [2,2,2]. x_trj,u_trj are the most important return variable here, and they are 2-dimensional arrays.
I tried to call this run_ilqr function with Pool, and here is what I have done:
import time
from multiprocessing import Pool
with Pool(multiprocessing.cpu_count()-1) as p:
p.starmap(run_ilqr,[first_args,N,max_iter,regu_init,alpha_init,second_args,third_args,fourth_args,fifth_args,sixth_args])
where first_args, second_args etc, are lists each of length 150, and for example, first_args refers to x0 in the input argument since I want to call the function with 150 different x0 vectors. However, the following error pops up
Traceback (most recent call last):
File "/usr/lib/python3.7/multiprocessing/pool.py", line 121, in worker
result = (True, func(*args, **kwds))
File "/usr/lib/python3.7/multiprocessing/pool.py", line 47, in starmapstar
return list(itertools.starmap(args[0], args[1]))
TypeError: run_ilqr() takes 10 positional arguments but 150 were given
"""
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-88-628be8ab875b> in <module>()
1 with Pool(multiprocessing.cpu_count()-1) as p:
----> 2 p.starmap(run_ilqr,[first_args,N,max_iter,regu_init,alpha_init,second_args,third_args,fourth_args,fifth_args,sixth_args])
1 frames
/usr/lib/python3.7/multiprocessing/pool.py in starmap(self, func, iterable, chunksize)
274 `func` and (a, b) becomes func(a, b).
275 '''
--> 276 return self._map_async(func, iterable, starmapstar, chunksize).get()
277
278 def starmap_async(self, func, iterable, chunksize=None, callback=None,
/usr/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
655 return self._value
656 else:
--> 657 raise self._value
658
659 def _set(self, i, obj):
TypeError: run_ilqr() takes 10 positional arguments but 150 were given
I looked up some examples for pool.starmap, and it seems like I just needed to pass in a list of input arguments for each variable, which is what I have tried to do. Any help is much appreciated!