Skip to content

Commit

Permalink
Refactor threading
Browse files Browse the repository at this point in the history
  • Loading branch information
zakki committed Aug 20, 2017
1 parent 3a6f21a commit 524b9d0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 65 deletions.
139 changes: 75 additions & 64 deletions src/UctSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ int playout = CONST_PLAYOUT;
double default_remaining_time = ALL_THINKING_TIME;

// 各スレッドに渡す引数
thread_arg_t t_arg[THREAD_MAX];
vector<thread_arg_t> t_arg;

// プレイアウトの統計情報
statistic_t statistic[BOARD_MAX];
Expand Down Expand Up @@ -152,15 +152,17 @@ bool pondered = false;

double time_limit;

std::thread *handle[THREAD_MAX]; // スレッドのハンドル
std::vector<std::unique_ptr<std::thread>> handle; // スレッドのハンドル

static volatile bool running;

// UCB Bonusの等価パラメータ
double bonus_equivalence = BONUS_EQUIVALENCE;
// UCB Bonusの重み
double bonus_weight = BONUS_WEIGHT;

// 乱数生成器
std::mt19937_64 *mt[THREAD_MAX];
std::vector<std::unique_ptr<std::mt19937_64>> mt;

// Last-Good-Reply
LGR lgr;
Expand Down Expand Up @@ -287,6 +289,8 @@ static void UpdateNodeStatistic( game_info_t *game, int winner, statistic_t *nod
// 結果の更新
static void UpdateResult( child_node_t *child, int result, int current );

// 乱数の初期化
static void InitRand();

/////////////////////
// 予測読みの設定 //
Expand Down Expand Up @@ -337,6 +341,8 @@ SetThread( int new_thread )
threads = new_thread;

lgr_ctx.resize(threads);

InitRand();
}


Expand Down Expand Up @@ -444,6 +450,20 @@ SetTimeSettings( int main_time, int byoyomi, int stone )
}
}

////////////////
// 乱数の初期化 //
////////////////
static void
InitRand()
{
mt.clear();
random_device rd;
for (int i = 0; i < threads; i++) {
mt.push_back(make_unique<mt19937_64>(rd()));
}
}


/////////////////////////
// UCT探索の初期設定 //
/////////////////////////
Expand Down Expand Up @@ -490,12 +510,7 @@ InitializeSearchSetting( void )
}

// 乱数の初期化
for (int i = 0; i < THREAD_MAX; i++) {
if (mt[i]) {
delete mt[i];
}
mt[i] = new mt19937_64((unsigned int)(time(NULL) + i));
}
InitRand();

// Initialize Last-Good-Reply
lgr.reset();
Expand Down Expand Up @@ -558,16 +573,10 @@ StopPondering( void )

