diff --git a/src/Command.cpp b/src/Command.cpp index 83d02f1..20e3019 100644 --- a/src/Command.cpp +++ b/src/Command.cpp @@ -148,13 +148,23 @@ AnalyzeCommand( int argc, char **argv ) SetUseNN(false); break; case COMMAND_NO_GPU: - SetDeviceId(-1); + SetDeviceIds({ -1 }); break; case COMMAND_NO_EXPAND: SetNoExpand(true); break; case COMMAND_DEVICE_ID: - SetDeviceId(atoi(argv[++i])); + { + vector device_ids; + char copy[BUF_SIZE]; + STRCPY(copy, BUF_SIZE, argv[++i]); + char* token = strtok(copy, ","); + while (token != NULL) { + device_ids.push_back(atoi(token)); + token = strtok(NULL, ","); + } + SetDeviceIds(device_ids); + } break; default: for (int j = 0; j < COMMAND_MAX; j++){ diff --git a/src/UctSearch.cpp b/src/UctSearch.cpp index d221d33..fc24fbf 100644 --- a/src/UctSearch.cpp +++ b/src/UctSearch.cpp @@ -73,7 +73,7 @@ struct policy_eval_req { }; void ReadWeights(); -void EvalNode(); +void EvalNode(int thread_no); //void EvalUctNode(std::vector& indices, std::vector& color, std::vector& trans, std::vector& data, std::vector& path); //////////////// @@ -197,14 +197,14 @@ ray_clock::time_point begin_time; static bool early_pass = true; static bool use_nn = true; -static int device_id = -2; +static vector device_ids = { -2 }; static std::queue> eval_policy_queue; static std::queue> eval_value_queue; static int eval_count_policy, eval_count_value; static double owner_nn[BOARD_MAX]; -static CNTK::FunctionPtr nn_policy = nullptr; -static CNTK::FunctionPtr nn_value = nullptr; +static vector nn_policy_list; +static vector nn_value_list; //template double atomic_fetch_add(std::atomic *obj, double arg) { @@ -412,9 +412,9 @@ SetUseNN(bool flag) } void -SetDeviceId(int id) +SetDeviceIds(const vector& id) { - device_id = id; + device_ids = id; } void @@ -493,7 +493,7 @@ InitializeUctSearch( void ) exit(1); } - if (use_nn && !nn_policy) + if (use_nn && nn_policy_list.size() == 0) ReadWeights(); } @@ -649,8 +649,11 @@ UctSearchGenmove( game_info_t *game, int color ) handle.push_back(make_unique(ParallelUctSearch, &t_arg[i])); } - if (use_nn) - handle.push_back(make_unique(EvalNode)); + if (use_nn) { + for (int i = 0; i < device_ids.size(); i++) { + handle.push_back(make_unique(EvalNode, i)); + } + } for (auto &t : handle) { t->join(); @@ -670,8 +673,12 @@ UctSearchGenmove( game_info_t *game, int color ) for (int i = 0; i < threads; i++) { handle.push_back(make_unique(ParallelUctSearch, &t_arg[i])); } - if (use_nn) - handle.push_back(make_unique(EvalNode)); + + if (use_nn) { + for (int i = 0; i < device_ids.size(); i++) { + handle.push_back(make_unique(EvalNode, i)); + } + } for (auto &t : handle) { t->join(); @@ -776,21 +783,21 @@ UctSearchGenmove( game_info_t *game, int color ) // 予測読み // /////////////// void -UctSearchPondering( game_info_t *game, int color ) +UctSearchPondering(game_info_t *game, int color) { int pos; if (!pondering_mode) { - return ; + return; } // 探索情報をクリア - memset(statistic, 0, sizeof(statistic_t) * board_max); - fill_n(criticality_index, board_max, 0); + memset(statistic, 0, sizeof(statistic_t) * board_max); + fill_n(criticality_index, board_max, 0); for (int i = 0; i < board_max; i++) { - criticality[i] = 0.0; + criticality[i] = 0.0; } - + po_info.count = 0; for (int i = 0; i < pure_board_max; i++) { @@ -809,7 +816,7 @@ UctSearchPondering( game_info_t *game, int color ) if (uct_node[current_root].child_num <= 1) { ponder = false; pondering_stop = true; - return ; + return; } ponder = true; @@ -827,8 +834,11 @@ UctSearchPondering( game_info_t *game, int color ) handle.push_back(make_unique(ParallelUctSearchPondering, &t_arg[i])); } - if (use_nn) - handle.push_back(make_unique(EvalNode)); + if (use_nn) { + for (int i = 0; i < device_ids.size(); i++) { + handle.push_back(make_unique(EvalNode, i)); + } + } return ; } @@ -2415,8 +2425,9 @@ CorrectDescendentNodes(vector &indexes, int index) extern char uct_params_path[1024]; static CNTK::DeviceDescriptor -GetDevice() +GetDevice(int no) { + int device_id = device_ids[no]; if (device_id == -1) return CNTK::DeviceDescriptor::CPUDevice(); if (device_id == -2) @@ -2436,37 +2447,43 @@ ReadWeights() cerr << "Init CNTK" << endl; - CNTK::DeviceDescriptor device = GetDevice(); + for (int i = 0; i < device_ids.size(); i++) { + auto device = GetDevice(i); - wstring policy_name = path; - policy_name += L"/model2.bin"; - nn_policy = CNTK::Function::Load(policy_name, device); + wstring policy_name = path; + policy_name += L"/model2.bin"; + auto nn_policy = CNTK::Function::Load(policy_name, device); - wstring value_name = path; - value_name += L"/model3.bin"; - nn_value = CNTK::Function::Load(value_name, device); + wstring value_name = path; + value_name += L"/model3.bin"; + auto nn_value = CNTK::Function::Load(value_name, device); - if (!nn_policy || !nn_value) - { - cerr << "Get EvalModel failed\n"; - } + if (!nn_policy || !nn_value) + { + cerr << "Get EvalModel failed\n"; + abort(); + } + + nn_policy_list.push_back(nn_policy); + nn_value_list.push_back(nn_value); #if 0 - wcerr << L"***POLICY" << endl; - for (auto var : nn_policy->Inputs()) { - wcerr << var.AsString() << endl; - } - for (auto var : nn_policy->Outputs()) { - wcerr << var.AsString() << endl; - } - wcerr << L"***VALUE" << endl; - for (auto var : nn_value->Inputs()) { - wcerr << var.AsString() << endl; - } - for (auto var : nn_value->Outputs()) { - wcerr << var.AsString() << endl; - } + wcerr << L"***POLICY" << endl; + for (auto var : nn_policy->Inputs()) { + wcerr << var.AsString() << endl; + } + for (auto var : nn_policy->Outputs()) { + wcerr << var.AsString() << endl; + } + wcerr << L"***VALUE" << endl; + for (auto var : nn_value->Inputs()) { + wcerr << var.AsString() << endl; + } + for (auto var : nn_value->Outputs()) { + wcerr << var.AsString() << endl; + } #endif + } cerr << "ok" << endl; } @@ -2499,13 +2516,17 @@ GetOutputVaraiableByName(CNTK::FunctionPtr evalFunc, wstring varName, CNTK::Vari } void -EvalPolicy(const std::vector>& requests, +EvalPolicy(int thread_no, + const std::vector>& requests, std::vector& data_basic, std::vector& data_features, std::vector& data_history, std::vector& data_color, std::vector& data_komi) { if (requests.size() == 0) return; + auto device = GetDevice(thread_no); + auto nn_policy = nn_policy_list[thread_no]; + CNTK::Variable var_basic, var_features, var_history, var_color, var_komi; GetInputVariableByName(nn_policy, L"basic", var_basic); GetInputVariableByName(nn_policy, L"features", var_features); @@ -2531,7 +2552,6 @@ EvalPolicy(const std::vector>& requests, CNTK::ValuePtr value_ol; - CNTK::DeviceDescriptor device = GetDevice(); std::unordered_map inputs = { { var_basic, value_basic }, { var_features, value_features }, @@ -2607,13 +2627,17 @@ EvalPolicy(const std::vector>& requests, void -EvalValue(const std::vector>& requests, +EvalValue(int thread_no, + const std::vector>& requests, std::vector& data_basic, std::vector& data_features, std::vector& data_history, std::vector& data_color, std::vector& data_komi, std::vector& data_safety) { if (requests.size() == 0) return; + auto device = GetDevice(thread_no); + auto nn_value = nn_value_list[thread_no]; + CNTK::Variable var_basic, var_features, var_history, var_color, var_komi, var_safety; GetInputVariableByName(nn_value, L"basic", var_basic); GetInputVariableByName(nn_value, L"features", var_features); @@ -2642,7 +2666,6 @@ EvalValue(const std::vector>& requests, CNTK::ValuePtr value_p; - CNTK::DeviceDescriptor device = GetDevice(); std::unordered_map inputs = { { var_basic, value_basic }, { var_features, value_features }, @@ -2700,20 +2723,23 @@ EvalValue(const std::vector>& requests, eval_count_value += requests.size(); } -static std::vector eval_input_data_basic; -static std::vector eval_input_data_features; -static std::vector eval_input_data_history; -static std::vector eval_input_data_color; -static std::vector eval_input_data_komi; -static std::vector eval_input_data_safety; - -void EvalNode() { +void EvalNode(int thread_no) { #if 1 + std::vector eval_input_data_basic; + std::vector eval_input_data_features; + std::vector eval_input_data_history; + std::vector eval_input_data_color; + std::vector eval_input_data_komi; + std::vector eval_input_data_safety; + + int num_eval = 0; + while (true) { mutex_queue.lock(); if (!running && ((!reuse_subtree && !ponder) || (eval_policy_queue.empty() && eval_value_queue.empty()))) { mutex_queue.unlock(); + cerr << "Eval #" << thread_no << " " << num_eval << endl; break; } @@ -2748,7 +2774,8 @@ void EvalNode() { eval_input_data_color.push_back(req->color - 1); eval_input_data_komi.push_back(komi[0]); } - EvalPolicy(requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi); + num_eval += requests.size(); + EvalPolicy(thread_no, requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi); mutex_queue.lock(); } @@ -2777,7 +2804,8 @@ void EvalNode() { eval_input_data_komi.push_back(komi[0]); } eval_input_data_safety.resize(requests.size() * pure_board_max * 8); - EvalValue(requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi, eval_input_data_safety); + num_eval += requests.size(); + EvalValue(thread_no, requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi, eval_input_data_safety); } } #endif diff --git a/src/UctSearch.h b/src/UctSearch.h index 27722ac..a46b66d 100644 --- a/src/UctSearch.h +++ b/src/UctSearch.h @@ -236,6 +236,6 @@ void SetReuseSubtree( bool flag ); void SetUseNN(bool flag); -void SetDeviceId(int id); +void SetDeviceIds( const std::vector& ids ); #endif