diff --git a/DQN/DoubleDQNPrioritized_Solution.ipynb b/DQN/DoubleDQNPrioritized_Solution.ipynb new file mode 100644 index 000000000..db0d22482 --- /dev/null +++ b/DQN/DoubleDQNPrioritized_Solution.ipynb @@ -0,0 +1,713 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import gym\n", + "from gym.wrappers import Monitor\n", + "import itertools\n", + "import numpy as np\n", + "import os\n", + "import random\n", + "import sys\n", + "import tensorflow as tf\n", + "\n", + "if \"../\" not in sys.path:\n", + " sys.path.append(\"../\")\n", + "\n", + "from lib import plotting\n", + "from collections import deque, namedtuple" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "env = gym.envs.make(\"Breakout-v0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Atari Actions: 0 (noop), 1 (fire), 2 (left) and 3 (right) are valid actions\n", + "VALID_ACTIONS = [0, 1, 2, 3]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class StateProcessor():\n", + " \"\"\"\n", + " Processes a raw Atari images. Resizes it and converts it to grayscale.\n", + " \"\"\"\n", + " def __init__(self, scope=\"state_processor\"):\n", + " # Build the Tensorflow graph\n", + " with tf.variable_scope(scope):\n", + " self.input_state = tf.placeholder(shape=[210, 160, 3], dtype=tf.uint8)\n", + " self.output = tf.image.rgb_to_grayscale(self.input_state)\n", + " self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)\n", + " self.output = tf.image.resize_images(\n", + " self.output, size=[84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", + " self.output = tf.squeeze(self.output)\n", + " \n", + " def process(self, sess, state):\n", + " \"\"\"\n", + " Args:\n", + " sess: A Tensorflow session object\n", + " state: A [210, 160, 3] Atari RGB State\n", + "\n", + " Returns:\n", + " A processed [84, 84] state representing grayscale values.\n", + " \"\"\"\n", + " return sess.run(self.output, { self.input_state: state })" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class Estimator():\n", + " \"\"\"Q-Value Estimator neural network.\n", + "\n", + " This network is used for both the Q-Network and the Target Network.\n", + " \"\"\"\n", + "\n", + " def __init__(self, scope=\"estimator\", summaries_dir=None):\n", + " self.scope = scope\n", + " # Writes Tensorboard summaries to disk\n", + " self.summary_writer = None\n", + " with tf.variable_scope(scope):\n", + " # Build the graph\n", + " self._build_model()\n", + " if summaries_dir:\n", + " summary_dir = os.path.join(summaries_dir, \"summaries_{}\".format(scope))\n", + " if not os.path.exists(summary_dir):\n", + " os.makedirs(summary_dir)\n", + " self.summary_writer = tf.summary.FileWriter(summary_dir)\n", + "\n", + " def _build_model(self):\n", + " \"\"\"\n", + " Builds the Tensorflow graph.\n", + " \"\"\"\n", + "\n", + " # Placeholders for our input\n", + " # Our input are 4 grayscale frames of shape 84, 84 each\n", + " self.X_pl = tf.placeholder(shape=[None, 84, 84, 4], dtype=tf.uint8, name=\"X\")\n", + " # The TD target value\n", + " self.y_pl = tf.placeholder(shape=[None], dtype=tf.float32, name=\"y\")\n", + " # Integer id of which action was selected\n", + " self.actions_pl = tf.placeholder(shape=[None], dtype=tf.int32, name=\"actions\")\n", + " self.importance_weights = tf.placeholder(shape=[None], dtype=tf.float32, name=\"importance\")\n", + " X = tf.to_float(self.X_pl) / 255.0\n", + " batch_size = tf.shape(self.X_pl)[0]\n", + "\n", + " # Three convolutional layers\n", + " conv1 = tf.contrib.layers.conv2d(\n", + " X, 32, 8, 4, activation_fn=tf.nn.relu)\n", + " conv2 = tf.contrib.layers.conv2d(\n", + " conv1, 64, 4, 2, activation_fn=tf.nn.relu)\n", + " conv3 = tf.contrib.layers.conv2d(\n", + " conv2, 64, 3, 1, activation_fn=tf.nn.relu)\n", + "\n", + " # Fully connected layers\n", + " flattened = tf.contrib.layers.flatten(conv3)\n", + " fc1 = tf.contrib.layers.fully_connected(flattened, 512)\n", + " self.predictions = tf.contrib.layers.fully_connected(fc1, len(VALID_ACTIONS))\n", + "\n", + " # Get the predictions for the chosen actions only\n", + " gather_indices = tf.range(batch_size) * tf.shape(self.predictions)[1] + self.actions_pl\n", + " self.action_predictions = tf.gather(tf.reshape(self.predictions, [-1]), gather_indices)\n", + "\n", + " # Calculate the loss\n", + " self.losses = tf.squared_difference(self.y_pl, self.action_predictions) * self.importance_weights\n", + " self.loss = tf.reduce_mean(self.losses)\n", + "\n", + " # Optimizer Parameters from original paper\n", + " self.optimizer = tf.train.RMSPropOptimizer(0.00025, 0.99, 0.0, 1e-6)\n", + " self.train_op = self.optimizer.minimize(self.loss, global_step=tf.contrib.framework.get_global_step())\n", + "\n", + " # Summaries for Tensorboard\n", + " self.summaries = tf.summary.merge([\n", + " tf.summary.scalar(\"loss\", self.loss),\n", + " tf.summary.histogram(\"loss_hist\", self.losses),\n", + " tf.summary.histogram(\"q_values_hist\", self.predictions),\n", + " tf.summary.scalar(\"max_q_value\", tf.reduce_max(self.predictions))\n", + " ])\n", + "\n", + " def predict(self, sess, s):\n", + " \"\"\"\n", + " Predicts action values.\n", + "\n", + " Args:\n", + " sess: Tensorflow session\n", + " s: State input of shape [batch_size, 4, 84, 84, 1]\n", + "\n", + " Returns:\n", + " Tensor of shape [batch_size, NUM_VALID_ACTIONS] containing the estimated \n", + " action values.\n", + " \"\"\"\n", + " return sess.run(self.predictions, { self.X_pl: s })\n", + "\n", + " def update(self, sess, s, a, y, importance):\n", + " \"\"\"\n", + " Updates the estimator towards the given targets.\n", + "\n", + " Args:\n", + " sess: Tensorflow session object\n", + " s: State input of shape [batch_size, 4, 84, 84, 1]\n", + " a: Chosen actions of shape [batch_size]\n", + " y: Targets of shape [batch_size]\n", + "\n", + " Returns:\n", + " The calculated loss on the batch.\n", + " \"\"\"\n", + " feed_dict = { self.X_pl: s, self.y_pl: y, self.actions_pl: a, self.importance_weights: importance}\n", + " summaries, global_step, _, loss = sess.run(\n", + " [self.summaries, tf.contrib.framework.get_global_step(), self.train_op, self.loss],\n", + " feed_dict)\n", + " if self.summary_writer:\n", + " self.summary_writer.add_summary(summaries, global_step)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# # For Testing....\n", + "\n", + "# tf.reset_default_graph()\n", + "# global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n", + "\n", + "# e = Estimator(scope=\"test\")\n", + "# sp = StateProcessor()\n", + "\n", + "# with tf.Session() as sess:\n", + "# sess.run(tf.global_variables_initializer())\n", + " \n", + "# # Example observation batch\n", + "# observation = env.reset()\n", + "# observation_p = sp.process(sess, observation)\n", + "# observation = np.stack([observation_p] * 4, axis=2)\n", + "# observations = np.array([observation] * 2)\n", + " \n", + "# # Test Prediction\n", + "# print(e.predict(sess, observations))\n", + "\n", + "# # Test training step\n", + "# y = np.array([10.0, 10.0])\n", + "# a = np.array([1, 3])\n", + "# importance = np.array([0.01])\n", + "# print(e.update(sess, observations, a, y, importance))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class BinarySumTree():\n", + " def __init__(self, capacity):\n", + " self.capacity = capacity\n", + " self.tree = np.zeros(2*self.capacity - 1)\n", + " self.transitions = np.zeros(self.capacity, dtype=object)\n", + " self.n_transitions = 0\n", + " self.transition_index = 0\n", + " \n", + " # propagate the change up the tree to the node\n", + " def propagate_error(self, index, delta):\n", + " # get the parent of current node\n", + " parent = (index - 1) // 2\n", + " self.tree[parent] += delta\n", + " if parent != 0:\n", + " self.propagate_error(parent, delta) \n", + " \n", + " # Retrieve sample from the tree according to given priority\n", + " def retrieve_transition(self, index, priority):\n", + " # get children \n", + " left_child_index = 2 * index + 1\n", + " right_child_index = left_child_index + 1\n", + " \n", + " if left_child_index >= len(self.tree):\n", + " return index\n", + " \n", + " if priority <= self.tree[left_child_index]:\n", + " return self.retrieve_transition(left_child_index, priority)\n", + " else:\n", + " return self.retrieve_transition(right_child_index, priority - self.tree[left_child_index])\n", + " \n", + " # append error to tree and transition to transitions\n", + " def add(self, error, transition):\n", + " # get the tree index to store error\n", + " index = self.transition_index + self.capacity - 1\n", + " \n", + " # append transition to data\n", + " self.transitions[self.transition_index] = transition\n", + " \n", + " # update the tree with the error\n", + " self.update(index, error)\n", + " \n", + " # update transition_index\n", + " self.transition_index += 1\n", + " if self.transition_index >= self.capacity:\n", + " self.transition_index = 0 \n", + " if self.n_transitions < self.capacity:\n", + " self.n_transitions += 1\n", + " \n", + " def update(self, index, error):\n", + " delta = error - self.tree[index]\n", + " self.tree[index] = error\n", + " \n", + " self.propagate_error(index, delta)\n", + " \n", + " def fetch_transition(self, sample_priority):\n", + " index = self.retrieve_transition(0, sample_priority)\n", + " transition_index = index - self.capacity + 1\n", + " transition = self.transitions[transition_index]\n", + " return (index, self.tree[index], transition)\n", + " \n", + " # get total error in the tree\n", + " def cumulative_error(self):\n", + " return self.tree[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# https://arxiv.org/pdf/1511.05952.pdf\n", + "class PrioritizedReplayMemory(object):\n", + " \"\"\"\n", + " Stochastic Proportional Prioritized Experience Replay Memory. Adds\n", + " transitions to memory uses a SumTree to keep track of errors and \n", + " fetch transition from each segment of errors.\n", + " Fetches Transitions from replay memory according to the magnitude \n", + " of error. Read paper for more Details.\n", + " \"\"\"\n", + " e = 0.01\n", + " a = 0.6\n", + " beta = 0.4\n", + " beta_increment_per_sampling = 0.001\n", + "\n", + " def __init__(self, capacity):\n", + " self.tree = BinarySumTree(capacity)\n", + " self.capacity = capacity\n", + "\n", + " def _get_priority(self, error):\n", + " return (np.abs(error) + self.e) ** self.a\n", + "\n", + " def add(self, error, sample):\n", + " p = self._get_priority(error)\n", + " self.tree.add(p, sample)\n", + "\n", + " def sample(self, n):\n", + " \"\"\"\n", + " Samples a segment, and then sample uniformly among the transitions \n", + " within it. This works particularly well in conjunction with a \n", + " minibatch based learning algorithm: choose k to be the size of the \n", + " minibatch, and sample exactly one transition from each segment – this \n", + " is a form of stratified sampling that has the added advantage of \n", + " balancing out the minibatch\n", + " \"\"\"\n", + " batch = []\n", + " idxs = []\n", + " segment = self.tree.cumulative_error() / n\n", + " priorities = []\n", + "\n", + " self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])\n", + "\n", + " for i in range(n):\n", + " a = segment * i\n", + " b = segment * (i + 1)\n", + "\n", + " s = random.uniform(a, b)\n", + " (idx, p, data) = self.tree.fetch_transition(s)\n", + " priorities.append(p)\n", + " batch.append(data)\n", + " idxs.append(idx)\n", + "\n", + " sampling_probabilities = priorities / self.tree.cumulative_error()\n", + " is_weight = np.power(self.tree.n_transitions * sampling_probabilities, -self.beta)\n", + " is_weight /= is_weight.max()\n", + "\n", + " return batch, idxs, is_weight\n", + "\n", + " def update(self, idx, error):\n", + " p = self._get_priority(error)\n", + " self.tree.update(idx, p)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def copy_model_parameters(sess, estimator1, estimator2):\n", + " \"\"\"\n", + " Copies the model parameters of one estimator to another.\n", + "\n", + " Args:\n", + " sess: Tensorflow session instance\n", + " estimator1: Estimator to copy the paramters from\n", + " estimator2: Estimator to copy the parameters to\n", + " \"\"\"\n", + " e1_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator1.scope)]\n", + " e1_params = sorted(e1_params, key=lambda v: v.name)\n", + " e2_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator2.scope)]\n", + " e2_params = sorted(e2_params, key=lambda v: v.name)\n", + "\n", + " update_ops = []\n", + " for e1_v, e2_v in zip(e1_params, e2_params):\n", + " op = e2_v.assign(e1_v)\n", + " update_ops.append(op)\n", + "\n", + " sess.run(update_ops)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def make_epsilon_greedy_policy(estimator, nA):\n", + " \"\"\"\n", + " Creates an epsilon-greedy policy based on a given Q-function approximator and epsilon.\n", + "\n", + " Args:\n", + " estimator: An estimator that returns q values for a given state\n", + " nA: Number of actions in the environment.\n", + "\n", + " Returns:\n", + " A function that takes the (sess, observation, epsilon) as an argument and returns\n", + " the probabilities for each action in the form of a numpy array of length nA.\n", + "\n", + " \"\"\"\n", + " def policy_fn(sess, observation, epsilon):\n", + " A = np.ones(nA, dtype=float) * epsilon / nA\n", + " q_values = estimator.predict(sess, np.expand_dims(observation, 0))[0]\n", + " best_action = np.argmax(q_values)\n", + " A[best_action] += (1.0 - epsilon)\n", + " return A\n", + " return policy_fn" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def deep_q_learning(sess,\n", + " env,\n", + " q_estimator,\n", + " target_estimator,\n", + " state_processor,\n", + " num_episodes,\n", + " experiment_dir,\n", + " replay_memory_size=500000,\n", + " replay_memory_init_size=500,\n", + " update_target_estimator_every=10000,\n", + " discount_factor=0.99,\n", + " epsilon_start=1.0,\n", + " epsilon_end=0.1,\n", + " epsilon_decay_steps=500000,\n", + " batch_size=32,\n", + " record_video_every=50):\n", + " \"\"\"\n", + " Q-Learning algorithm for off-policy TD control using Function Approximation.\n", + " Finds the optimal greedy policy while following an epsilon-greedy policy.\n", + "\n", + " Args:\n", + " sess: Tensorflow Session object\n", + " env: OpenAI environment\n", + " q_estimator: Estimator object used for the q values\n", + " target_estimator: Estimator object used for the targets\n", + " state_processor: A StateProcessor object\n", + " num_episodes: Number of episodes to run for\n", + " experiment_dir: Directory to save Tensorflow summaries in\n", + " replay_memory_size: Size of the replay memory\n", + " replay_memory_init_size: Number of random experiences to sampel when initializing \n", + " the reply memory.\n", + " update_target_estimator_every: Copy parameters from the Q estimator to the \n", + " target estimator every N steps\n", + " discount_factor: Gamma discount factor\n", + " epsilon_start: Chance to sample a random action when taking an action.\n", + " Epsilon is decayed over time and this is the start value\n", + " epsilon_end: The final minimum value of epsilon after decaying is done\n", + " epsilon_decay_steps: Number of steps to decay epsilon over\n", + " batch_size: Size of batches to sample from the replay memory\n", + " record_video_every: Record a video every N episodes\n", + "\n", + " Returns:\n", + " An EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.\n", + " \"\"\"\n", + "\n", + " Transition = namedtuple(\"Transition\", [\"state\", \"action\", \"reward\", \"next_state\", \"done\"])\n", + "\n", + " # The replay memory\n", + " replay_memory = PrioritizedReplayMemory(replay_memory_size)\n", + "\n", + " # Keeps track of useful statistics\n", + " stats = plotting.EpisodeStats(\n", + " episode_lengths=np.zeros(num_episodes),\n", + " episode_rewards=np.zeros(num_episodes))\n", + "\n", + " # Create directories for checkpoints and summaries\n", + " checkpoint_dir = os.path.join(experiment_dir, \"checkpoints\")\n", + " checkpoint_path = os.path.join(checkpoint_dir, \"model\")\n", + " monitor_path = os.path.join(experiment_dir, \"monitor\")\n", + "\n", + " if not os.path.exists(checkpoint_dir):\n", + " os.makedirs(checkpoint_dir)\n", + " if not os.path.exists(monitor_path):\n", + " os.makedirs(monitor_path)\n", + "\n", + " saver = tf.train.Saver()\n", + " # Load a previous checkpoint if we find one\n", + " latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n", + " if latest_checkpoint:\n", + " print(\"Loading model checkpoint {}...\\n\".format(latest_checkpoint))\n", + " saver.restore(sess, latest_checkpoint)\n", + " \n", + " # Get the current time step\n", + " total_t = sess.run(tf.contrib.framework.get_global_step())\n", + "\n", + " # The epsilon decay schedule\n", + " epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps)\n", + "\n", + " # The policy we're following\n", + " policy = make_epsilon_greedy_policy(\n", + " q_estimator,\n", + " len(VALID_ACTIONS))\n", + " \n", + " def push_sample_to_memory(state, reward, action, next_state, done):\n", + " target = q_estimator.predict(sess, np.expand_dims(state, 0))\n", + " q_values_next = q_estimator.predict(sess, np.expand_dims(next_state, 0))\n", + " q_values_next_target = target_estimator.predict(sess, np.expand_dims(next_state, 0))\n", + " old_val = target[0][action].copy()\n", + " if done:\n", + " target[0][action] = reward\n", + " else:\n", + " # Double DQN\n", + " target[0][action] = reward + discount_factor * q_values_next_target[0][np.argmax(q_values_next)]\n", + " error = abs(old_val - target[0][action])\n", + " replay_memory.add(error, Transition(state, action, reward, next_state, done))\n", + " \n", + " # Populate the replay memory with initial experience\n", + " print(\"Populating replay memory...\")\n", + " state = env.reset()\n", + " state = state_processor.process(sess, state)\n", + " state = np.stack([state] * 4, axis=2)\n", + " for i in range(replay_memory_init_size):\n", + " action_probs = policy(sess, state, epsilons[total_t])\n", + " action = np.random.choice(np.arange(len(action_probs)), p=action_probs)\n", + " next_state, reward, done, _ = env.step(VALID_ACTIONS[action])\n", + " next_state = state_processor.process(sess, next_state)\n", + " next_state = np.append(state[:,:,1:], np.expand_dims(next_state, 2), axis=2)\n", + " push_sample_to_memory(state, reward, action, next_state, done)\n", + " if done:\n", + " state = env.reset()\n", + " state = state_processor.process(sess, state)\n", + " state = np.stack([state] * 4, axis=2)\n", + " else:\n", + " state = next_state\n", + "\n", + "# Record videos\n", + " # Add env Monitor wrapper\n", + " env = Monitor(env, \n", + " directory=monitor_path, \n", + " video_callable=lambda count: count % record_video_every == 0, \n", + " resume=True)\n", + "\n", + " for i_episode in range(num_episodes):\n", + "\n", + " # Save the current checkpoint\n", + " saver.save(tf.get_default_session(), checkpoint_path)\n", + "\n", + " # Reset the environment\n", + " state = env.reset()\n", + " state = state_processor.process(sess, state)\n", + " state = np.stack([state] * 4, axis=2)\n", + " loss = None\n", + "\n", + " # One step in the environment\n", + " for t in itertools.count():\n", + "\n", + " # Epsilon for this time step\n", + " epsilon = epsilons[min(total_t, epsilon_decay_steps-1)]\n", + "\n", + " # Add epsilon to Tensorboard\n", + " episode_summary = tf.Summary()\n", + " episode_summary.value.add(simple_value=epsilon, tag=\"epsilon\")\n", + " q_estimator.summary_writer.add_summary(episode_summary, total_t)\n", + "\n", + " # Maybe update the target estimator\n", + " if total_t % update_target_estimator_every == 0:\n", + " copy_model_parameters(sess, q_estimator, target_estimator)\n", + " print(\"\\nCopied model parameters to target network.\")\n", + "\n", + " # Print out which step we're on, useful for debugging.\n", + " print(\"\\rStep {} ({}) @ Episode {}/{}, loss: {}\".format(\n", + " t, total_t, i_episode + 1, num_episodes, loss), end=\"\")\n", + " sys.stdout.flush()\n", + "\n", + " # Take a step\n", + " action_probs = policy(sess, state, epsilon)\n", + " action = np.random.choice(np.arange(len(action_probs)), p=action_probs)\n", + " next_state, reward, done, _ = env.step(VALID_ACTIONS[action])\n", + " next_state = state_processor.process(sess, next_state)\n", + " next_state = np.append(state[:,:,1:], np.expand_dims(next_state, 2), axis=2)\n", + "\n", + " # Save transition to replay memory\n", + " push_sample_to_memory(state, reward, action, next_state, done)\n", + "\n", + " # Update statistics\n", + " stats.episode_rewards[i_episode] += reward\n", + " stats.episode_lengths[i_episode] = t\n", + "\n", + " # Sample a minibatch from the replay memory\n", + " samples, indices, importance = replay_memory.sample(batch_size)\n", + " states_batch, action_batch, reward_batch, next_states_batch, done_batch = map(np.array, zip(*samples))\n", + " \n", + " # Calculate q values and targets\n", + " # This is where Double Q-Learning comes in!\n", + " q_values_next = q_estimator.predict(sess, next_states_batch)\n", + " best_actions = np.argmax(q_values_next, axis=1)\n", + " q_values_next_target = target_estimator.predict(sess, next_states_batch)\n", + " targets_batch = reward_batch + np.invert(done_batch).astype(np.float32) * \\\n", + " discount_factor * q_values_next_target[np.arange(batch_size), best_actions]\n", + " \n", + " # update errors\n", + " # get current action value\n", + " predictions = q_estimator.predict(sess, states_batch)\n", + " gather_indices = np.arange(batch_size) * np.shape(predictions)[1] + action_batch\n", + " action_predictions = np.take(np.reshape(predictions, [-1]), gather_indices)\n", + " \n", + " errors = np.abs(action_predictions - targets_batch)\n", + " \n", + " # update priority\n", + " for i in range(batch_size):\n", + " index = indices[i]\n", + " replay_memory.update(index, errors[i])\n", + "\n", + " # Perform gradient descent update\n", + " states_batch = np.array(states_batch)\n", + " importance_batch = np.array(importance)\n", + " loss = q_estimator.update(sess, states_batch, action_batch, targets_batch, importance_batch)\n", + "\n", + " if done:\n", + " break\n", + "\n", + " state = next_state\n", + " total_t += 1\n", + "\n", + " # Add summaries to tensorboard\n", + " episode_summary = tf.Summary()\n", + " episode_summary.value.add(simple_value=stats.episode_rewards[i_episode], node_name=\"episode_reward\", tag=\"episode_reward\")\n", + " episode_summary.value.add(simple_value=stats.episode_lengths[i_episode], node_name=\"episode_length\", tag=\"episode_length\")\n", + " q_estimator.summary_writer.add_summary(episode_summary, total_t)\n", + " q_estimator.summary_writer.flush()\n", + "\n", + " yield total_t, plotting.EpisodeStats(\n", + " episode_lengths=stats.episode_lengths[:i_episode+1],\n", + " episode_rewards=stats.episode_rewards[:i_episode+1])\n", + "\n", + " Monitor.close()\n", + " return stats" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "tf.reset_default_graph()\n", + "\n", + "# Where we save our checkpoints and graphs\n", + "experiment_dir = os.path.abspath(\"./experiments/{}\".format(env.spec.id))\n", + "\n", + "# Create a glboal step variable\n", + "global_step = tf.Variable(0, name='global_step', trainable=False)\n", + " \n", + "# Create estimators\n", + "q_estimator = Estimator(scope=\"q_estimator\", summaries_dir=experiment_dir)\n", + "target_estimator = Estimator(scope=\"target_q\")\n", + "\n", + "# State processor\n", + "state_processor = StateProcessor()\n", + "\n", + "# Run it!\n", + "with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " for t, stats in deep_q_learning(sess,\n", + " env,\n", + " q_estimator=q_estimator,\n", + " target_estimator=target_estimator,\n", + " state_processor=state_processor,\n", + " experiment_dir=experiment_dir,\n", + " num_episodes=10000,\n", + " replay_memory_size=500000,\n", + " replay_memory_init_size=50000,\n", + " update_target_estimator_every=10000,\n", + " epsilon_start=1.0,\n", + " epsilon_end=0.1,\n", + " epsilon_decay_steps=500000,\n", + " discount_factor=0.99,\n", + " batch_size=32):\n", + "\n", + " print(\"\\nEpisode Reward: {}\".format(stats.episode_rewards[-1]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}