diff --git a/r2i/tensorflow/engine.cc b/r2i/tensorflow/engine.cc index 91252f67..0f31ef2a 100644 --- a/r2i/tensorflow/engine.cc +++ b/r2i/tensorflow/engine.cc @@ -10,9 +10,7 @@ */ #include "r2i/tensorflow/engine.h" - #include - #include "r2i/tensorflow/prediction.h" #include "r2i/tensorflow/frame.h" @@ -55,6 +53,18 @@ RuntimeError Engine::SetModel (std::shared_ptr in_model) { return error; } +RuntimeError Engine::SetMemoryUsage (double memory_usage) { + RuntimeError error; + + if (memory_usage > 1.0 || memory_usage < 0.1) { + error.Set (RuntimeError::Code::WRONG_API_USAGE, "Invalid memory usage value"); + return error; + } + + this->session_memory_usage_index = (static_cast(memory_usage * 10) - 1); + return error; +} + static RuntimeError FreeSession (TF_Session *session) { RuntimeError error; std::shared_ptr pstatus(TF_NewStatus (), TF_DeleteStatus); @@ -97,6 +107,8 @@ RuntimeError Engine::Start () { TF_Graph *graph = pgraph.get(); TF_Status *status = pstatus.get (); TF_SessionOptions *opt = popt.get (); + TF_SetConfig(opt, this->config[this->session_memory_usage_index], + RAM_ARRAY_SIZE, status); std::shared_ptr session (TF_NewSession(graph, opt, status), FreeSession); diff --git a/r2i/tensorflow/engine.h b/r2i/tensorflow/engine.h index a8f4b78b..a2283c0e 100644 --- a/r2i/tensorflow/engine.h +++ b/r2i/tensorflow/engine.h @@ -18,6 +18,8 @@ #include +#define RAM_ARRAY_SIZE 11 + namespace r2i { namespace tensorflow { @@ -27,6 +29,8 @@ class Engine : public IEngine { r2i::RuntimeError SetModel (std::shared_ptr in_model) override; + r2i::RuntimeError SetMemoryUsage (double memory_usage); + r2i::RuntimeError Start () override; r2i::RuntimeError Stop () override; @@ -41,10 +45,25 @@ class Engine : public IEngine { STARTED, STOPPED }; + State state; + int session_memory_usage_index; std::shared_ptr session; std::shared_ptr model; + + const uint8_t config[10][RAM_ARRAY_SIZE] = { + {0x32, 0x9, 0x9, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f}, + {0x32, 0x9, 0x9, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xc9, 0x3f}, + {0x32, 0x9, 0x9, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0xd3, 0x3f}, + {0x32, 0x9, 0x9, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xd9, 0x3f}, + {0x32, 0x9, 0x9, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe0, 0x3f}, + {0x32, 0x9, 0x9, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0xe3, 0x3f}, + {0x32, 0x9, 0x9, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xe6, 0x3f}, + {0x32, 0x9, 0x9, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xe9, 0x3f}, + {0x32, 0x9, 0x9, 0xcd, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xec, 0x3f}, + {0x32, 0x9, 0x9, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf0, 0x3f} + }; }; } diff --git a/r2i/tensorflow/parameters.cc b/r2i/tensorflow/parameters.cc index 384406f8..e6b2cae1 100644 --- a/r2i/tensorflow/parameters.cc +++ b/r2i/tensorflow/parameters.cc @@ -38,6 +38,10 @@ Parameters::Parameters () : r2i::ParameterMeta::Flags::READ, r2i::ParameterMeta::Type::STRING, std::make_shared(this)), + PARAM("gpu-memory-usage", "Per process GPU memory usage fraction", + r2i::ParameterMeta::Flags::READWRITE | r2i::ParameterMeta::Flags::WRITE_BEFORE_START, + r2i::ParameterMeta::Type::DOUBLE, + std::make_shared(this)), /* Model parameters */ PARAM("input-layer", "Name of the input layer in the graph", @@ -157,11 +161,6 @@ RuntimeError Parameters::Get (const std::string &in_parameter, double &value) { } auto accessor = std::dynamic_pointer_cast(param.accessor); - if (nullptr == model) { - error.Set (RuntimeError::Code::INCOMPATIBLE_MODEL, - "The provided engine is not an tensorflow model"); - return error; - } error = accessor->Get (); if (error.IsError ()) { @@ -230,11 +229,6 @@ RuntimeError Parameters::Set (const std::string &in_parameter, } auto accessor = std::dynamic_pointer_cast(param.accessor); - if (nullptr == model) { - error.Set (RuntimeError::Code::INCOMPATIBLE_MODEL, - "The provided engine is not an tensorflow model"); - return error; - } accessor->value = in_value; return accessor->Set (); diff --git a/r2i/tensorflow/parameters.h b/r2i/tensorflow/parameters.h index 550f1bb5..1fabafb9 100644 --- a/r2i/tensorflow/parameters.h +++ b/r2i/tensorflow/parameters.h @@ -99,6 +99,18 @@ class Parameters : public IParameters { } }; + class MemoryUsageAccessor : public DoubleAccessor { + public: + MemoryUsageAccessor (Parameters *target) : DoubleAccessor(target) {} + RuntimeError Set () { + return target->engine->SetMemoryUsage(this->value); + } + + RuntimeError Get () { + return RuntimeError (); + } + }; + class InputLayerAccessor : public StringAccessor { public: InputLayerAccessor (Parameters *target) : StringAccessor(target) {}