-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplot_rewards.py
More file actions
42 lines (32 loc) · 1.28 KB
/
plot_rewards.py
File metadata and controls
42 lines (32 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python
'''
Created on September 2, 2014
@author: Jonathan Scholz <jonathan.scholz@gmail.com>
'''
from matplotlib import pyplot as plt
from scipy.interpolate import UnivariateSpline
import pickle
def plot_sarsa_vs_qlearning(sarsa_rewards, qlearning_rewards):
'''
Generates a smoothed plot of sarsa and q-learning rewards,
using scipy's UnivariateSpline.
'''
# plt.interactive(True)
plt.figure(0)
plt.clf()
plt.ylabel('Reward per episodes')
plt.xlabel('Episodes')
smooth_factor = 150000
x = range(len(sarsa_rewards))
sms = UnivariateSpline(x, sarsa_rewards, s=245000)
# plt.plot(x, sarsa_rewards)
plt.plot(x, sms(x))
smq = UnivariateSpline(x, qlearning_rewards, s=300000)
# plt.plot(x, qlearning_rewards)
plt.plot(x, smq(x))
plt.legend(["SARSA", "Q-Learning"], loc=0)
plt.show()
if __name__ == '__main__':
sarsa_episode_rewards = pickle.load(open('SARSAAgent_rewards_alpha-0.25_gamma-0.9_epsilon-0.99_epsilon_decay-0.99_plot-True_max_episodes-500.pkl','rb'))
qlearning_episode_rewards = pickle.load(open('QLearningAgent_rewards_alpha-0.25_gamma-0.9_epsilon-0.99_epsilon_decay-0.99_plot-True_max_episodes-500.pkl','rb'))
plot_sarsa_vs_qlearning(sarsa_episode_rewards, qlearning_episode_rewards)