Skip to content

Commit

Permalink
Support multiple GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
zakki committed Aug 20, 2017
1 parent 524b9d0 commit 4987925
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 62 deletions.
14 changes: 12 additions & 2 deletions src/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,23 @@ AnalyzeCommand( int argc, char **argv )
SetUseNN(false);
break;
case COMMAND_NO_GPU:
SetDeviceId(-1);
SetDeviceIds({ -1 });
break;
case COMMAND_NO_EXPAND:
SetNoExpand(true);
break;
case COMMAND_DEVICE_ID:
SetDeviceId(atoi(argv[++i]));
{
vector<int> device_ids;
char copy[BUF_SIZE];
STRCPY(copy, BUF_SIZE, argv[++i]);
char* token = strtok(copy, ",");
while (token != NULL) {
device_ids.push_back(atoi(token));
token = strtok(NULL, ",");
}
SetDeviceIds(device_ids);
}
break;
default:
for (int j = 0; j < COMMAND_MAX; j++){
Expand Down
155 changes: 96 additions & 59 deletions src/UctSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct policy_eval_req {
};

void ReadWeights();
void EvalNode();
void EvalNode(int thread_no);
//void EvalUctNode(std::vector<int>& indices, std::vector<int>& color, std::vector<int>& trans, std::vector<float>& data, std::vector<int>& path);

////////////////
Expand Down Expand Up @@ -198,14 +198,14 @@ ray_clock::time_point begin_time;
static bool early_pass = true;

static bool use_nn = true;
static int device_id = 0;
static vector<int> device_ids = { 0 };
static std::queue<std::shared_ptr<policy_eval_req>> eval_policy_queue;
static std::queue<std::shared_ptr<value_eval_req>> eval_value_queue;
static int eval_count_policy, eval_count_value;
static double owner_nn[BOARD_MAX];

static CNTK::FunctionPtr nn_policy = nullptr;
static CNTK::FunctionPtr nn_value = nullptr;
static vector<CNTK::FunctionPtr> nn_policy_list;
static vector<CNTK::FunctionPtr> nn_value_list;

//template<double>
double atomic_fetch_add(std::atomic<double> *obj, double arg) {
Expand Down Expand Up @@ -410,9 +410,9 @@ SetUseNN(bool flag)
}

void
SetDeviceId(int id)
SetDeviceIds(const vector<int>& id)
{
device_id = id;
device_ids = id;
}

void
Expand Down Expand Up @@ -491,7 +491,7 @@ InitializeUctSearch( void )
exit(1);
}

if (use_nn && !nn_policy)
if (use_nn && nn_policy_list.size() == 0)
ReadWeights();
}

Expand Down Expand Up @@ -660,8 +660,11 @@ UctSearchGenmove( game_info_t *game, int color )
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}

if (use_nn)
handle.push_back(make_unique<thread>(EvalNode));
if (use_nn) {
for (int i = 0; i < device_ids.size(); i++) {
handle.push_back(make_unique<thread>(EvalNode, i));
}
}