if (ponder) {
pondering_stop = true;
for (int i = 0; i < threads; i++) {
handle[i]->join();
delete handle[i];
handle[i] = nullptr;
}
if (use_nn) {
handle[threads]->join();
delete handle[threads];
handle[threads] = nullptr;
for (auto& t : handle) {
t->join();
}
handle.clear();

ponder = false;
pondered = true;
Expand Down Expand Up @@ -642,26 +651,22 @@ UctSearchGenmove( game_info_t *game, int color )
// 探索時間とプレイアウト回数の予定値を出力
PrintPlayoutLimits(time_limit, po_info.halt);

t_arg.resize(threads);
running = true;
for (int i = 0; i < threads; i++) {
t_arg[i].thread_id = i;
t_arg[i].game = game;
t_arg[i].color = color;
handle[i] = new thread(ParallelUctSearch, &t_arg[i]);
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}

if (use_nn)
handle[threads] = new thread(EvalNode);
handle.push_back(make_unique<thread>(EvalNode));

for (int i = 0; i < threads; i++) {
handle[i]->join();
delete handle[i];
handle[i] = nullptr;
}
if (use_nn) {
handle[threads]->join();
delete handle[threads];
handle[threads] = nullptr;
for (auto &t : handle) {
t->join();
}
handle.clear();

// 着手が41手以降で,
// 時間延長を行う設定になっていて,
Expand All @@ -672,22 +677,17 @@ UctSearchGenmove( game_info_t *game, int color )
ExtendTime()) {
po_info.halt = (int)(1.5 * po_info.halt);
time_limit *= 1.5;
running = true;
for (int i = 0; i < threads; i++) {
handle[i] = new thread(ParallelUctSearch, &t_arg[i]);
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}
if (use_nn)
handle[threads] = new thread(EvalNode);
handle.push_back(make_unique<thread>(EvalNode));

for (int i = 0; i < threads; i++) {
handle[i]->join();
delete handle[i];
handle[i] = nullptr;
}
if (use_nn) {
handle[threads]->join();
delete handle[threads];
handle[threads] = nullptr;
for (auto &t : handle) {
t->join();
}
handle.clear();
}

uct_child = uct_node[current_root].child;
Expand Down Expand Up @@ -831,15 +831,17 @@ UctSearchPondering( game_info_t *game, int color )
// Dynamic Komiの算出(置碁のときのみ)
DynamicKomi(game, &uct_node[current_root], color);

t_arg.resize(threads);
running = true;
for (int i = 0; i < threads; i++) {
t_arg[i].thread_id = i;
t_arg[i].game = game;
t_arg[i].color = color;
handle[i] = new thread(ParallelUctSearchPondering, &t_arg[i]);
handle.push_back(make_unique<thread>(ParallelUctSearchPondering, &t_arg[i]));
}

if (use_nn)
handle[threads] = new thread(EvalNode);
handle.push_back(make_unique<thread>(EvalNode));

return ;
}
Expand Down Expand Up @@ -905,18 +907,19 @@ UctSearchStat(game_info_t *game, int color, int num)
// Dynamic Komiの算出(置碁のときのみ)
DynamicKomi(game, &uct_node[current_root], color);

t_arg.resize(threads);
running = true;
for (i = 0; i < threads; i++) {
t_arg[i].thread_id = i;
t_arg[i].game = game;
t_arg[i].color = color;
handle[i] = new thread(ParallelUctSearch, &t_arg[i]);
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}

for (i = 0; i < threads; i++) {
handle[i]->join();
delete handle[i];
handle[i] = nullptr;
for (auto &t : handle) {
t->join();
}
handle.clear();

use_nn = org_use_nn;

Expand Down Expand Up @@ -1481,7 +1484,7 @@ ParallelUctSearch( thread_arg_t *arg )
memcpy(game->seki, seki, sizeof(bool) * BOARD_MAX);
// 1回プレイアウトする
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
UctSearch(game, color, mt[targ->thread_id].get(), lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
// 探索を打ち切るか確認
interruption = InterruptionCheck();
// ハッシュに余裕があるか確認
Expand All @@ -1495,6 +1498,7 @@ ParallelUctSearch( thread_arg_t *arg )
if (GetSpendTime(begin_time) > time_limit) break;
if (!enough_size) cerr << "HASH TABLE FULL" << endl;
} while (po_info.count < po_info.halt && !interruption && enough_size);
running = false;
} else {
do {
// Wait if dcnn queue is full
Expand All @@ -1506,7 +1510,7 @@ ParallelUctSearch( thread_arg_t *arg )
memcpy(game->seki, seki, sizeof(bool) * BOARD_MAX);
// 1回プレイアウトする
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
UctSearch(game, color, mt[targ->thread_id].get(), lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
// 探索を打ち切るか確認
interruption = InterruptionCheck();
// ハッシュに余裕があるか確認
Expand Down Expand Up @@ -1550,7 +1554,7 @@ ParallelUctSearchPondering( thread_arg_t *arg )
CopyGame(game, targ->game);
// 1回プレイアウトする
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
UctSearch(game, color, mt[targ->thread_id].get(), lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
// ハッシュに余裕があるか確認
enough_size = CheckRemainingHashSize();
// OwnerとCriticalityを計算する
Expand All @@ -1560,6 +1564,7 @@ ParallelUctSearchPondering( thread_arg_t *arg )
interval += CRITICALITY_INTERVAL;
}
} while (!pondering_stop && enough_size);
running = false;
} else {
do {
// Wait if dcnn queue is full
Expand All @@ -1570,7 +1575,7 @@ ParallelUctSearchPondering( thread_arg_t *arg )
CopyGame(game, targ->game);
// 1回プレイアウトする
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
UctSearch(game, color, mt[targ->thread_id].get(), lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
// ハッシュに余裕があるか確認
enough_size = CheckRemainingHashSize();
} while (!pondering_stop && enough_size);
Expand Down Expand Up @@ -2183,7 +2188,7 @@ int
UctAnalyze( game_info_t *game, int color )
{
int pos;
thread *handle[THREAD_MAX];
vector<unique_ptr<std::thread>> handle;

// 探索情報をクリア
memset(statistic, 0, sizeof(statistic_t) * board_max);
Expand All @@ -2206,19 +2211,20 @@ UctAnalyze( game_info_t *game, int color )

po_info.halt = 10000;

t_arg.resize(threads);
running = true;
for (int i = 0; i < threads; i++) {
t_arg[i].thread_id = i;
t_arg[i].game = game;
t_arg[i].color = color;
handle[i] = new std::thread(ParallelUctSearch, &t_arg[i]);
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}


for (int i = 0; i < threads; i++) {
handle[i]->join();
delete handle[i];
handle[i] = nullptr;
for (auto &t : handle) {
t->join();
}
handle.clear();

use_nn = org_use_nn;

Expand Down Expand Up @@ -2290,7 +2296,7 @@ UctSearchGenmoveCleanUp( game_info_t *game, int color )
double wp;
int count;
child_node_t *uct_child;
thread *handle[THREAD_MAX];
vector<unique_ptr<std::thread>> handle;

memset(statistic, 0, sizeof(statistic_t)* board_max);
fill_n(criticality_index, board_max, 0);
Expand All @@ -2316,20 +2322,26 @@ UctSearchGenmoveCleanUp( game_info_t *game, int color )

po_info.halt = po_info.num;

bool org_use_nn = use_nn;
use_nn = false;

DynamicKomi(game, &uct_node[current_root], color);

t_arg.reserve(threads);
running = true;
for (int i = 0; i < threads; i++) {
t_arg[i].thread_id = i;
t_arg[i].game = game;
t_arg[i].color = color;
handle[i] = new std::thread(ParallelUctSearch, &t_arg[i]);
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}

for (int i = 0; i < threads; i++) {
handle[i]->join();
delete handle[i];
handle[i] = nullptr;
for (auto &t : handle) {
t->join();
}
handle.clear();

use_nn = org_use_nn;

uct_child = uct_node[current_root].child;

Expand Down Expand Up @@ -2664,7 +2676,6 @@ void EvalNode() {
#if 1
while (true) {
mutex_queue.lock();
bool running = handle[0] != nullptr;
if (!running
&& ((!reuse_subtree && !ponder) || (eval_policy_queue.empty() && eval_value_queue.empty()))) {
mutex_queue.unlock();
Expand Down
1 change: 0 additions & 1 deletion src/UctSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
class LGR;
class LGRContext;

const int THREAD_MAX = 32; // 使用するスレッド数の最大値
const int MAX_NODES = 1000000; // UCTのノードの配列のサイズ
const double ALL_THINKING_TIME = 90.0; // 持ち時間(デフォルト)
const int CONST_PLAYOUT = 10000; // 1手あたりのプレイアウト回数(デフォルト)
Expand Down

0 comments on commit 524b9d0

Please sign in to comment.