From c1e90e70bc8955a271eda61e909bf67051010894 Mon Sep 17 00:00:00 2001 From: zjhellofss Date: Sat, 5 Aug 2023 10:35:57 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4expression=E7=9A=84=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/parser/parse_expression.hpp | 34 +++++------ source/layer/details/convolution.cpp | 2 +- source/parser/parse_expression.cpp | 91 ++++++++++++++++------------ 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/include/parser/parse_expression.hpp b/include/parser/parse_expression.hpp index c1d8afc7..08279bad 100644 --- a/include/parser/parse_expression.hpp +++ b/include/parser/parse_expression.hpp @@ -18,15 +18,15 @@ // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. - + // Created by fss on 22-12-1. #ifndef KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_ #define KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_ +#include #include #include #include -#include namespace kuiper_infer { @@ -44,29 +44,27 @@ enum class TokenType { // 词语Token struct Token { TokenType token_type = TokenType::TokenUnknown; - int32_t start_pos = 0; //词语开始的位置 - int32_t end_pos = 0; // 词语结束的位置 + int32_t start_pos = 0; // 词语开始的位置 + int32_t end_pos = 0; // 词语结束的位置 Token(TokenType token_type, int32_t start_pos, int32_t end_pos) - : token_type(token_type), start_pos(start_pos), end_pos(end_pos) { - - } + : token_type(token_type), start_pos(start_pos), end_pos(end_pos) {} }; // 语法树的节点 struct TokenNode { int32_t num_index = -1; - std::shared_ptr left = nullptr; // 语法树的左节点 - std::shared_ptr right = nullptr; // 语法树的右节点 - TokenNode(int32_t num_index, std::shared_ptr left, std::shared_ptr right); + std::shared_ptr left = nullptr; // 语法树的左节点 + std::shared_ptr right = nullptr; // 语法树的右节点 + TokenNode(int32_t num_index, std::shared_ptr left, + std::shared_ptr right); TokenNode() = default; }; // add(add(add(@0,@1),@1),add(@0,@2)) class ExpressionParser { public: - explicit ExpressionParser(std::string statement) : statement_(std::move(statement)) { - - } + explicit ExpressionParser(std::string statement) + : statement_(std::move(statement)) {} /** * 词法分析 @@ -84,16 +82,16 @@ class ExpressionParser { * 返回词法分析的结果 * @return 词法分析的结果 */ - const std::vector &tokens() const; + const std::vector& tokens() const; /** * 返回词语字符串 * @return 词语字符串 */ - const std::vector &token_strs() const; + const std::vector& token_strs() const; private: - std::shared_ptr Generate_(int32_t &index); + std::shared_ptr Generate_(int32_t& index); // 被分割的词语数组 std::vector tokens_; // 被分割的字符串数组 @@ -101,6 +99,6 @@ class ExpressionParser { // 待分割的表达式 std::string statement_; }; -} +} // namespace kuiper_infer -#endif //KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_ +#endif // KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_ diff --git a/source/layer/details/convolution.cpp b/source/layer/details/convolution.cpp index 09021265..6de16516 100644 --- a/source/layer/details/convolution.cpp +++ b/source/layer/details/convolution.cpp @@ -89,7 +89,7 @@ void ConvolutionLayer::set_weights(const std::vector& weights) { /* * 卷积核权重摆放的顺序是c n h w, 需要将它调整到n c h w * 其中n表示卷积核次序,kernel_idx = g * kernel_count_group + kg; - * origin_idx = ic * kernel_nhw (nhw) + kg(n) * kernel_hw + ... + * origin_pixel_idx = ic * kernel_nhw (nhw) + kg(n) * kernel_hw + ... */ for (uint32_t ic = 0; ic < kernel_channel; ++ic) { const uint32_t kernel_offset = ic * kernel_nhw; diff --git a/source/parser/parse_expression.cpp b/source/parser/parse_expression.cpp index 1da3d73b..733b7f26 100644 --- a/source/parser/parse_expression.cpp +++ b/source/parser/parse_expression.cpp @@ -18,19 +18,19 @@ // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. - + // Created by fss on 22-12-1. #include "parser/parse_expression.hpp" +#include #include #include #include #include -#include namespace kuiper_infer { -void ReversePolish(const std::shared_ptr &root_node, - std::vector> &reverse_polish) { +void ReversePolish(const std::shared_ptr& root_node, + std::vector>& reverse_polish) { if (root_node != nullptr) { ReversePolish(root_node->left, reverse_polish); ReversePolish(root_node->right, reverse_polish); @@ -44,36 +44,43 @@ void ExpressionParser::Tokenizer(bool retokenize) { } CHECK(!statement_.empty()) << "The input statement is empty!"; - statement_.erase(std::remove_if(statement_.begin(), statement_.end(), [](char c) { - return std::isspace(c); - }), statement_.end()); + statement_.erase(std::remove_if(statement_.begin(), statement_.end(), + [](char c) { return std::isspace(c); }), + statement_.end()); CHECK(!statement_.empty()) << "The input statement is empty!"; for (int32_t i = 0; i < statement_.size();) { char c = statement_.at(i); if (c == 'a') { CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'd') - << "Parse add token failed, illegal character: " << statement_.at(i + 1); + << "Parse add token failed, illegal character: " + << statement_.at(i + 1); CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'd') - << "Parse add token failed, illegal character: " << statement_.at(i + 2); + << "Parse add token failed, illegal character: " + << statement_.at(i + 2); Token token(TokenType::TokenAdd, i, i + 3); tokens_.push_back(token); - std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3); + std::string token_operation = + std::string(statement_.begin() + i, statement_.begin() + i + 3); token_strs_.push_back(token_operation); i = i + 3; } else if (c == 'm') { CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'u') - << "Parse multiply token failed, illegal character: " << statement_.at(i + 1); + << "Parse multiply token failed, illegal character: " + << statement_.at(i + 1); CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'l') - << "Parse multiply token failed, illegal character: " << statement_.at(i + 2); + << "Parse multiply token failed, illegal character: " + << statement_.at(i + 2); Token token(TokenType::TokenMul, i, i + 3); tokens_.push_back(token); - std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3); + std::string token_operation = + std::string(statement_.begin() + i, statement_.begin() + i + 3); token_strs_.push_back(token_operation); i = i + 3; } else if (c == '@') { CHECK(i + 1 < statement_.size() && std::isdigit(statement_.at(i + 1))) - << "Parse number token failed, illegal character: " << statement_.at(i + 1); + << "Parse number token failed, illegal character: " + << statement_.at(i + 1); int32_t j = i + 1; for (; j < statement_.size(); ++j) { if (!std::isdigit(statement_.at(j))) { @@ -83,25 +90,29 @@ void ExpressionParser::Tokenizer(bool retokenize) { Token token(TokenType::TokenInputNumber, i, j); CHECK(token.start_pos < token.end_pos); tokens_.push_back(token); - std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j); + std::string token_input_number = + std::string(statement_.begin() + i, statement_.begin() + j); token_strs_.push_back(token_input_number); i = j; } else if (c == ',') { Token token(TokenType::TokenComma, i, i + 1); tokens_.push_back(token); - std::string token_comma = std::string(statement_.begin() + i, statement_.begin() + i + 1); + std::string token_comma = + std::string(statement_.begin() + i, statement_.begin() + i + 1); token_strs_.push_back(token_comma); i += 1; } else if (c == '(') { Token token(TokenType::TokenLeftBracket, i, i + 1); tokens_.push_back(token); - std::string token_left_bracket = std::string(statement_.begin() + i, statement_.begin() + i + 1); + std::string token_left_bracket = + std::string(statement_.begin() + i, statement_.begin() + i + 1); token_strs_.push_back(token_left_bracket); i += 1; } else if (c == ')') { Token token(TokenType::TokenRightBracket, i, i + 1); tokens_.push_back(token); - std::string token_right_bracket = std::string(statement_.begin() + i, statement_.begin() + i + 1); + std::string token_right_bracket = + std::string(statement_.begin() + i, statement_.begin() + i + 1); token_strs_.push_back(token_right_bracket); i += 1; } else { @@ -110,28 +121,32 @@ void ExpressionParser::Tokenizer(bool retokenize) { } } -const std::vector &ExpressionParser::tokens() const { +const std::vector& ExpressionParser::tokens() const { return this->tokens_; } -const std::vector &ExpressionParser::token_strs() const { +const std::vector& ExpressionParser::token_strs() const { return this->token_strs_; } -std::shared_ptr ExpressionParser::Generate_(int32_t &index) { +std::shared_ptr ExpressionParser::Generate_(int32_t& index) { CHECK(index < this->tokens_.size()); const auto current_token = this->tokens_.at(index); - CHECK(current_token.token_type == TokenType::TokenInputNumber - || current_token.token_type == TokenType::TokenAdd || current_token.token_type == TokenType::TokenMul); + CHECK(current_token.token_type == TokenType::TokenInputNumber || + current_token.token_type == TokenType::TokenAdd || + current_token.token_type == TokenType::TokenMul); if (current_token.token_type == TokenType::TokenInputNumber) { uint32_t start_pos = current_token.start_pos + 1; uint32_t end_pos = current_token.end_pos; - CHECK(end_pos > start_pos || end_pos <= this->statement_.length()) << "Current token has a wrong length"; - const std::string &str_number = - std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos); + CHECK(end_pos > start_pos || end_pos <= this->statement_.length()) + << "Current token has a wrong length"; + const std::string& str_number = + std::string(this->statement_.begin() + start_pos, + this->statement_.begin() + end_pos); return std::make_shared(std::stoi(str_number), nullptr, nullptr); - } else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd) { + } else if (current_token.token_type == TokenType::TokenMul || + current_token.token_type == TokenType::TokenAdd) { std::shared_ptr current_node = std::make_shared(); current_node->num_index = int(current_token.token_type); @@ -143,8 +158,9 @@ std::shared_ptr ExpressionParser::Generate_(int32_t &index) { CHECK(index < this->tokens_.size()) << "Missing correspond left token!"; const auto left_token = this->tokens_.at(index); - if (left_token.token_type == TokenType::TokenInputNumber - || left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul) { + if (left_token.token_type == TokenType::TokenInputNumber || + left_token.token_type == TokenType::TokenAdd || + left_token.token_type == TokenType::TokenMul) { current_node->left = Generate_(index); } else { LOG(FATAL) << "Unknown token type: " << int(left_token.token_type); @@ -157,8 +173,9 @@ std::shared_ptr ExpressionParser::Generate_(int32_t &index) { index += 1; CHECK(index < this->tokens_.size()) << "Missing correspond right token!"; const auto right_token = this->tokens_.at(index); - if (right_token.token_type == TokenType::TokenInputNumber - || right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul) { + if (right_token.token_type == TokenType::TokenInputNumber || + right_token.token_type == TokenType::TokenAdd || + right_token.token_type == TokenType::TokenMul) { current_node->right = Generate_(index); } else { LOG(FATAL) << "Unknown token type: " << int(right_token.token_type); @@ -173,8 +190,7 @@ std::shared_ptr ExpressionParser::Generate_(int32_t &index) { } } -std::vector > ExpressionParser::Generate() { - +std::vector> ExpressionParser::Generate() { if (this->tokens_.empty()) { this->Tokenizer(true); } @@ -190,8 +206,7 @@ std::vector > ExpressionParser::Generate() { return reverse_polish; } -TokenNode::TokenNode(int32_t num_index, std::shared_ptr left, std::shared_ptr right) : - num_index(num_index), left(left), right(right) { - -} -} \ No newline at end of file +TokenNode::TokenNode(int32_t num_index, std::shared_ptr left, + std::shared_ptr right) + : num_index(num_index), left(left), right(right) {} +} // namespace kuiper_infer \ No newline at end of file