Skip to content

Commit

Permalink
修改runtime_operand中shapes的类型
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Aug 12, 2023
1 parent bb6b5c9 commit 5e12234
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 53 deletions.
17 changes: 16 additions & 1 deletion include/data/tensor_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// Created by yizhu on 2023/3/20.

#ifndef KUIPER_INFER_TENSOR_UTIL_H
Expand Down Expand Up @@ -107,6 +107,21 @@ std::shared_ptr<Tensor<float>> TensorElementMultiply(
std::shared_ptr<Tensor<float>> TensorCreate(uint32_t channels, uint32_t rows,
uint32_t cols);

/**
* 创建一个张量
* @param rows 行数
* @param cols 列数
* @return 创建后的张量
*/
std::shared_ptr<Tensor<float>> TensorCreate(uint32_t rows, uint32_t cols);

/**
* 创建一个张量
* @param size 数据数
* @return 创建后的张量
*/
std::shared_ptr<Tensor<float>> TensorCreate(uint32_t size);

/**
* 创建一个张量
* @param shapes 张量的形状
Expand Down
24 changes: 13 additions & 11 deletions include/runtime/runtime_operand.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,27 @@
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// Created by fss on 22-11-28.

#ifndef KUIPER_INFER_INCLUDE_PARSER_RUNTIME_OPERAND_HPP_
#define KUIPER_INFER_INCLUDE_PARSER_RUNTIME_OPERAND_HPP_
#include <vector>
#include <string>
#include <memory>
#include "status_code.hpp"
#include "runtime_datatype.hpp"
#include <string>
#include <vector>
#include "data/tensor.hpp"
#include "runtime_datatype.hpp"
#include "status_code.hpp"

namespace kuiper_infer {
/// 计算节点输入输出的操作数
struct RuntimeOperand {
std::string name; /// 操作数的名称
std::vector<int32_t> shapes; /// 操作数的形状
std::vector<std::shared_ptr<Tensor<float>>> datas; /// 存储操作数
RuntimeDataType type = RuntimeDataType::kTypeUnknown; /// 操作数的类型,一般是float
std::string name; /// 操作数的名称
std::vector<uint32_t> shapes; /// 操作数的形状
std::vector<std::shared_ptr<Tensor<float>>> datas; /// 存储操作数
RuntimeDataType type =
RuntimeDataType::kTypeUnknown; /// 操作数的类型,一般是float

};
}
#endif //KUIPER_INFER_INCLUDE_PARSER_RUNTIME_OPERAND_HPP_
} // namespace kuiper_infer
#endif // KUIPER_INFER_INCLUDE_PARSER_RUNTIME_OPERAND_HPP_
8 changes: 8 additions & 0 deletions source/data/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ std::shared_ptr<Tensor<float>> TensorCreate(uint32_t channels, uint32_t rows,
return std::make_shared<Tensor<float>>(channels, rows, cols);
}

std::shared_ptr<Tensor<float>> TensorCreate(uint32_t rows, uint32_t cols) {
return std::make_shared<Tensor<float>>(1, rows, cols);
}

std::shared_ptr<Tensor<float>> TensorCreate(uint32_t size) {
return std::make_shared<Tensor<float>>(1, 1, size);
}

std::shared_ptr<Tensor<float>> TensorCreate(
const std::vector<uint32_t>& shapes) {
CHECK(!shapes.empty() && shapes.size() <= 3);
Expand Down
4 changes: 3 additions & 1 deletion source/layer/abstract/layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ InferStatus Layer::Forward() {

const std::shared_ptr<RuntimeOperand>& output_operand_datas =
runtime_operator->output_operands;

if (output_operand_datas == nullptr || output_operand_datas->datas.empty()) {
int a = 3;
}
CHECK(!layer_input_datas.empty())
<< runtime_operator->name << " Layer input data is empty";
CHECK(output_operand_datas != nullptr && !output_operand_datas->datas.empty())
Expand Down
18 changes: 9 additions & 9 deletions source/layer/details/activation_sse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ static void SigmoidSSE(sftensor input, sftensor output) {
#ifdef __SSE2__
int32_t index = 0;
int32_t packet_size = 4;
const uint32_t in_size = input->size();
int32_t in_size = static_cast<int32_t>(input->size());
const float* in_ptr = input->raw_ptr();
float* out_ptr = output->raw_ptr();
#ifdef __AVX2__
packet_size = 8;
__m256 _one = _mm256_set1_ps(1.f);
__m256 _zero = _mm256_setzero_ps();
for (; index <= (int32_t)in_size - packet_size; index += packet_size) {
for (; index <= in_size - packet_size; index += packet_size) {
__m256 _p = _mm256_loadu_ps(in_ptr);
_p = _mm256_div_ps(
_one, _mm256_add_ps(_one, fmath::exp_ps256(_mm256_sub_ps(_zero, _p))));
Expand All @@ -57,7 +57,7 @@ static void SigmoidSSE(sftensor input, sftensor output) {
#else
__m128 _one = _mm_set1_ps(1.f);
__m128 _zero = _mm_setzero_ps();
for (; index <= (int32_t)in_size - packet_size; index += packet_size) {
for (; index <= in_size - packet_size; index += packet_size) {
__m128 _p = _mm_load_ps(in_ptr);
_p = _mm_div_ps(_one,
_mm_add_ps(_one, fmath::exp_ps(_mm_sub_ps(_zero, _p))));
Expand Down Expand Up @@ -93,13 +93,13 @@ static void ReluSSE(sftensor input, sftensor output) {
#else
int32_t j = 0;
int32_t packet_size = 4;
const uint32_t size = input->size();
int32_t size = static_cast<int32_t>(input->size());
const float* in_ptr = input->raw_ptr();
float* out_ptr = output->raw_ptr();
#ifdef __AVX2__
packet_size = 8;
__m256 _zero = _mm256_setzero_ps();
for (j = 0; j <= (int32_t)size - packet_size; j += packet_size) {
for (j = 0; j <= size - packet_size; j += packet_size) {
__m256 _p = _mm256_loadu_ps(in_ptr);
__m256 _value = _mm256_max_ps(_zero, _p);
_mm256_storeu_ps(out_ptr, _value);
Expand All @@ -108,7 +108,7 @@ static void ReluSSE(sftensor input, sftensor output) {
}
#else
__m128 _zero = _mm_setzero_ps();
for (j = 0; j <= (int32_t)size - packet_size; j += packet_size) {
for (j = 0; j <= size - packet_size; j += packet_size) {
__m128 _p = _mm_load_ps(in_ptr);
__m128 _value = _mm_max_ps(_zero, _p);
_mm_store_ps(out_ptr, _value);
Expand Down Expand Up @@ -138,15 +138,15 @@ static void SiluSSE(sftensor input, sftensor output) {
#else
int32_t j = 0;
int32_t packet_size = 4;
const uint32_t size = input->size();
int32_t size = static_cast<int32_t>(input->size());
const float* in_ptr = input->raw_ptr();
float* out_ptr = output->raw_ptr();
#ifdef __AVX2__
packet_size = 8;
__m256 _one = _mm256_set1_ps(1.f);
__m256 _zero = _mm256_setzero_ps();

for (j = 0; j <= (int32_t)size - packet_size; j += packet_size) {
for (j = 0; j <= size - packet_size; j += packet_size) {
__m256 _p = _mm256_loadu_ps(in_ptr);
_p = _mm256_div_ps(
_p, _mm256_add_ps(_one, fmath::exp_ps256(_mm256_sub_ps(_zero, _p))));
Expand All @@ -158,7 +158,7 @@ static void SiluSSE(sftensor input, sftensor output) {
__m128 _one = _mm_set1_ps(1.f);
__m128 _zero = _mm_setzero_ps();

for (j = 0; j <= (int32_t)size - packet_size; j += packet_size) {
for (j = 0; j <= size - packet_size; j += packet_size) {
__m128 _p = _mm_loadu_ps(in_ptr);
_p = _mm_div_ps(_p, _mm_add_ps(_one, fmath::exp_ps(_mm_sub_ps(_zero, _p))));
_mm_storeu_ps(out_ptr, _p);
Expand Down
48 changes: 29 additions & 19 deletions source/layer/details/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ ConvolutionLayer::ConvolutionLayer(ConvType conv_type, uint32_t output_channel,
if (use_bias_) {
this->InitBiasParam(output_channel, 1, 1, 1);
}

CHECK_GE(groups_, 1);
CHECK_GT(stride_h_, 0);
CHECK_GT(stride_w_, 0);
if (conv_type_ == ConvType::OpConv) {
CHECK_EQ(output_padding_h_, 0);
CHECK_EQ(output_padding_w_, 0);
}
CHECK(conv_type_ == ConvType::OpConv || conv_type_ == ConvType::OpDeconv);
}

Expand Down Expand Up @@ -161,6 +169,7 @@ InferStatus ConvolutionLayer::Forward(
}

const uint32_t batch_size = inputs.size();

const uint32_t kernel_count_group = kernel_count / groups_;

if (kernel_matrix_arr_.empty()) {
Expand Down Expand Up @@ -239,10 +248,10 @@ InferStatus ConvolutionLayer::Forward(
<< "The number of channel for the kernel "
"matrix and input tensor do not match";

arma::fmat input_matrix;
if (conv_type_ == ConvType::OpConv) {
input_matrix = ConvIm2Col(input, kernel_h, kernel_w, input_h, input_w,
input_c_group, g, row_len, col_len);
const arma::fmat& input_matrix =
ConvIm2Col(input, kernel_h, kernel_w, input_h, input_w,
input_c_group, g, row_len, col_len);
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
ConvGemmBias(input_matrix, output_tensor, g, k, kernel_count_group,
Expand Down Expand Up @@ -282,21 +291,22 @@ void ConvolutionLayer::DeconvCol2ImWithBias(

uint32_t slide_count_w = (size_w - kernel_w) / stride_w_ + 1;
uint32_t slide_count_h = (size_h - kernel_h) / stride_h_ + 1;
#pragma omp parallel for collapse(2)
for (uint32_t x = 0; x < slide_count_w; ++x) {
for (uint32_t y = 0; y < slide_count_h; ++y) {
const uint32_t offset_x = x * stride_w_;
const uint32_t offset_y = y * stride_h_;
arma::fmat gemm_column((float*)gemm_result.colptr(x * slide_count_h + y),
kernel_h, kernel_w, false, true);

uint32_t gemm_rows = gemm_column.n_rows;
uint32_t gemm_cols = gemm_column.n_cols;
for (uint32_t col = 0; col < gemm_cols; ++col) {
float* gemm_ptr = gemm_column.colptr(col);
float* output_ptr = output_padding.colptr(offset_x + col);
memcpy(output_ptr + offset_y, gemm_ptr, sizeof(float) * gemm_rows);
}
#pragma omp parallel for
for (uint32_t index = 0; index < slide_count_w * slide_count_h; ++index) {
uint32_t x = index / slide_count_h;
uint32_t y = index % slide_count_h;
const uint32_t offset_x = x * stride_w_;
const uint32_t offset_y = y * stride_h_;
arma::fmat gemm_column((float*)gemm_result.colptr(index), kernel_h,
kernel_w, false, true);

uint32_t gemm_rows = gemm_column.n_rows;
uint32_t gemm_cols = gemm_column.n_cols;

for (uint32_t col = 0; col < gemm_cols; ++col) {
float* gemm_ptr = gemm_column.colptr(col);
float* output_ptr = output_padding.colptr(offset_x + col);
memcpy(output_ptr + offset_y, gemm_ptr, sizeof(float) * gemm_rows);
}
}

Expand Down Expand Up @@ -394,7 +404,7 @@ void ConvolutionLayer::ConvGemmBias(const arma::fmat& input_matrix,
uint32_t kernel_count_group,
uint32_t output_h,
uint32_t output_w) const {
CHECK(conv_type_ == ConvType::OpConv);
CHECK(conv_type_ == ConvType::OpConv) << "Convolution type need be OpConv";

CHECK(!input_matrix.empty());
CHECK(output_tensor && !output_tensor->empty());
Expand Down
7 changes: 6 additions & 1 deletion source/layer/details/maxpooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ MaxPoolingLayer::MaxPoolingLayer(uint32_t padding_h, uint32_t padding_w,
pooling_size_h_(pooling_size_h),
pooling_size_w_(pooling_size_w),
stride_h_(stride_h),
stride_w_(stride_w) {}
stride_w_(stride_w) {
CHECK_GT(stride_h_, 0);
CHECK_GT(stride_w_, 0);
CHECK_GT(pooling_size_h_, 0);
CHECK_GT(pooling_size_w_, 0);
}

InferStatus MaxPoolingLayer::Forward(
const std::vector<std::shared_ptr<Tensor<float>>>& inputs,
Expand Down
6 changes: 5 additions & 1 deletion source/runtime/runtime_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,12 @@ void RuntimeGraph::InitGraphOperatorsInput(
const pnnx::Operator* producer = input->producer;
std::shared_ptr<RuntimeOperand> runtime_operand =
std::make_shared<RuntimeOperand>();

runtime_operand->name = producer->name;
runtime_operand->shapes = input->shape;
for (uint32_t dim : input->shape) {
runtime_operand->shapes.push_back(dim);
}
CHECK(!runtime_operand->shapes.empty());

switch (input->type) {
case 1: {
Expand Down
26 changes: 16 additions & 10 deletions source/runtime/runtime_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,18 @@ void RuntimeOperatorUtils::InitOperatorOutput(
// 一个节点仅支持一个输出,实际上在pnnx中一个节点拥有两个不同输出的情况也是不存在的
pnnx::Operand* operand = operands.front();
const auto& runtime_op = operators.at(i);

CHECK(operand != nullptr) << "Operand output is null";
const std::vector<int32_t>& operand_shapes = operand->shape;
std::vector<uint32_t> operand_shapes;
for (uint32_t dim : operand->shape) {
operand_shapes.push_back(dim);
}
CHECK(!operand_shapes.empty());

// 得到需要初始化的输出空间
const auto& output_tensors = runtime_op->output_operands;
// 获取节点的输出张量应有形状
const int32_t batch = operand_shapes.at(0);
const uint32_t batch = operand_shapes.front();
CHECK(batch >= 0) << "Dynamic batch size is not supported!";
CHECK(operand_shapes.size() == 2 || operand_shapes.size() == 4 ||
operand_shapes.size() == 3)
Expand All @@ -101,19 +107,19 @@ void RuntimeOperatorUtils::InitOperatorOutput(
output_operand->type = RuntimeDataType::kTypeFloat32;
output_operand->name = operand->name + "_output";
// 输出空间初始化
if(runtime_op)
for (int j = 0; j < batch; ++j) {
if (operand_shapes.size() == 4) {
sftensor output_tensor = TensorCreate(
operand_shapes.at(1), operand_shapes.at(2), operand_shapes.at(3));
output_operand->datas.push_back(output_tensor);
} else if (operand_shapes.size() == 2) {
sftensor output_tensor = TensorCreate(
std::vector<uint32_t>{(uint32_t)operand_shapes.at(1)});
sftensor output_tensor = TensorCreate(operand_shapes.at(1));
output_operand->datas.push_back(output_tensor);
} else {
// current shape is 3
sftensor output_tensor = TensorCreate(std::vector<uint32_t>{
(uint32_t)operand_shapes.at(1), (uint32_t)operand_shapes.at(2)});
sftensor output_tensor =
TensorCreate(operand_shapes.at(1), operand_shapes.at(2));
output_operand->datas.push_back(output_tensor);
}
}
Expand All @@ -134,8 +140,8 @@ void RuntimeOperatorUtils::InitOperatorOutput(
DLOG(WARNING)
<< "The shape of tensor do not adapting with output operand";
const auto& target_shapes = std::vector<uint32_t>{
(uint32_t)operand_shapes.at(1), (uint32_t)operand_shapes.at(2),
(uint32_t)operand_shapes.at(3)};
operand_shapes.at(1), operand_shapes.at(2),
operand_shapes.at(3)};
output_tensor->Reshape(target_shapes);
}
} else if (operand_shapes.size() == 2) {
Expand All @@ -145,7 +151,7 @@ void RuntimeOperatorUtils::InitOperatorOutput(
DLOG(WARNING)
<< "The shape of tensor do not adapting with output operand";
const auto& target_shapes =
std::vector<uint32_t>{(uint32_t)operand_shapes.at(1)};
std::vector<uint32_t>{operand_shapes.at(1)};
output_tensor->Reshape(target_shapes);
}
} else {
Expand All @@ -156,7 +162,7 @@ void RuntimeOperatorUtils::InitOperatorOutput(
DLOG(WARNING)
<< "The shape of tensor do not adapting with output operand";
const auto& target_shapes = std::vector<uint32_t>{
(uint32_t)operand_shapes.at(1), (uint32_t)operand_shapes.at(2)};
operand_shapes.at(1), operand_shapes.at(2)};
output_tensor->Reshape(target_shapes);
}
}
Expand Down

0 comments on commit 5e12234

Please sign in to comment.