for (auto &t : handle) {
t->join();
Expand All @@ -681,8 +684,12 @@ UctSearchGenmove( game_info_t *game, int color )
for (int i = 0; i < threads; i++) {
handle.push_back(make_unique<thread>(ParallelUctSearch, &t_arg[i]));
}
if (use_nn)
handle.push_back(make_unique<thread>(EvalNode));

if (use_nn) {
for (int i = 0; i < device_ids.size(); i++) {
handle.push_back(make_unique<thread>(EvalNode, i));
}
}

for (auto &t : handle) {
t->join();
Expand Down Expand Up @@ -787,21 +794,21 @@ UctSearchGenmove( game_info_t *game, int color )
// 予測読み //
///////////////
void
UctSearchPondering( game_info_t *game, int color )
UctSearchPondering(game_info_t *game, int color)
{
int pos;

if (!pondering_mode) {
return ;
return;
}

// 探索情報をクリア
memset(statistic, 0, sizeof(statistic_t) * board_max);
fill_n(criticality_index, board_max, 0);
memset(statistic, 0, sizeof(statistic_t) * board_max);
fill_n(criticality_index, board_max, 0);
for (int i = 0; i < board_max; i++) {
criticality[i] = 0.0;
criticality[i] = 0.0;
}

po_info.count = 0;

for (int i = 0; i < pure_board_max; i++) {
Expand All @@ -822,7 +829,7 @@ UctSearchPondering( game_info_t *game, int color )
if (uct_node[current_root].child_num <= 1) {
ponder = false;
pondering_stop = true;
return ;
return;
}

ponder = true;
Expand All @@ -840,8 +847,11 @@ UctSearchPondering( game_info_t *game, int color )
handle.push_back(make_unique<thread>(ParallelUctSearchPondering, &t_arg[i]));
}

if (use_nn)
handle.push_back(make_unique<thread>(EvalNode));
if (use_nn) {
for (int i = 0; i < device_ids.size(); i++) {
handle.push_back(make_unique<thread>(EvalNode, i));
}
}

return ;
}
Expand Down Expand Up @@ -2393,6 +2403,16 @@ UctSearchGenmoveCleanUp( game_info_t *game, int color )

extern char uct_params_path[1024];

static CNTK::DeviceDescriptor
GetDevice(int no)
{
int device_id = device_ids[no];
if (device_id < 0)
return CNTK::DeviceDescriptor::CPUDevice();
else
return CNTK::DeviceDescriptor::GPUDevice(device_id);
}

void
ReadWeights()
{
Expand All @@ -2401,37 +2421,43 @@ ReadWeights()

cerr << "Init CNTK" << endl;

CNTK::DeviceDescriptor device = CNTK::DeviceDescriptor::GPUDevice(device_id);
for (int i = 0; i < device_ids.size(); i++) {
auto device = GetDevice(i);

wstring policy_name = converter.from_bytes(uct_params_path);
policy_name += L"/model2.bin";
nn_policy = CNTK::Function::Load(policy_name, device);
wstring policy_name = converter.from_bytes(uct_params_path);
policy_name += L"/model2.bin";
auto nn_policy = CNTK::Function::Load(policy_name, device);

wstring value_name = converter.from_bytes(uct_params_path);
value_name += L"/model3.bin";
nn_value = CNTK::Function::Load(value_name, device);
wstring value_name = converter.from_bytes(uct_params_path);
value_name += L"/model3.bin";
auto nn_value = CNTK::Function::Load(value_name, device);

if (!nn_policy || !nn_value)
{
cerr << "Get EvalModel failed\n";
}
if (!nn_policy || !nn_value)
{
cerr << "Get EvalModel failed\n";
abort();
}

nn_policy_list.push_back(nn_policy);
nn_value_list.push_back(nn_value);

#if 0
wcerr << L"***POLICY" << endl;
for (auto var : nn_policy->Inputs()) {
wcerr << var.AsString() << endl;
}
for (auto var : nn_policy->Outputs()) {
wcerr << var.AsString() << endl;
}
wcerr << L"***VALUE" << endl;
for (auto var : nn_value->Inputs()) {
wcerr << var.AsString() << endl;
}
for (auto var : nn_value->Outputs()) {
wcerr << var.AsString() << endl;
}
wcerr << L"***POLICY" << endl;
for (auto var : nn_policy->Inputs()) {
wcerr << var.AsString() << endl;
}
for (auto var : nn_policy->Outputs()) {
wcerr << var.AsString() << endl;
}
wcerr << L"***VALUE" << endl;
for (auto var : nn_value->Inputs()) {
wcerr << var.AsString() << endl;
}
for (auto var : nn_value->Outputs()) {
wcerr << var.AsString() << endl;
}
#endif
}

cerr << "ok" << endl;
}
Expand Down Expand Up @@ -2464,13 +2490,17 @@ GetOutputVaraiableByName(CNTK::FunctionPtr evalFunc, wstring varName, CNTK::Vari
}

void
EvalPolicy(const std::vector<std::shared_ptr<policy_eval_req>>& requests,
EvalPolicy(int thread_no,
const std::vector<std::shared_ptr<policy_eval_req>>& requests,
std::vector<float>& data_basic, std::vector<float>& data_features, std::vector<float>& data_history,
std::vector<float>& data_color, std::vector<float>& data_komi)
{
if (requests.size() == 0)
return;

auto device = GetDevice(thread_no);
auto nn_policy = nn_policy_list[thread_no];

CNTK::Variable var_basic, var_features, var_history, var_color, var_komi;
GetInputVariableByName(nn_policy, L"basic", var_basic);
GetInputVariableByName(nn_policy, L"features", var_features);
Expand All @@ -2496,7 +2526,6 @@ EvalPolicy(const std::vector<std::shared_ptr<policy_eval_req>>& requests,

CNTK::ValuePtr value_ol;

CNTK::DeviceDescriptor device = CNTK::DeviceDescriptor::GPUDevice(device_id);
std::unordered_map<CNTK::Variable, CNTK::ValuePtr> inputs = {
{ var_basic, value_basic },
{ var_features, value_features },
Expand Down Expand Up @@ -2572,13 +2601,17 @@ EvalPolicy(const std::vector<std::shared_ptr<policy_eval_req>>& requests,


void
EvalValue(const std::vector<std::shared_ptr<value_eval_req>>& requests,
EvalValue(int thread_no,
const std::vector<std::shared_ptr<value_eval_req>>& requests,
std::vector<float>& data_basic, std::vector<float>& data_features, std::vector<float>& data_history,
std::vector<float>& data_color, std::vector<float>& data_komi, std::vector<float>& data_safety)
{
if (requests.size() == 0)
return;

auto device = GetDevice(thread_no);
auto nn_value = nn_value_list[thread_no];

CNTK::Variable var_basic, var_features, var_history, var_color, var_komi, var_safety;
GetInputVariableByName(nn_value, L"basic", var_basic);
GetInputVariableByName(nn_value, L"features", var_features);
Expand Down Expand Up @@ -2607,7 +2640,6 @@ EvalValue(const std::vector<std::shared_ptr<value_eval_req>>& requests,

CNTK::ValuePtr value_p;

CNTK::DeviceDescriptor device = CNTK::DeviceDescriptor::GPUDevice(device_id);
std::unordered_map<CNTK::Variable, CNTK::ValuePtr> inputs = {
{ var_basic, value_basic },
{ var_features, value_features },
Expand Down Expand Up @@ -2665,20 +2697,23 @@ EvalValue(const std::vector<std::shared_ptr<value_eval_req>>& requests,
eval_count_value += requests.size();
}

static std::vector<float> eval_input_data_basic;
static std::vector<float> eval_input_data_features;
static std::vector<float> eval_input_data_history;
static std::vector<float> eval_input_data_color;
static std::vector<float> eval_input_data_komi;
static std::vector<float> eval_input_data_safety;

void EvalNode() {
void EvalNode(int thread_no) {
#if 1
std::vector<float> eval_input_data_basic;
std::vector<float> eval_input_data_features;
std::vector<float> eval_input_data_history;
std::vector<float> eval_input_data_color;
std::vector<float> eval_input_data_komi;
std::vector<float> eval_input_data_safety;

int num_eval = 0;

while (true) {
mutex_queue.lock();
if (!running
&& ((!reuse_subtree && !ponder) || (eval_policy_queue.empty() && eval_value_queue.empty()))) {
mutex_queue.unlock();
cerr << "Eval #" << thread_no << " " << num_eval << endl;
break;
}

Expand Down Expand Up @@ -2713,7 +2748,8 @@ void EvalNode() {
eval_input_data_color.push_back(req->color - 1);
eval_input_data_komi.push_back(komi[0]);
}
EvalPolicy(requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi);
num_eval += requests.size();
EvalPolicy(thread_no, requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi);
mutex_queue.lock();
}

Expand Down Expand Up @@ -2742,7 +2778,8 @@ void EvalNode() {
eval_input_data_komi.push_back(komi[0]);
}
eval_input_data_safety.resize(requests.size() * pure_board_max * 8);
EvalValue(requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi, eval_input_data_safety);
num_eval += requests.size();
EvalValue(thread_no, requests, eval_input_data_basic, eval_input_data_features, eval_input_data_history, eval_input_data_color, eval_input_data_komi, eval_input_data_safety);
}
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/UctSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,6 @@ void SetReuseSubtree( bool flag );

void SetUseNN(bool flag);

void SetDeviceId(int id);
void SetDeviceIds( const std::vector<int>& ids );

#endif

0 comments on commit 4987925

Please sign in to comment.