diff --git a/bin/plotting_test.py b/bin/plotting_test.py new file mode 100755 index 0000000..e730c85 --- /dev/null +++ b/bin/plotting_test.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python +import numpy as np +import math +from driving_gridworld.matplotlib import Simulator +from driving_gridworld.matplotlib import Bumps +from driving_gridworld.matplotlib import Crashes +from driving_gridworld.matplotlib import Ditch +from driving_gridworld.matplotlib import Progress +from driving_gridworld.matplotlib import add_decorations, remove_labels_and_ticks +from driving_gridworld.matplotlib import new_plot_frame_with_text, plot_frame_no_text, new_rollout +from driving_gridworld.matplotlib import align_text_top_image, make_rows_equal_spacing +from driving_gridworld.gridworld import DrivingGridworld +from driving_gridworld.human_ui import observation_to_img, obs_to_rgb +from driving_gridworld.road import Road +from driving_gridworld.car import Car +from driving_gridworld.actions import NO_OP +from driving_gridworld.obstacles import Pedestrian, Bump +import matplotlib as mpl +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from matplotlib import rc +rc('animation', html='jshtml') +import os + + +def ensure_dir(dir_name): + try: + os.mkdir(dir_name) + except FileExistsError: + return + + +# Define path where files will be saved: +my_path = os.path.dirname(os.path.realpath(__file__)) +dir_name = my_path + '/../tmp' +ensure_dir(dir_name) + +# Set up formatting for the movie files +Writer = animation.writers['ffmpeg'] + + +def new_road(headlight_range=2): + return Road( + headlight_range, + Car(2, 2), + obstacles=[Bump(0, 2), Pedestrian(1, 1)], + allowed_obstacle_appearance_columns=[{2}, {1}], + allow_crashing=True) + + +def test_still_image_with_no_text(): + game = DrivingGridworld(new_road) + observation = game.its_showtime()[0] + img = observation_to_img(observation, obs_to_rgb) + fig, ax = plt.subplots(figsize=(6, 6)) + ax = add_decorations(img, remove_labels_and_ticks(ax)) + ax.imshow(img, aspect=1.8) + fig.savefig(dir_name + '/img_no_text.pdf') + + +def test_still_image_with_text(): + game = DrivingGridworld(new_road) + observation = game.its_showtime()[0] + img = observation_to_img(observation, obs_to_rgb) + fig, ax = plt.subplots(figsize=(6, 6)) + + reward_function_list = [Progress(), Bumps(), Ditch(), Crashes()] + info_lists = [] + info_lists.append([f.new_info() for f in reward_function_list]) + frame, ax_texts = new_plot_frame_with_text( + img, 0, *info_lists[0], fig=fig, ax=ax)[:2] + fig.savefig(dir_name + '/img_with_text.pdf') + + +def test_video_with_text(): + hlr = new_road()._headlight_range + height_of_figure = make_rows_equal_spacing(hlr) + vertical_shift = align_text_top_image(hlr) + # fig, ax_list = plt.subplots( + # 1, figsize=(12, height_of_figure), squeeze=False) + # ax_list = ax_list.reshape([1]) + frames, fig, ax_list, actions, rollout_info_lists = new_rollout( + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + plotting_function=new_plot_frame_with_text, + reward_function_list=[Progress(), + Bumps(), Ditch(), + Crashes()], + num_steps=10) + + ani = animation.ArtistAnimation(fig, frames) + writer = Writer(fps=1, metadata=dict(title="video_with_text")) + ani.save(dir_name + '/video_with_text.mp4', writer=writer) + + +def test_video_with_no_text(): # Should maybe pass the policy as an argument? + frames, fig, ax_list, actions, rollout_info_lists = new_rollout( + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + plotting_function=plot_frame_no_text, + reward_function_list=[Progress(), + Bumps(), Ditch(), + Crashes()], + num_steps=10) + ani = animation.ArtistAnimation(fig, frames) + writer = Writer(fps=1, metadata=dict(title="video_no_text")) + ani.save(dir_name + '/video_no_text.mp4', writer=writer) + + +def test_video_multiple_agents_with_text(): + hlr = new_road()._headlight_range + height_of_figure = make_rows_equal_spacing(hlr) + vertical_shift = align_text_top_image(hlr) + + simulators = [ + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)) + ] + + fig, ax_list = plt.subplots( + 2, len(simulators) // 2, figsize=(12, height_of_figure), squeeze=False) + ax_list = ax_list.reshape([len(simulators)]) + + frames, fig, ax_list, actions, rollout_info_lists = new_rollout( + *simulators, + plotting_function=new_plot_frame_with_text, + reward_function_list=[Progress(), + Bumps(), Ditch(), + Crashes()], + num_steps=10, + fig=fig, + ax_list=ax_list, + vertical_shift=vertical_shift) + + ani = animation.ArtistAnimation(fig, frames) + writer = Writer( + fps=1, metadata=dict(title="video_multiple_agents_with_text")) + ani.save(dir_name + '/video_multiple_agents_with_text.mp4', writer=writer) + + +def test_video_multiple_agents_no_text(): + hlr = new_road()._headlight_range + height_of_figure = make_rows_equal_spacing(hlr) + simulators = [ + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)), + Simulator(lambda x: NO_OP, DrivingGridworld(new_road)) + ] + + fig, ax_list = plt.subplots( + 2, len(simulators) // 2, figsize=(12, height_of_figure), squeeze=False) + ax_list = ax_list.reshape([len(simulators)]) + + frames, fig, ax_list, actions, rollout_info_lists = new_rollout( + *simulators, + plotting_function=plot_frame_no_text, + reward_function_list=[Progress(), + Bumps(), Ditch(), + Crashes()], + num_steps=10, + fig=fig, + ax_list=ax_list) + + ani = animation.ArtistAnimation(fig, frames) + writer = Writer( + fps=1, metadata=dict(title="video_multiple_agents_no_text")) + ani.save(dir_name + '/video_multiple_agents_no_text.mp4', writer=writer) + + +if __name__ == '__main__': + test_still_image_with_no_text() + test_still_image_with_text() + test_video_with_text() + test_video_with_no_text() + test_video_multiple_agents_with_text() + test_video_multiple_agents_no_text() diff --git a/driving_gridworld/matplotlib.py b/driving_gridworld/matplotlib.py index 807ba17..3c76949 100644 --- a/driving_gridworld/matplotlib.py +++ b/driving_gridworld/matplotlib.py @@ -1,60 +1,283 @@ +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import math from driving_gridworld.actions import ACTION_NAMES from driving_gridworld.rollout import Rollout -from driving_gridworld.human_ui import observation_to_img - - -def plot_frame_with_text(img, - reward, - discounted_return, - action, - fig=None, - ax=None, - animated=False, - show_grid=False): - white_matrix = np.ones(img.shape) - extended_img = np.concatenate((img, white_matrix), axis=1) +from driving_gridworld.human_ui import observation_to_img, obs_to_rgb +from driving_gridworld.obstacles import Bump, Pedestrian +from driving_gridworld.actions import NO_OP - text_list = [ - 'Action: {}'.format(ACTION_NAMES[action]), - 'Reward: {:0.2f}'.format(reward), - 'Return: {:0.2f}'.format(discounted_return) - ] - if fig is None: - fig = plt.figure() +class RewardInfo(object): + def __init__(self, name, string_format, discount=1.0): + self.name = name + self.discount = discount + self.g = 0 + self._t = 0 + self._string_format = string_format + self.r = 0 + + def next(self, reward_value): + if self._t > 0: + if self.discount < 1.0: + self.g += self.discount**(self._t - 1) * self.r + else: + self.g += self.r + self.r = reward_value + self._t += 1 + return self + + def reward_to_s(self): + return self._string_format.format(self.r) + + def return_to_s(self): + return self._string_format.format(self.g) + + +class RewardFunction(object): + def __init__(self, name, discount=1.0): + self.name = name + self.discount = discount + + def new_info(self): + return RewardInfo( + self.name, self.string_format, discount=self.discount) + + +class Bumps(RewardFunction): + def __init__(self): + super().__init__('Bumps', 1.0) + + def __call__(self, s, a, s_p): + return s.count_obstacle_collisions( + s_p, lambda obs: 1 if isinstance(obs, Bump) else None)[0] + + @property + def string_format(self): + return '{:d}' + + +class Crashes(RewardFunction): + def __init__(self): + super().__init__('Crashes', 1.0) + + def __call__(self, s, a, s_p): + return s.count_obstacle_collisions( + s_p, lambda obs: 1 if isinstance(obs, Pedestrian) else None)[0] + @property + def string_format(self): + return '{:d}' + + +class Ditch(RewardFunction): + def __init__(self): + super().__init__('Ditch', 1.0) + + def __call__(self, s, a, s_p): + return int(s.is_in_a_ditch() or s_p.is_in_a_ditch()) * s.car.speed + + @property + def string_format(self): + return '{:d}' + + +class Progress(RewardFunction): + def __init__(self): + super().__init__('Progress', 1.0) + + def __call__(self, s, a, s_p): + return s.car.progress_toward_destination(a) + + @property + def string_format(self): + return '{:d}' + + +def remove_labels_and_ticks(ax=None): if ax is None: - ax = fig.add_subplot(111) + ax = plt.gca() - ax.grid(show_grid) + ax.grid(False) + ax.axis('off') - # Remove ticks and tick labels ax.set_xticklabels([]) ax.set_yticklabels([]) for tic in ax.xaxis.get_major_ticks(): tic.tick1On = tic.tick2On = False for tic in ax.yaxis.get_major_ticks(): tic.tick1On = tic.tick2On = False + return ax + + +def add_decorations(img, ax=None): + if ax is None: + ax = plt.gca() + + incr = 0.55 + y = -0.60 + for i in range(2 * img.shape[0]): + ax.add_patch( + mpl.patches.Rectangle( + (2.47, y), 0.03, 0.33, 0, color='yellow', alpha=0.8)) + y += incr + + direction_offsets = np.array([(-0.5, -0.5), (0, -0.5 - 0.5 / 3), (0.5, + -0.5)]) + for i in range(img.shape[0] - 1): + ax.add_patch( + mpl.patches.Polygon( + np.array([6, i + 1]) + direction_offsets, + closed=True, + alpha=0.8, + facecolor='grey')) + return ax + + +def new_plot_frame_with_text(img, + action, + *reward_info_list, + fig=None, + ax=None, + animated=False, + show_grid=False, + vertical_shift=1.0): + + if fig is None: + fig = plt.figure() + + if ax is None: + ax = fig.add_subplot(111) + + ax = add_decorations(img, remove_labels_and_ticks(ax)) + num_text_columns = 5 + white_matrix = np.ones([img.shape[0], num_text_columns, img.shape[2]]) + extended_img = np.concatenate((img, white_matrix), axis=1) + + text_list = [ACTION_NAMES[action]] + for info in reward_info_list: + text_list.append('{:8s} {:>5s} + {:>1s}'.format( + info.name, info.return_to_s(), info.reward_to_s())) + + column = img.shape[1] - 0.1 + font = mpl.font_manager.FontProperties() + font.set_family('monospace') + ax_texts = [ + ax.text( + column, + math.ceil(img.shape[0] // 2) + vertical_shift, + '\n\n'.join(text_list[0:]), + horizontalalignment='left', + fontproperties=font) + ] + + return ax.imshow( + extended_img, animated=animated, aspect=1.5), ax_texts, fig, ax + + +def plot_frame_no_text(img, + action, + *reward_info_list, + fig=None, + ax=None, + animated=False, + show_grid=False, + vertical_shift=0.0): + white_matrix = np.ones([img.shape[0], 0, img.shape[2]]) + extended_img = np.concatenate((img, white_matrix), axis=1) + + if fig is None: + fig = plt.figure() + + if ax is None: + ax = fig.add_subplot(111) + + remove_labels_and_ticks(ax) + add_decorations(img, ax) + + return ax.imshow(extended_img, animated=animated, aspect=1.5), [], fig, ax + + +def new_rollout(*simulators, + plotting_function=plot_frame_no_text, + reward_function_list=[], + num_steps=100, + fig=None, + ax_list=None, + vertical_shift=1.0): + + if fig is None or ax_list is None: + fig, ax_list = plt.subplots( + len(simulators), figsize=(6, 6), squeeze=False) + ax_list = ax_list.reshape([len(simulators)]) + + info_lists = [] + frames = [[]] + for i, sim in enumerate(simulators): + observation, d = sim.start() + img = observation_to_img(observation, obs_to_rgb) + info_lists.append([f.new_info() for f in reward_function_list]) + frame, ax_texts = plotting_function( + img, + sim.a, + *info_lists[i], + fig=fig, + ax=ax_list[i], + vertical_shift=vertical_shift)[:2] + frames[0] += [frame] + ax_texts + + actions = [[] for _ in simulators] + for t in range(num_steps): + frames.append([]) + for i, sim in enumerate(simulators): + a, observation, _ = sim.step() + actions[i].append(a) + + for j, info in enumerate(info_lists[i]): + info.next(reward_function_list[j](*sim.sas())) + + frame, ax_texts = plotting_function( + observation_to_img(observation, obs_to_rgb), + a, + *info_lists[i], + fig=fig, + ax=ax_list[i], + vertical_shift=vertical_shift)[:2] + frames[-1] += [frame] + ax_texts + return frames, fig, ax_list, actions, info_lists + + +def align_text_top_image(headlight_range): + output = [ + 1.25, 0.25, -0.75, -1.65, -2.5, -3.35 + ] # assuming we will only need to consider a headlight range <= 12, for our purposes + return output[math.ceil(headlight_range / 2) - 1] + - column = img.shape[1] - 0.4 - ax_texts = [ax.annotate(t, (column, i)) for i, t in enumerate(text_list)] +def make_rows_equal_spacing(headlight_range): + y = headlight_range + 5 + return y - return ax.imshow(extended_img, animated=animated), ax_texts, fig, ax +class Simulator(object): + def __init__(self, policy, game): + self.policy = policy + self.game = game -def plot_rollout(policy, game, num_steps=100, policy_on_game=False): - rollout = Rollout(policy, game, policy_on_game=policy_on_game) - frames = [] + def start(self): + self.prev_state = self.game.road.copy() + self.observation, _, d = self.game.its_showtime() + self.a = NO_OP + self.d = 1 + return self.observation, d - fig = None - ax = None - for t, o, a, r, d, o_prime, dr in rollout: - if t >= num_steps: - break + def step(self): + if self.d > 0: + self.prev_state = self.game.road.copy() + self.a = self.policy(self.game.road) + self.observation, _, self.d = self.game.play(self.a) + return self.a, self.observation, self.d - frame, ax_texts, fig, ax = plot_frame_with_text( - observation_to_img(o), r, dr, a, fig=fig, ax=ax) - frames.append([frame] + ax_texts) - return frames, fig, ax + def sas(self): + return self.prev_state, self.a, self.game.road