Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory issues when running GYM environment #166

Open
toshima1051 opened this issue Dec 2, 2024 · 0 comments
Open

Memory issues when running GYM environment #166

toshima1051 opened this issue Dec 2, 2024 · 0 comments

Comments

@toshima1051
Copy link

I am using DreamerV3 for reinforcement training in a home-made GYM environment, but the training is forced to stop at around 300,000 steps.
How can I solve this problem?

I am running on WSL2, python 3.10, RTX4070 ti SUPER GPU, 32GB memory.
Since we are not using OpenAI GYM but its successor Gymnasium, we may need to rewrite from_gym.py a bit to run the code. (Specifically, modify the import statement, etc.)

The code is as attached.
I am having a lot of trouble with the execution stopping at the same point even if I set the environment to one concurrent execution.

Thank you in advance for your help.

Code of the GYM environment:

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import random
import time
import os
from torch.utils.tensorboard import SummaryWriter

class DiscreteWithAttr(spaces.Discrete):
    @property
    def discrete(self):
        return True

    @property
    def classes(self):
        return self.n

class VampireSurvivorEnv(gym.Env):
    metadata = {
        'render_modes': ['human', 'rgb_array'],
        'render_fps': 30
    }

    def __init__(self, render_mode='rgb_array', learning_rate=0.001):
        super(VampireSurvivorEnv, self).__init__()

        # アクションと観測スペースの定義
        self.action_space = DiscreteWithAttr(5)  # 上、下、左、右、決定キー
        self.observation_space = spaces.Dict({
            'image': spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8),
            'is_terminal': spaces.Discrete(2),
            'is_first': spaces.Discrete(2),
            'is_last': spaces.Discrete(2)
        })

        # Pygameの初期化
        self.render_mode = render_mode
        if self.render_mode == 'human':
            pygame.init()
            self.screen_width = 500
            self.screen_height = 500
            self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
            pygame.display.set_caption('Vampire Survivor Gym Environment')
            self.clock = pygame.time.Clock()
            self.font = pygame.font.Font(None, 36)
            self.small_font = pygame.font.Font(None, 24)
        else:
            os.environ["SDL_VIDEODRIVER"] = "dummy"
            pygame.init()
            self.screen_width = 500
            self.screen_height = 500
            self.screen = pygame.Surface((self.screen_width, self.screen_height))
            self.clock = pygame.time.Clock()
            self.font = pygame.font.Font(None, 36)
            self.small_font = pygame.font.Font(None, 24)

        # TensorBoardの設定
        self.writer = SummaryWriter('runs/vampire_survivor')
        self.learning_rate = learning_rate

        # ゲームの状態を初期化
        self._initialize_game_state()

    def _initialize_game_state(self):
        self.player_pos = [self.screen_width // 2, self.screen_height // 2]
        self.previous_position = self.player_pos.copy()
        self.player_speed = 5
        self.player_attack_interval = 1.0
        self.player_last_attack_time = time.time()
        self.bullets = []
        self.enemies = []
        self.enemy_spawn_interval = 1.0
        self.last_enemy_spawn_time = time.time()
        self.score = 0
        self.player_max_health = 100
        self.player_health = self.player_max_health

        self.upgrades = {
            'speed': 0,
            'attack_power': 1,
            'health': 0,
            'bullet_speed': 0,
            'area': 0,  # 'area' を追加
            'evasion': 0,
            'fire_level': 0,
            'wind_level': 0,
            'water_level': 0,
            'ice_level': 0,    # 新規追加
            'earth_level': 0    # 新規追加
        }
        self.max_upgrade_levels = {
            'speed': 10,
            'attack_power': 10,
            'health': 10,
            'bullet_speed': 10,
            'area': 10,  # 'area' を追加
            'evasion': 10,
            'fire_level': 5,
            'wind_level': 5,
            'water_level': 5,
            'ice_level': 5,     # 新規追加
            'earth_level': 5    # 新規追加
        }
        self.enemies_defeated = 0
        self.total_enemies_defeated = 0
        self.level = 1  # レベルを初期化
        self.enemies_to_defeat_for_upgrade = 5  # 初期の必要経験値
        self.upgrade_menu_open = False
        self.direction = [0, -1]
        self.dps = 0
        self.stationary_steps = 0
        self.selected_upgrade = 0
        self.hit_enemy = False

        self.steps_in_episode = 0
        self.reward = 0
        self.game_over = False

        # イベント関連
        self.base_event_interval = 20
        self.min_event_interval = 10
        self.event_interval = self.base_event_interval
        self.last_event_time = time.time()

        # アイテム関連
        self.items = []
        self.active_effects = {}
        self.item_picked = False

        # 敵の体力倍率
        self.enemy_health_multiplier = 1.0

        # アイテムドロップ確率
        self.item_drop_probability = 0.005

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self._initialize_game_state()

        initial_obs = {
            'image': np.zeros((64, 64, 3), dtype=np.uint8),
            'is_terminal': 0,
            'is_first': 1,
            'is_last': 0
        }
        return initial_obs, {}

    def step(self, action):
        if self.upgrade_menu_open:
            self._handle_upgrade_selection(action)
            reward = 0
            done = False
            info = {}
            next_obs = {
                'image': self.get_obs_image(),
                'is_terminal': 0,
                'is_first': 0,
                'is_last': 0
            }
            return next_obs, reward, done, False, info

        previous_position = self.player_pos.copy()
        reward = 0

        if action in [0, 1, 2, 3]:
            self._move_player(action)

        if previous_position == self.player_pos:
            self.stationary_steps += 1
        else:
            self.stationary_steps = 0

        # 停滞ペナルティ
        # stationary_penalty = -1 if self.stationary_steps > 10 else 0
        # reward += stationary_penalty

        current_time = time.time()
        if current_time - self.player_last_attack_time >= self.player_attack_interval:
            self._shoot_bullet()
            self.player_last_attack_time = current_time

        if current_time - self.last_enemy_spawn_time >= self.enemy_spawn_interval:
            self._spawn_enemy()
            self.last_enemy_spawn_time = current_time

        self._move_bullets()
        self._update_enemies()
        self._attack_enemies()
        self._check_player_hit()
        self._check_item_pickup()

        # active_effectsの更新
        for effect in list(self.active_effects):
            self.active_effects[effect]['duration'] -= self.clock.get_time() / 1000
            if self.active_effects[effect]['duration'] <= 0:
                del self.active_effects[effect]

        # イベントのトリガー
        if current_time - self.last_event_time >= self.event_interval:
            self._trigger_event()
            self.last_event_time = current_time
            self.event_interval = max(
                self.min_event_interval,
                self.base_event_interval - (self.upgrades.get('attack_speed', 0) * 0.5)
            )
            self.enemy_health_multiplier += 0.15

        # 敵に近づくことへの報酬
        # closest_enemy_distance = self._get_closest_enemy_distance()
        # if closest_enemy_distance is not None:
        #     proximity_reward = max(0, (100 - closest_enemy_distance) / 100)
        #     reward += proximity_reward

        # 敵を倒した場合の処理
        if self.enemies_defeated >= self.enemies_to_defeat_for_upgrade:
            self.enemies_defeated = 0
            self.upgrade_menu_open = True
            self._open_upgrade_menu()
            self.level += 1  # レベルを増加
            self.enemies_to_defeat_for_upgrade = int(self.enemies_to_defeat_for_upgrade * 1.5)  # 必要経験値を1.5倍
            reward += 10

        # # 敵に攻撃をヒットした場合の報酬
        # if self.hit_enemy:
        #     reward += 10
        #     self.hit_enemy = False

        reward += self.reward
        self.reward = 0  # 報酬をリセット

        done = self.player_health <= 0 or self.score >= 10000

        if done:
            episode_length = self.steps_in_episode
            self.steps_in_episode = 0
            if self.player_health <= 0:
                self._reset_on_death()
        else:
            self.steps_in_episode += 1

        if self.render_mode == 'human':
            self.render()

        obs_image = self.get_obs_image()
        next_obs = {
            'image': obs_image,
            'is_terminal': 1 if done else 0,
            'is_first': 0,
            'is_last': 1 if done else 0
        }

        info = {
            'episode_length': self.steps_in_episode if not done else episode_length,
            'enemies_defeated': self.total_enemies_defeated
        }

        return next_obs, reward, done, False, info

    def render(self, mode='human'):
        if self.render_mode == 'human':
            if self.screen is None:
                self._initialize_rendering()

            self.screen.fill((0, 0, 0))
            self._draw_player()
            self._draw_enemies()
            self._draw_bullets()
            self._draw_items()
            self._draw_health_bar()
            self._draw_upgrade_count()
            self._draw_enemy_defeated_count()
            self._draw_upgrades_and_weapons()

            if self.upgrade_menu_open:
                self._draw_upgrade_menu()
            if self.game_over:
                self._draw_game_over()

            pygame.display.flip()
            self.clock.tick(self.metadata['render_fps'])
        elif self.render_mode == 'rgb_array':
            return np.array(pygame.surfarray.array3d(self.screen))

    def close(self):
        if self.render_mode == 'human':
            pygame.quit()

    def get_obs_image(self):
        scaled_screen = pygame.transform.scale(self.screen, (64, 64))
        obs_image = pygame.surfarray.array3d(scaled_screen).astype(np.uint8)
        obs_image = np.transpose(obs_image, (1, 0, 2))
        return obs_image

    # ゲームロジックのメソッド
    def _move_player(self, action):
        if action == 0:  # 上
            self.player_pos[1] = max(0, self.player_pos[1] - self.player_speed)
            self.direction = [0, -1]
        elif action == 1:  # 下
            self.player_pos[1] = min(self.screen_height - 1, self.player_pos[1] + self.player_speed)
            self.direction = [0, 1]
        elif action == 2:  # 左
            self.player_pos[0] = max(0, self.player_pos[0] - self.player_speed)
            self.direction = [-1, 0]
        elif action == 3:  # 右
            self.player_pos[0] = min(self.screen_width - 1, self.player_pos[0] + self.player_speed)
            self.direction = [1, 0]

    def _shoot_bullet(self):
        for weapon in self.upgrades.keys():
            if weapon.endswith('_level') and self.upgrades[weapon] > 0:
                self._create_bullets(weapon)
        # 基本弾
        bullet_dir = self.direction.copy()
        bullet_pos = self.player_pos.copy()
        bullet_speed = max(self.upgrades['bullet_speed'] * 4, 10)
        self.bullets.append({
            'pos': bullet_pos,
            'dir': bullet_dir,
            'speed': bullet_speed,
            'owner': 'player',
            'damage': (self.upgrades['attack_power'] + 1) * 5,
            'size': 10,
            'piercing': False,
            'weapon': 'basic'
        })

    def _create_bullets(self, weapon):
        bullet_params = self.get_bullet_params(weapon)
        self.bullets.extend(bullet_params)

    def get_bullet_params(self, weapon):
        bullet_list = []
        bullet_speed = (self.upgrades['bullet_speed'] + 1) * 4
        bullet_damage = (self.upgrades['attack_power'] + 1) * 5
        area = (self.upgrades['area'] + 1) * 10

        if weapon == 'fire_level':
            bullet_count = 6 + self.upgrades.get('fire_level', 0) * 2  # 基本は8発、レベルごとに2発増加
            for i in range(bullet_count):
                angle = i * (360 / bullet_count)
                rad = np.deg2rad(angle)
                dir_x = np.cos(rad)
                dir_y = np.sin(rad)
                bullet_list.append({
                    'pos': self.player_pos.copy(),
                    'dir': [dir_x, dir_y],
                    'speed': bullet_speed,
                    'owner': 'player',
                    'damage': bullet_damage * 0.4,
                    'size': area * 0.8,
                    'piercing': False,
                    'weapon': 'fire',
                    'slow_effect': True
                })
        elif weapon == 'wind_level':
            bullet_count = 2 + self.upgrades.get('wind_level', 0) * 1  # 基本は3発、レベルごとに1発増加
            for i in range(bullet_count):
                angle = i * (360 / bullet_count)
                rad = np.deg2rad(angle)
                dir_x = np.cos(rad)
                dir_y = np.sin(rad)
                bullet_list.append({
                    'pos': self.player_pos.copy(),
                    'dir': [dir_x, dir_y],
                    'speed': bullet_speed * 0.5,
                    'owner': 'player',
                    'damage': bullet_damage * 0.2,
                    'size': area,
                    'piercing': True,
                    'weapon': 'wind'
                })
        elif weapon == 'water_level':
            bullet_count = 3 + self.upgrades.get('water_level', 0)  # 基本は4発、レベルごとに1発増加
            for i in range(bullet_count):
                angle = i * (360 / bullet_count)
                rad = np.deg2rad(angle)
                dir_x = np.cos(rad)
                dir_y = np.sin(rad)
                bullet_list.append({
                    'pos': self.player_pos.copy(),
                    'dir': [dir_x, dir_y],
                    'speed': bullet_speed * 0.6,
                    'owner': 'player',
                    'damage': bullet_damage * 0.5,
                    'size': area,
                    'piercing': False,
                    'weapon': 'water',
                    'slow_effect': True
                })
        elif weapon == 'ice_level':
            bullet_count = 2 + self.upgrades.get('ice_level', 0) * 1    # 基本は3発、レベルごとに1発増加
            angles = [-15, 0, 15]  # スプレッドショットの角度
            for angle in angles:
                rad = np.deg2rad(angle)
                dir_x = self.direction[0] * np.cos(rad) - self.direction[1] * np.sin(rad)
                dir_y = self.direction[0] * np.sin(rad) + self.direction[1] * np.cos(rad)
                bullet_list.append({
                    'pos': self.player_pos.copy(),
                    'dir': [dir_x, dir_y],
                    'speed': bullet_speed * 0.8,
                    'owner': 'player',
                    'damage': bullet_damage * 0.7,
                    'size': area,
                    'piercing': False,
                    'weapon': 'ice',
                    'slow_effect': True
                })
        elif weapon == 'earth_level':
            bullet_list.append({
                'pos': self.player_pos.copy(),
                'dir': [0, 0],
                'speed': 0,
                'owner': 'player',
                'damage': bullet_damage * 1.5,
                'size': area * 1.2,
                'piercing': False,
                'weapon': 'earth',
                'duration': 1.0  # 持続時間
            })
        return bullet_list

    def _move_bullets(self):
        for bullet in list(self.bullets):
            if bullet['weapon'] == 'earth':
                bullet['duration'] -= self.clock.get_time() / 1000
                if bullet['duration'] <= 0:
                    self.bullets.remove(bullet)
            else:
                bullet['pos'][0] += bullet['dir'][0] * bullet['speed']
                bullet['pos'][1] += bullet['dir'][1] * bullet['speed']
                if (bullet['pos'][0] < 0 or bullet['pos'][0] > self.screen_width or
                        bullet['pos'][1] < 0 or bullet['pos'][1] > self.screen_height):
                    self.bullets.remove(bullet)

    def _spawn_enemy(self):
        enemy_count = max(1, int(self.upgrades.get('level', 1) * 0.5))
        for _ in range(enemy_count):
            enemy_type = self._determine_enemy_type()
            base_health = {'weak': 10, 'normal': 20, 'strong': 40, 'fast': 15, 'boss': 200, 'swarm': 8}.get(enemy_type, 20)
            health = base_health * self.enemy_health_multiplier
            speed = {'weak': 1.5, 'normal': 1.0, 'strong': 0.7, 'fast': 2.5, 'boss': 0.5, 'swarm': 1.0}.get(enemy_type, 1.0)
            x, y = self._random_spawn_position()
            self.enemies.append({
                'pos': [x, y],
                'health': health,
                'max_health': health,
                'type': enemy_type,
                'speed': speed,
                'effects': {}
            })

    def _determine_enemy_type(self):
        level = self.upgrades.get('level', 1)
        if level < 5:
            return 'weak'
        elif level < 10:
            return random.choice(['weak', 'normal'])
        else:
            return random.choice(['normal', 'strong', 'fast', 'swarm', 'boss'])

    def _random_spawn_position(self):
        while True:
            x = random.randint(0, self.screen_width - 1)
            y = random.randint(0, self.screen_height - 1)
            if self._distance([x, y], self.player_pos) > 50:
                return x, y

    def _update_enemies(self):
        for enemy in self.enemies:
            direction = np.array(self.player_pos) - np.array(enemy['pos'])
            distance = np.linalg.norm(direction)
            if distance != 0:
                direction = direction / distance
            speed = enemy['speed']
            if 'slow' in enemy['effects']:
                speed *= 0.5
            enemy['pos'][0] += direction[0] * speed
            enemy['pos'][1] += direction[1] * speed

            # 効果時間の更新
            for effect in list(enemy['effects']):
                enemy['effects'][effect] -= self.clock.get_time() / 1000
                if enemy['effects'][effect] <= 0:
                    del enemy['effects'][effect]

    def _attack_enemies(self):
        self.hit_enemy = False
        for bullet in list(self.bullets):
            for enemy in list(self.enemies):
                if self._distance(bullet['pos'], enemy['pos']) < bullet['size']:
                    enemy['health'] -= bullet['damage']
                    self.hit_enemy = True
                    if bullet.get('slow_effect'):
                        enemy['effects']['slow'] = 2.0
                    if enemy['health'] <= 0:
                        self.enemies.remove(enemy)
                        self.enemies_defeated += 1
                        self.total_enemies_defeated += 1
                        self.reward += 10
                        self.score += 10
                        if random.random() < self.item_drop_probability:
                            self._spawn_item(enemy['pos'])
                    if not bullet['piercing'] and bullet in self.bullets:
                        self.bullets.remove(bullet)
                    break

    def _check_player_hit(self):
        for enemy in list(self.enemies):
            if self._distance(self.player_pos, enemy['pos']) < 15:
                if 'shield' in self.active_effects:
                    pass
                else:
                    damage = {'weak': 5, 'normal': 10, 'strong': 15, 'fast': 7, 'boss': 20, 'swarm': 3}.get(enemy['type'], 10)
                    self.player_health -= damage
                    if self.player_health <= 0:
                        self.game_over = True
                self.enemies.remove(enemy)
                break

    def _spawn_item(self, position):
        item_type = random.choice(['heal', 'power_up', 'shield'])
        self.items.append({'pos': position, 'type': item_type})

    def _check_item_pickup(self):
        self.item_picked = False
        for item in list(self.items):
            if self._distance(self.player_pos, item['pos']) < 15:
                if item['type'] == 'heal':
                    self.player_health = min(self.player_max_health, self.player_health + 30)
                elif item['type'] == 'power_up':
                    self.active_effects['power_up'] = {'duration': 10.0, 'value': 2}
                elif item['type'] == 'shield':
                    self.active_effects['shield'] = {'duration': 5.0}
                self.items.remove(item)
                self.item_picked = True
                self.reward += 20

    def _trigger_event(self):
        event_type = random.choice(['fast_enemies', 'big_enemy', 'enemy_swarm'])
        if event_type == 'fast_enemies':
            enemy_num = max(5, 5 + int(self.upgrades.get('level', 1) * 1.5))
            for _ in range(enemy_num):
                x, y = self._random_spawn_position()
                health = 15 * self.enemy_health_multiplier
                self.enemies.append({
                    'pos': [x, y],
                    'health': health,
                    'max_health': health,
                    'type': 'fast',
                    'speed': 2.5,
                    'effects': {}
                })
        elif event_type == 'big_enemy':
            x, y = self._random_spawn_position()
            health = 200 * self.enemy_health_multiplier
            self.enemies.append({
                'pos': [x, y],
                'health': health,
                'max_health': health,
                'type': 'boss',
                'speed': 0.5,
                'effects': {}
            })
        elif event_type == 'enemy_swarm':
            enemy_num = max(10, 10 + int(self.upgrades.get('level', 1) * 1.5))
            for _ in range(enemy_num):
                x, y = self._random_spawn_position()
                health = 8 * self.enemy_health_multiplier
                self.enemies.append({
                    'pos': [x, y],
                    'health': health,
                    'max_health': health,
                    'type': 'swarm',
                    'speed': 1.0,
                    'effects': {}
                })

    def _distance(self, pos1, pos2):
        return np.linalg.norm(np.array(pos1) - np.array(pos2))

    def _get_closest_enemy_distance(self):
        if not self.enemies:
            return None
        distances = [self._distance(self.player_pos, enemy['pos']) for enemy in self.enemies]
        return min(distances)

    def _open_upgrade_menu(self):
        self.upgrade_menu_open = True
        self.available_upgrade_options = self._get_available_upgrades()
        self.selected_upgrade = 0

    def _get_available_upgrades(self):
        options_pool = []
        for option in self.upgrades.keys():
            if self.upgrades[option] < self.max_upgrade_levels.get(option, 5):
                options_pool.append(option)
        # 未所持の武器を追加
        for weapon in ['fire', 'ice', 'wind', 'earth', 'water']:
            weapon_key = f"{weapon}_level"
            if weapon_key not in self.upgrades or self.upgrades.get(weapon_key, 0) < self.max_upgrade_levels.get(weapon_key, 5):
                options_pool.append(weapon_key)
        # ランダムに3つ選択
        if len(options_pool) <= 3:
            return options_pool
        else:
            return random.sample(options_pool, 3)

    def _handle_upgrade_selection(self, action):
        if action == 0:
            self.selected_upgrade = (self.selected_upgrade - 1) % len(self.available_upgrade_options)
        elif action == 1:
            self.selected_upgrade = (self.selected_upgrade + 1) % len(self.available_upgrade_options)
        elif action == 4:
            if self.available_upgrade_options:
                upgrade_key = self.available_upgrade_options[self.selected_upgrade]
                max_level = self.max_upgrade_levels.get(upgrade_key, 5)
                current_level = self.upgrades.get(upgrade_key, 0)
                if current_level < max_level:
                    self.upgrades[upgrade_key] = current_level + 1
                    self._apply_upgrade_effect(upgrade_key)
                self.upgrade_menu_open = False

    def _apply_upgrade_effect(self, upgrade_key):
        if upgrade_key == 'health':
            self.player_max_health += 10
            self.player_health = self.player_max_health
        elif upgrade_key == 'speed':
            self.player_speed += 0.5
        elif upgrade_key == 'attack_power':
            pass  # 攻撃力は弾のダメージに反映される
        elif upgrade_key == 'bullet_speed':
            pass  # 弾の速度は弾の生成時に反映される
        elif upgrade_key == 'evasion':
            pass  # 回避機能の実装が必要
        elif upgrade_key == 'fire_level':
            pass  # Fire武器の特性は弾の生成時に反映される
        elif upgrade_key == 'wind_level':
            pass  # Wind武器の特性は弾の生成時に反映される
        elif upgrade_key == 'water_level':
            pass  # Water武器の特性は弾の生成時に反映される
        elif upgrade_key == 'ice_level':
            pass  # Ice武器の特性は弾の生成時に反映される
        elif upgrade_key == 'earth_level':
            pass  # Earth武器の特性は弾の生成時に反映される

    def _spawn_item(self, position):
        item_type = random.choice(['heal', 'power_up', 'shield'])
        self.items.append({'pos': position, 'type': item_type})

    # 描画関連のメソッド
    def _draw_player(self):
        color = (0, 255, 0)
        if 'shield' in self.active_effects:
            color = (0, 255, 255)
        pygame.draw.circle(self.screen, color, [int(self.player_pos[0]), int(self.player_pos[1])], 10)

    def _draw_enemies(self):
        for enemy in self.enemies:
            color = {
                'weak': (255, 100, 100),
                'normal': (255, 0, 0),
                'strong': (200, 0, 0),
                'fast': (255, 150, 0),
                'boss': (255, 0, 255),
                'swarm': (255, 50, 50)
            }.get(enemy['type'], (255, 255, 255))
            pygame.draw.circle(self.screen, color, [int(enemy['pos'][0]), int(enemy['pos'][1])], 10)
            # ヘルスバー
            health_ratio = enemy['health'] / enemy['max_health']
            pygame.draw.rect(self.screen, (255, 0, 0), (enemy['pos'][0]-10, enemy['pos'][1]-15, 20, 3))
            pygame.draw.rect(self.screen, (0, 255, 0), (enemy['pos'][0]-10, enemy['pos'][1]-15, 20 * health_ratio, 3))

    def _draw_bullets(self):
        for bullet in self.bullets:
            color = {
                'basic': (255, 255, 0),
                'fire': (255, 165, 0),
                'ice': (0, 191, 255),
                'wind': (173, 216, 230),
                'earth': (139, 69, 19),
                'water': (0, 0, 255)
            }.get(bullet['weapon'], (255, 255, 255))
            pygame.draw.circle(self.screen, color, [int(bullet['pos'][0]), int(bullet['pos'][1])], int(bullet['size'] / 5))

    def _draw_items(self):
        for item in self.items:
            color = {
                'heal': (0, 255, 0),
                'power_up': (255, 0, 0),
                'shield': (0, 255, 255)
            }.get(item['type'], (255, 255, 255))
            pygame.draw.rect(self.screen, color, pygame.Rect(item['pos'][0]-5, item['pos'][1]-5, 10, 10))

    def _draw_health_bar(self):
        pygame.draw.rect(self.screen, (255, 0, 0), (10, 10, 100, 20))
        pygame.draw.rect(self.screen, (0, 255, 0), (10, 10, 100 * (self.player_health / self.player_max_health), 20))

    def _draw_upgrade_count(self):
        upgrades_text = f"Level: {self.level}  Exp: {self.enemies_defeated}/{self.enemies_to_defeat_for_upgrade}"
        text = self.font.render(upgrades_text, True, (255, 255, 255))
        self.screen.blit(text, (10, 40))

    def _draw_enemy_defeated_count(self):
        text = self.font.render(f'Enemies Defeated: {self.total_enemies_defeated}', True, (255, 255, 255))
        self.screen.blit(text, (10, 70))

    def _draw_upgrades_and_weapons(self):
        upgrades_text = "Upgrades: " + ", ".join(
            [f"{k.capitalize()}({v})" for k, v in self.upgrades.items() if v > 0 and not k.endswith('_level')]
        )
        weapon_upgrades_text = ", ".join(
            [f"{weapon.replace('_level', '').capitalize()} Lv{self.upgrades[weapon]}"
             for weapon in self.upgrades.keys()
             if weapon.endswith('_level') and self.upgrades[weapon] > 0]
        )
        weapons_text = "Weapons: " + ", ".join(
            [weapon.replace('_level', '').capitalize()
             for weapon in self.upgrades.keys()
             if weapon.endswith('_level') and self.upgrades[weapon] > 0]
        )
        upgrades_surface = self.small_font.render(upgrades_text, True, (255, 255, 255))
        weapon_upgrades_surface = self.small_font.render(f"Weapon Upgrades: {weapon_upgrades_text}", True, (255, 255, 255))
        weapons_surface = self.small_font.render(weapons_text, True, (255, 255, 255))
        self.screen.blit(upgrades_surface, (10, self.screen_height - 90))
        self.screen.blit(weapon_upgrades_surface, (10, self.screen_height - 60))
        self.screen.blit(weapons_surface, (10, self.screen_height - 30))

    def _draw_upgrade_menu(self):
        menu_surface = pygame.Surface((self.screen_width, self.screen_height), pygame.SRCALPHA)
        menu_surface.fill((0, 0, 0, 180))
        self.screen.blit(menu_surface, (0, 0))

        for idx, option in enumerate(self.available_upgrade_options):
            color = (255, 255, 0) if idx == self.selected_upgrade else (255, 255, 255)
            display_name = option.replace('_', ' ').capitalize()
            level = self.upgrades.get(option, 0)
            max_level = self.max_upgrade_levels.get(option, 5)
            if level > 0:
                text = self.font.render(f"{display_name} (Lv{level}/{max_level})", True, color)
            else:
                text = self.font.render(f"Unlock {display_name}", True, color)
            self.screen.blit(text, (self.screen_width // 2 - 150, self.screen_height // 2 - 60 + idx * 40))

        instruction = self.font.render("Press ENTER to upgrade", True, (255, 255, 255))
        self.screen.blit(instruction, (self.screen_width // 2 - 100, self.screen_height // 2 + 80))

    def _draw_game_over(self):
        game_over_text = self.font.render("Game Over", True, (255, 255, 255))
        self.screen.blit(game_over_text, (self.screen_width // 2 - 60, self.screen_height // 2 - 80))

        steps_text = self.font.render(f"Time Survived: {self.steps_in_episode // 60} seconds", True, (255, 255, 255))
        self.screen.blit(steps_text, (self.screen_width // 2 - 100, self.screen_height // 2 - 40))

        reward_text = self.font.render(f"Total Reward: {self.reward:.2f}", True, (255, 255, 255))
        self.screen.blit(reward_text, (self.screen_width // 2 - 100, self.screen_height // 2))

        enemy_defeated_text = self.font.render(f"Total Enemies Defeated: {self.total_enemies_defeated}", True, (255, 255, 255))
        self.screen.blit(enemy_defeated_text, (self.screen_width // 2 - 150, self.screen_height // 2 + 40))

        retry_text = self.font.render("Press 'R' to Retry", True, (255, 255, 255))
        self.screen.blit(retry_text, (self.screen_width // 2 - 100, self.screen_height // 2 + 80))

    def _handle_events(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                gym.logger.info("Quit event received")
                self.close()
                exit()
            elif event.type == pygame.KEYDOWN:
                if self.game_over:
                    if event.key == pygame.K_r:
                        self.reset()
                elif self.upgrade_menu_open:
                    if event.key == pygame.K_RETURN:
                        self._handle_upgrade_selection(4)
                    elif event.key == pygame.K_UP:
                        self._handle_upgrade_selection(0)
                    elif event.key == pygame.K_DOWN:
                        self._handle_upgrade_selection(1)

    def _reset_on_death(self):
        self.enemies_defeated = 0
        self.total_enemies_defeated = 0
        self.upgrades = {
            'speed': 0,
            'attack_power': 1,
            'health': 0,
            'bullet_speed': 0,
            'evasion': 0,
            'fire_level': 0,
            'wind_level': 0,
            'water_level': 0,
            'ice_level': 0,    # 新規追加
            'earth_level': 0    # 新規追加
        }

    def _initialize_rendering(self):
        if self.render_mode == 'human' and self.screen is None:
            self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
            pygame.display.set_caption('Vampire Survivor Gym Environment')
            self.clock = pygame.time.Clock()
            self.font = pygame.font.Font(None, 36)
            self.small_font = pygame.font.Font(None, 24)

    # 環境登録
from gymnasium.envs.registration import register

register(
    id='VampireSurvivor-v1',
    entry_point='vampgym:VampireSurvivorEnv',
)

Training Code:

import warnings
from functools import partial as bind
import os
import datetime
import dreamerv3
import embodied
import gymnasium as gym
import vampgym  # Vampire Survivor 環境をインポート

warnings.filterwarnings('ignore', '.*truncated to dtype int32.*')

def main():
    # ログディレクトリを動的に生成(例:タイムスタンプ付き)
    timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
    # logdir = f'/home/user1/logdir/{timestamp}-vampire_survivor'
    logdir = f'/home/user1/logdir/20241202T060612-vampire_survivor'
   

    # デフォルトのconfigを取得して、追加の設定を適用
    config = embodied.Config(dreamerv3.Agent.configs.get('defaults', {}))
    config = config.update({
        **dreamerv3.Agent.configs.get('size100m', {}),
        'logdir': logdir,
        'run.train_ratio': 32,
        'run.steps': 1e10,
        'batch_size': 2,           # バッチサイズ
        'run.num_envs': 1,         # 環境数
        'run.num_envs_eval': 1,    # 評価環境数
        'replay.size': 5e5,        # リプレイバッファのサイズ
        # 'batch_length': 33,        # バッチ長さ
        # 'filter': 'all',           # ログフィルターの設定
        # 'replay.online': True,     # リプレイバッファのオンライン設定
})


    # フラグをパースしてconfigを更新
    config = embodied.Flags(config).parse()

    # ログディレクトリの作成(存在しない場合のみ)
    print('Logdir:', config.logdir)
    logdir_path = embodied.Path(config.logdir)
    os.makedirs(str(logdir_path), exist_ok=True)  # ディレクトリを作成
    config.save(logdir_path / 'config.yaml')      # 設定を保存

    # エージェントの生成
    def make_agent(config):
        env = make_env(config)
        agent = dreamerv3.Agent(env.obs_space, env.act_space, config)
        env.close()
        return agent

    # ログ出力用のLoggerを生成
    def make_logger(config):
        logdir = embodied.Path(config.logdir)
        return embodied.Logger(embodied.Counter(), [
            embodied.logger.TerminalOutput(config.filter),
            embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
            embodied.logger.TensorBoardOutput(logdir),
        ])

    # リプレイバッファを生成
    def make_replay(config):
        return embodied.replay.Replay(
            length=config.batch_length,
            capacity=config.replay.size,
            directory=embodied.Path(config.logdir) / 'replay',
            online=config.replay.online
        )

    # 環境の生成
    def make_env(config, env_id=0, render_mode='human'):
        from embodied.envs import from_gym
        env = gym.make('VampireSurvivor-v1', render_mode=render_mode)  # 学習中は'rgb_array'に設定
        env = from_gym.FromGym(env)
        env = dreamerv3.wrap_env(env, config)  # 環境をDreamer用にラップ
        return env

    # 学習の開始
    args = embodied.Config(
        **config.run,
        logdir=config.logdir,
        batch_size=config.batch_size,
        batch_length=config.batch_length,
        batch_length_eval=config.batch_length_eval,
        replay_context=config.replay_context,
        resume=True,  # ここでresumeフラグを設定
    )

    # トレーニングの実行
    embodied.run.train(
        bind(make_agent, config),
        bind(make_replay, config),
        bind(make_env, config),
        bind(make_logger, config),
        args
    )

if __name__ == '__main__':
    main()





Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant