diff --git a/examples/QuickStart/train.py b/examples/QuickStart/train.py index 5c464b55d..2d7dd7806 100644 --- a/examples/QuickStart/train.py +++ b/examples/QuickStart/train.py @@ -89,10 +89,10 @@ def main(): agent = CartpoleAgent(alg) # load model and evaluate - # if os.path.exists('./model.ckpt'): - # agent.restore('./model.ckpt') - # run_evaluate_episodes(agent, env, render=True) - # exit() + if os.path.exists('./model.ckpt') and args.eval: + agent.restore('./model.ckpt') + run_evaluate_episodes(agent, render=True) + exit() for i in range(args.max_episodes): obs_list, action_list, reward_list = run_train_episode(agent, env) @@ -121,5 +121,10 @@ def main(): type=int, default=1000, help='stop condition: number of episodes') + parser.add_argument( + '--eval', + action='store_true', + default=False, + help='whether to evaluate') args = parser.parse_args() main()