0

I am simulating epsilon-greedy algorithm in bandit problem with 3 arm and bernolli return. After doing the experiment, I want to draw the return for each arm, that is, if one arm is chosen at each time, the value it takes against the corresponding time will be its return, and for the rest 2 arms, the value will be set to -1. Now I would like to plot the return of one arm against the time slot.(The value will take on -1 or 1 or 0)

import matplotlib.pyplot as plt
import random
from scipy import stats
class greedy():
    def __init__(self,epsilon,n):
        self.epsilon=epsilon
        self.n=n
        self.value=[0,0,0]#estimator
        self.count=[0,0,0]
        self.prob=[0.4,0.6,0.8]
        self.greedy_reward=[[0 for x in range(10000)] for y in range(3)]
    def exploration(self,i):
        max_index=np.random.choice([0,1,2])
        r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))#do experiment, return r
        self.count[max_index]+=1
        for time in range(3):
            self.greedy_reward[time][i]=-1
        self.greedy_reward[max_index][i]=r
        self.value[max_index]=self.value[max_index]+(1/self.count[max_index])*(r-self.value[max_index])
    def exploitation(self,i):
        max_index=self.value.index(max(self.value))
        r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))
        self.count[max_index]+=1
        for time in range(3):
            self.greedy_reward[time][i]=-1
        self.greedy_reward[max_index][i]=r
        self.value[max_index]=self.value[max_index]+(1/self.count[max_index])*(r-self.value[max_index])
    def EE_choice(self,i):
        output=np.random.choice(# o is exploitation,1 is exploration
        [0,1], 
        p=[1-self.epsilon,self.epsilon]
        )
        if output==1:
            self.exploration(i);
        else:
            self.exploitation(i);
    def exp(self):
        for i in range(0,self.n):

Then, we take out the return for one arm, for example, arm3.

import matplotlib.pyplot as plt
x=[i for i in range(1,10001)]
arm_3_y=[0 for i in range(10000)]
for j in range(10000):
    arm_3_y[j]=greedy_1.greedy_reward[2][j]
plt.scatter(x,arm_3_y,marker='o')
plt.ylim([-1,1])
plt.show()

enter image description here

As we can see, all points in one vertical line overlap together, is there any way could avoid this?

3
  • Your output is too dense for any differences of each step to be seen Commented Dec 28, 2019 at 7:45
  • @methfux I hold the same view with you. So could you please give me some advice on better plotting it? Commented Dec 28, 2019 at 7:51
  • use box plot, that'll be better in my opinion Commented Dec 28, 2019 at 7:52

1 Answer 1

1

Depending upon what you want to visualize there can be many way to solve it. If you want to see the distribution but don't need individual points then use boxplot. It'll show you the mean, quartiles and range.

If you definitely need the scatterplot, and see the points, add some randomness to every points in your data (only for visualization process), it'll reduce the chances of overlapping in your data and you can see where they are more clustered.

def randomize(arr):
    stdev = .01*min(arr) #use any small value, small enough to not change the distribution
    return arr + np.random.randn(len(arr)) * stdev
plt.scatter(x,randomize(arm_3_y),marker='o')

It should help in visualization. Try messhing with the coefficient (0.01 here) for more jittering.

Sign up to request clarification or add additional context in comments.

2 Comments

This idea can be combined by setting a low alpha value, so places where many dots overlap get colored more strongly.
yeah, that makes the jitter better.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.