Skip to content

Latest commit

 

History

History
72 lines (49 loc) · 6.17 KB

README_jp.md

File metadata and controls

72 lines (49 loc) · 6.17 KB

Accelerated-TD

加速勾配時間差学習アルゴリズム(Accelerated Gradient Temporal Difference Learning algorithm, ATD)のPython実装。

GitHub code size in bytes GitHub GitHub issues Gitee issues GitHub Pull requests Contribution GitHub Repo stars GitHub forks Gitee Repo stars Gitee forks

紹介

エージェント

PlainATDAgent 行列を直接更新し、SVDATDAgentDiagonalizedSVDATDAgent はそれぞれ特異値分解を更新します。これは、論文の著者(以下の論文へのリンク)にはそれほど複雑ではないように見えます。 SVDATDAgentDiagonalizedSVDATDAgent の違いは、 SVDATDAgent がここで説明されている方法を採用していることです:Brand 2006 および DiagonalizedSVDATDAgent は、ここで説明した方法Gahring 2015 を採用して対角化します 行列を使用して、行列の疑似逆行列を計算しやすくします。私はこの方法を完全には理解していませんが。 また、 TDAgentと呼ばれる従来の勾配時間差エージェントを実装しました。以下に説明するいくつかの環境でそれらをテストしました。

バックエンドのサポート

PyTorch(CPU)のバックエンドをサポートして、 numpy.ndarrayからtorch.Tensorへの変換プロセスをスキップします。このサポートを使用するには、 atdモジュールをインポートする前に次のコードを追加できます:

import os
os.environ["ATD_BACKEND"] = "NumPy"  # 又は "PyTorch"

自分でテストを実行する場合は、このリポジトリのクローンを作成して、 python algorithm_test/<random_walk 又は boyans_chain>.pyを実行します。:)

必須

  • Python>=3.9
  • NumPy>=1.19
  • PyTorchをバックエンドとして使用する場合は、Torch>=1.10も必要です。
  • テストスクリプトを実行する場合は、Matplotlib>=3.3.3も必要です。
  • テストスクリプトを実行する場合は、Tqdmも必要です。

テスト

ランダムウォーク(Random Walk)

この環境はサットンの本からのものです。

コードファイルはこれで、結果はここrandom_walk

ボヤンのチェーン(Boyan's Chain)

この環境は、Boyan 1999 に示されています。

コードファイルはこれで、結果はここboyans_chain

手順

アルゴリズムの実装をプロジェクトにインポートするには、慣れていない場合は次の手順に従ってください。

  1. リポジトリのクローンを作成し、 atd.pyを目的の場所にコピーします。 GitHub から .zip ファイルをダウンロードした場合は、忘れずに解凍してください。
  2. 次のコードをPythonスクリプトの先頭に追加します。
    from atd import TDAgent, SVDATDAgent, DiagonalizedSVDATDAgent, PlainATDAgent # または任意のエージェントから
  3. ターゲットディレクトリが、実行するPythonメインプログラムファイルが配置されているディレクトリと異なる場合は、手順2のコードスニペットの代わりにこのコードスニペットを使用して、適切なディレクトリを環境変数に追加し、Pythonが通訳はそれを見つけることができます。または、Pythonの新しいバージョンで提供されている importlibを検討することもできます。
    import sys
    
    sys.path.append("<atd.pyを配置したディレクトリ>")
    from atd import TDAgent, SVDATDAgent, DiagonalizedSVDATDAgent, PlainATDAgent # または任意のエージェントから
  4. 次のようなエージェントを初期化すると、次のように使用できます!
    agent = TDAgent(lr=0.01, lambd=0, observation_space_n=4, action_space_n=2)

参照:Gahring 2016

プルリクエストについてお気軽にご連絡ください。コメントをお待ちしております。