1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| import gym import torch import torch.nn as nn import torch.optim as optim from torch.utils.tensorboard import SummaryWriter import numpy as np import random from collections import deque from tqdm import tqdm
class DQN(nn.Module): def __init__(self, state_dim, action_dim): super(DQN, self).__init__() self.fc1 = nn.Linear(state_dim, 64) self.fc2 = nn.Linear(64, 64) self.fc3 = nn.Linear(64, action_dim)
def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x
class Agent(): def __init__(self, state_dim, action_dim, memory_size=10000, batch_size=64, gamma=0.99, lr=1e-4): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.state_dim = state_dim self.action_dim = action_dim self.memory = deque( maxlen=memory_size) self.batch_size = batch_size self.gamma = gamma self.lr = lr self.policy_net = DQN(state_dim, action_dim).to(self.device) self.target_net = DQN(state_dim, action_dim).to(self.device) self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr) self.loss_fn = nn.MSELoss() self.steps = 0 self.writer = SummaryWriter()
def select_action(self, state, eps): if random.random() < eps: return random.randint(0, self.action_dim - 1) else: state = torch.FloatTensor(state).to(self.device) with torch.no_grad(): action = self.policy_net(state).argmax().item() return action
def store_transition(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done))
def train(self): if len(self.memory) < self.batch_size: return transitions = random.sample(self.memory, self.batch_size) batch = list(zip(*transitions))
state_batch = torch.FloatTensor(batch[0]).to(self.device) action_batch = torch.LongTensor(batch[1]).to(self.device) reward_batch = torch.FloatTensor(batch[2]).to(self.device) next_state_batch = torch.FloatTensor(batch[3]).to(self.device) done_batch = torch.FloatTensor(batch[4]).to(self.device)
q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1) next_q_values = self.target_net(next_state_batch).max(1)[0] expected_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
loss = self.loss_fn(q_values, expected_q_values.detach())
self.optimizer.zero_grad() loss.backward() self.optimizer.step()
self.steps += 1 self.writer.add_scalar("Loss", loss.item(), self.steps)
def update_target(self): self.target_net.load_state_dict(self.policy_net.state_dict())
def train_epoch(env, eps): state = env.reset() total_reward = 0 while True: action = agent.select_action(state, eps)
next_state, reward, done, _ = env.step(action)
agent.store_transition(state, action, reward, next_state, done) state = next_state agent.train()
total_reward += reward if done: break return total_reward
def train_dqn(env, agent: Agent, eps_start=1, eps_end=0.1, eps_decay=0.995, max_episodes=1000, max_steps=1000): eps = eps_start for episode in tqdm(range(max_episodes)): reward = train_epoch(env, eps) agent.update_target() eps = max(eps * eps_decay, eps_end) print(f'{episode} --> {reward}')
if __name__ == "__main__": env = gym.make("CartPole-v1") state_dim = env.observation_space.shape[0] action_dim = env.action_space.n agent = Agent(state_dim, action_dim) train_dqn(env, agent)
|