Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit e506be3d authored by Morelle Tanguy's avatar Morelle Tanguy

Updated DQN

parent 2af30a56
......@@ -6,10 +6,8 @@ Created on Wed Mar 20 11:28:33 2019
"""
import torch
import random
import numpy as np
from AiGame import Game
from collections import deque
from utils.Plotter import plot_learning
from network.DQNNet import DQNetwork
......@@ -23,7 +21,7 @@ class DQLearning():
def __init__(self, player_params, action_dic,
gamma, epsilon, alpha,
state_size, epsEnd=0.05,
replace=10000, actionSpace=[0,1,2,3,4,5]):
max_steps_iter = 100, actionSpace=[0,1,2,3,4,5]):
self.player_params = player_params
self.action_dic = action_dic
......@@ -33,6 +31,7 @@ class DQLearning():
self.EPS_END = epsEnd
self.ALPHA = alpha
self.actionSpace = actionSpace
self.max_steps_iter = max_steps_iter
self.steps = 0
self.learn_step_counter = 0
......@@ -48,7 +47,6 @@ class DQLearning():
numGames = 1000
batch_size=32
self.DQNet.train()
# uncomment the line below to record every episode.
for i in range(numGames):
print('starting game ', i+1, 'epsilon: %.4f' % self.EPSILON)
done = False
......@@ -57,7 +55,8 @@ class DQLearning():
frames = [observation, observation, observation]
new_frames = [observation, observation, observation]
score = 0
while not done:
steps = 0
while not done and steps<self.max_steps_iter:
frames_ = np.array([frames]).astype(np.float)
action = self.chooseAction(frames_)
observation, reward, done = env.step(self.action_dic[action])
......@@ -69,6 +68,7 @@ class DQLearning():
new_frames_ = np.array([new_frames]).astype(np.float)
self.learn(batch_size, frames_, new_frames_, reward)
frames = new_frames
steps += 1
scores.append(score)
print('score:',score)
......@@ -82,7 +82,6 @@ class DQLearning():
def save_net(self, iteration):
#base_dir = 'C:/Users/Tanguy Morelle/Desktop/3A/Deep Learning/Game/Reinforcement Learning/DQN'
PATH = 'network/ckpt/'+str(iteration)+'.pth'
torch.save(self.DQNet.state_dict(), PATH)
......@@ -127,7 +126,6 @@ class DQLearning():
next_frames = Variable(torch.from_numpy(next_frames)).to(self.device)
reward = torch.tensor(reward, dtype = torch.double).to(self.device)
# convert to list because memory is an array of numpy objects
Qpred = self.DQNet.forward(frames)[0]
Qnext = self.DQNet.forward(next_frames)[0]
......@@ -140,8 +138,7 @@ class DQLearning():
self.EPSILON -= 1e-4
else:
self.EPSILON = self.EPS_END
#Qpred.requires_grad_()
loss = self.DQNet.loss(Qtarget, Qpred).to(self.device)
loss.backward()
self.DQNet.optimizer.step()
......@@ -166,6 +163,7 @@ if __name__ == '__main__':
}
dqlearner = DQLearning(player_params, action_dic,
max_steps_iter = 100,
gamma=0.95,
epsilon=1.0,
alpha=0.003,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment