Skip to content

Commit

Permalink
调整expression的代码格式
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Aug 5, 2023
1 parent b586f6f commit c1e90e7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 57 deletions.
34 changes: 16 additions & 18 deletions include/parser/parse_expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// Created by fss on 22-12-1.

#ifndef KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_
#define KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <memory>

namespace kuiper_infer {

Expand All @@ -44,29 +44,27 @@ enum class TokenType {
// 词语Token
struct Token {
TokenType token_type = TokenType::TokenUnknown;
int32_t start_pos = 0; //词语开始的位置
int32_t end_pos = 0; // 词语结束的位置
int32_t start_pos = 0; // 词语开始的位置
int32_t end_pos = 0; // 词语结束的位置
Token(TokenType token_type, int32_t start_pos, int32_t end_pos)
: token_type(token_type), start_pos(start_pos), end_pos(end_pos) {

}
: token_type(token_type), start_pos(start_pos), end_pos(end_pos) {}
};

// 语法树的节点
struct TokenNode {
int32_t num_index = -1;
std::shared_ptr<TokenNode> left = nullptr; // 语法树的左节点
std::shared_ptr<TokenNode> right = nullptr; // 语法树的右节点
TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);
std::shared_ptr<TokenNode> left = nullptr; // 语法树的左节点
std::shared_ptr<TokenNode> right = nullptr; // 语法树的右节点
TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left,
std::shared_ptr<TokenNode> right);
TokenNode() = default;
};

// add(add(add(@0,@1),@1),add(@0,@2))
class ExpressionParser {
public:
explicit ExpressionParser(std::string statement) : statement_(std::move(statement)) {

}
explicit ExpressionParser(std::string statement)
: statement_(std::move(statement)) {}

/**
* 词法分析
Expand All @@ -84,23 +82,23 @@ class ExpressionParser {
* 返回词法分析的结果
* @return 词法分析的结果
*/
const std::vector<Token> &tokens() const;
const std::vector<Token>& tokens() const;

/**
* 返回词语字符串
* @return 词语字符串
*/
const std::vector<std::string> &token_strs() const;
const std::vector<std::string>& token_strs() const;

private:
std::shared_ptr<TokenNode> Generate_(int32_t &index);
std::shared_ptr<TokenNode> Generate_(int32_t& index);
// 被分割的词语数组
std::vector<Token> tokens_;
// 被分割的字符串数组
std::vector<std::string> token_strs_;
// 待分割的表达式
std::string statement_;
};
}
} // namespace kuiper_infer

