-
Notifications
You must be signed in to change notification settings - Fork 817
Expand file tree
/
Copy pathtrain.py
More file actions
125 lines (104 loc) · 3.83 KB
/
train.py
File metadata and controls
125 lines (104 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gym
import numpy as np
import parl
from parl.utils import logger
from parl.env import CompatWrapper, is_gym_version_ge
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
import argparse
LEARNING_RATE = 1e-3
# train an episode
def run_train_episode(agent, env):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
action = agent.sample(obs)
action_list.append(action)
obs, reward, done, info = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
# evaluate 5 episodes
def run_evaluate_episodes(agent, eval_episodes=5, render=False):
# Compatible for different versions of gym
if is_gym_version_ge("0.26.0") and render: # if gym version >= 0.26.0
env = gym.make('CartPole-v1', render_mode="human")
else:
env = gym.make('CartPole-v1')
env = CompatWrapper(env)
eval_reward = []
for i in range(eval_episodes):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if isOver:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def calc_reward_to_go(reward_list, gamma=1.0):
for i in range(len(reward_list) - 2, -1, -1):
# G_i = r_i + γ·G_i+1
reward_list[i] += gamma * reward_list[i + 1] # Gt
return np.array(reward_list)
def main():
env = gym.make('CartPole-v1')
# Compatible for different versions of gym
env = CompatWrapper(env)
# env = env.unwrapped # Cancel the minimum score limit
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
# build an agent
model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
agent = CartpoleAgent(alg)
# load model and evaluate
# if os.path.exists('./model.ckpt'):
# agent.restore('./model.ckpt')
# run_evaluate_episodes(agent, env, render=True)
# exit()
for i in range(args.max_episodes):
obs_list, action_list, reward_list = run_train_episode(agent, env)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(
i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_reward_to_go(reward_list)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
total_reward = run_evaluate_episodes(agent, render=False)
logger.info('Test reward: {}'.format(total_reward))
# save the parameters to ./model.ckpt
agent.save('./model.ckpt')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Environment
parser.add_argument(
'--max_episodes',
type=int,
default=1000,
help='stop condition: number of episodes')
args = parser.parse_args()
main()