#endif //KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_
#endif // KUIPER_INFER_INCLUDE_PARSER_PARSE_EXPRESSION_HPP_
2 changes: 1 addition & 1 deletion source/layer/details/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void ConvolutionLayer::set_weights(const std::vector<float>& weights) {
/*
* 卷积核权重摆放的顺序是c n h w, 需要将它调整到n c h w
* 其中n表示卷积核次序,kernel_idx = g * kernel_count_group + kg;
* origin_idx = ic * kernel_nhw (nhw) + kg(n) * kernel_hw + ...
* origin_pixel_idx = ic * kernel_nhw (nhw) + kg(n) * kernel_hw + ...
*/
for (uint32_t ic = 0; ic < kernel_channel; ++ic) {
const uint32_t kernel_offset = ic * kernel_nhw;
Expand Down
91 changes: 53 additions & 38 deletions source/parser/parse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// Created by fss on 22-12-1.
#include "parser/parse_expression.hpp"
#include <glog/logging.h>
#include <algorithm>
#include <cctype>
#include <stack>
#include <utility>
#include <glog/logging.h>

namespace kuiper_infer {

void ReversePolish(const std::shared_ptr<TokenNode> &root_node,
std::vector<std::shared_ptr<TokenNode>> &reverse_polish) {
void ReversePolish(const std::shared_ptr<TokenNode>& root_node,
std::vector<std::shared_ptr<TokenNode>>& reverse_polish) {
if (root_node != nullptr) {
ReversePolish(root_node->left, reverse_polish);
ReversePolish(root_node->right, reverse_polish);
Expand All @@ -44,36 +44,43 @@ void ExpressionParser::Tokenizer(bool retokenize) {
}

CHECK(!statement_.empty()) << "The input statement is empty!";
statement_.erase(std::remove_if(statement_.begin(), statement_.end(), [](char c) {
return std::isspace(c);
}), statement_.end());
statement_.erase(std::remove_if(statement_.begin(), statement_.end(),
[](char c) { return std::isspace(c); }),
statement_.end());
CHECK(!statement_.empty()) << "The input statement is empty!";

for (int32_t i = 0; i < statement_.size();) {
char c = statement_.at(i);
if (c == 'a') {
CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'd')
<< "Parse add token failed, illegal character: " << statement_.at(i + 1);
<< "Parse add token failed, illegal character: "
<< statement_.at(i + 1);
CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'd')
<< "Parse add token failed, illegal character: " << statement_.at(i + 2);
<< "Parse add token failed, illegal character: "
<< statement_.at(i + 2);
Token token(TokenType::TokenAdd, i, i + 3);
tokens_.push_back(token);
std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);
std::string token_operation =
std::string(statement_.begin() + i, statement_.begin() + i + 3);
token_strs_.push_back(token_operation);
i = i + 3;
} else if (c == 'm') {
CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'u')
<< "Parse multiply token failed, illegal character: " << statement_.at(i + 1);
<< "Parse multiply token failed, illegal character: "
<< statement_.at(i + 1);
CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'l')
<< "Parse multiply token failed, illegal character: " << statement_.at(i + 2);
<< "Parse multiply token failed, illegal character: "
<< statement_.at(i + 2);
Token token(TokenType::TokenMul, i, i + 3);
tokens_.push_back(token);
std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);
std::string token_operation =
std::string(statement_.begin() + i, statement_.begin() + i + 3);
token_strs_.push_back(token_operation);
i = i + 3;
} else if (c == '@') {
CHECK(i + 1 < statement_.size() && std::isdigit(statement_.at(i + 1)))
<< "Parse number token failed, illegal character: " << statement_.at(i + 1);
<< "Parse number token failed, illegal character: "
<< statement_.at(i + 1);
int32_t j = i + 1;
for (; j < statement_.size(); ++j) {
if (!std::isdigit(statement_.at(j))) {
Expand All @@ -83,25 +90,29 @@ void ExpressionParser::Tokenizer(bool retokenize) {
Token token(TokenType::TokenInputNumber, i, j);
CHECK(token.start_pos < token.end_pos);
tokens_.push_back(token);
std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j);
std::string token_input_number =
std::string(statement_.begin() + i, statement_.begin() + j);
token_strs_.push_back(token_input_number);
i = j;
} else if (c == ',') {
Token token(TokenType::TokenComma, i, i + 1);
tokens_.push_back(token);
std::string token_comma = std::string(statement_.begin() + i, statement_.begin() + i + 1);
std::string token_comma =
std::string(statement_.begin() + i, statement_.begin() + i + 1);
token_strs_.push_back(token_comma);
i += 1;
} else if (c == '(') {
Token token(TokenType::TokenLeftBracket, i, i + 1);
tokens_.push_back(token);
std::string token_left_bracket = std::string(statement_.begin() + i, statement_.begin() + i + 1);
std::string token_left_bracket =
std::string(statement_.begin() + i, statement_.begin() + i + 1);
token_strs_.push_back(token_left_bracket);
i += 1;
} else if (c == ')') {
Token token(TokenType::TokenRightBracket, i, i + 1);
tokens_.push_back(token);
std::string token_right_bracket = std::string(statement_.begin() + i, statement_.begin() + i + 1);
std::string token_right_bracket =
std::string(statement_.begin() + i, statement_.begin() + i + 1);
token_strs_.push_back(token_right_bracket);
i += 1;
} else {
Expand All @@ -110,28 +121,32 @@ void ExpressionParser::Tokenizer(bool retokenize) {
}
}

const std::vector<Token> &ExpressionParser::tokens() const {
const std::vector<Token>& ExpressionParser::tokens() const {
return this->tokens_;
}

const std::vector<std::string> &ExpressionParser::token_strs() const {
const std::vector<std::string>& ExpressionParser::token_strs() const {
return this->token_strs_;
}

std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {
std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t& index) {
CHECK(index < this->tokens_.size());
const auto current_token = this->tokens_.at(index);
CHECK(current_token.token_type == TokenType::TokenInputNumber
|| current_token.token_type == TokenType::TokenAdd || current_token.token_type == TokenType::TokenMul);
CHECK(current_token.token_type == TokenType::TokenInputNumber ||
current_token.token_type == TokenType::TokenAdd ||
current_token.token_type == TokenType::TokenMul);
if (current_token.token_type == TokenType::TokenInputNumber) {
uint32_t start_pos = current_token.start_pos + 1;
uint32_t end_pos = current_token.end_pos;
CHECK(end_pos > start_pos || end_pos <= this->statement_.length()) << "Current token has a wrong length";
const std::string &str_number =
std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);
CHECK(end_pos > start_pos || end_pos <= this->statement_.length())
<< "Current token has a wrong length";
const std::string& str_number =
std::string(this->statement_.begin() + start_pos,
this->statement_.begin() + end_pos);
return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);

} else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd) {
} else if (current_token.token_type == TokenType::TokenMul ||
current_token.token_type == TokenType::TokenAdd) {
std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();
current_node->num_index = int(current_token.token_type);

Expand All @@ -143,8 +158,9 @@ std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {
CHECK(index < this->tokens_.size()) << "Missing correspond left token!";
const auto left_token = this->tokens_.at(index);

if (left_token.token_type == TokenType::TokenInputNumber
|| left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul) {
if (left_token.token_type == TokenType::TokenInputNumber ||
left_token.token_type == TokenType::TokenAdd ||
left_token.token_type == TokenType::TokenMul) {
current_node->left = Generate_(index);
} else {
LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);
Expand All @@ -157,8 +173,9 @@ std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {
index += 1;
CHECK(index < this->tokens_.size()) << "Missing correspond right token!";
const auto right_token = this->tokens_.at(index);
if (right_token.token_type == TokenType::TokenInputNumber
|| right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul) {
if (right_token.token_type == TokenType::TokenInputNumber ||
right_token.token_type == TokenType::TokenAdd ||
right_token.token_type == TokenType::TokenMul) {
current_node->right = Generate_(index);
} else {
LOG(FATAL) << "Unknown token type: " << int(right_token.token_type);
Expand All @@ -173,8 +190,7 @@ std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {
}
}

std::vector<std::shared_ptr<TokenNode> > ExpressionParser::Generate() {

std::vector<std::shared_ptr<TokenNode>> ExpressionParser::Generate() {
if (this->tokens_.empty()) {
this->Tokenizer(true);
}
Expand All @@ -190,8 +206,7 @@ std::vector<std::shared_ptr<TokenNode> > ExpressionParser::Generate() {
return reverse_polish;
}

TokenNode::TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right) :
num_index(num_index), left(left), right(right) {

}
}
TokenNode::TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left,
std::shared_ptr<TokenNode> right)
: num_index(num_index), left(left), right(right) {}
} // namespace kuiper_infer

0 comments on commit c1e90e7

Please sign in to comment.