Skip to content

Commit

Permalink
在conv中修改openmp的使用条件
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Aug 8, 2023
1 parent b536705 commit bb6b5c9
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions source/layer/details/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ InferStatus ConvolutionLayer::Forward(
CHECK(output_h > 0 && output_w > 0)
<< "The size of the output tensor should be greater than zero " << i
<< " th";

#pragma omp parallel for if (groups_ > 1)
for (uint32_t g = 0; g < groups_; ++g) {
std::shared_ptr<Tensor<float>> output_tensor = outputs.at(i);
if (output_tensor == nullptr || output_tensor->empty()) {
Expand Down Expand Up @@ -243,14 +243,14 @@ InferStatus ConvolutionLayer::Forward(
if (conv_type_ == ConvType::OpConv) {
input_matrix = ConvIm2Col(input, kernel_h, kernel_w, input_h, input_w,
input_c_group, g, row_len, col_len);
#pragma omp parallel for schedule(dynamic)
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
ConvGemmBias(input_matrix, output_tensor, g, k, kernel_count_group,
output_h, output_w);
}
} else {
CHECK(conv_type_ == ConvType::OpDeconv);
#pragma omp parallel for schedule(dynamic)
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
const arma::fmat& gemm_result = DeconvGemm(
input, input_h, input_w, input_c_group, g, k, kernel_count_group);
Expand Down Expand Up @@ -282,7 +282,7 @@ void ConvolutionLayer::DeconvCol2ImWithBias(

uint32_t slide_count_w = (size_w - kernel_w) / stride_w_ + 1;
uint32_t slide_count_h = (size_h - kernel_h) / stride_h_ + 1;
#pragma omp parallel for collapse(2) schedule(dynamic)
#pragma omp parallel for collapse(2)
for (uint32_t x = 0; x < slide_count_w; ++x) {
for (uint32_t y = 0; y < slide_count_h; ++y) {
const uint32_t offset_x = x * stride_w_;
Expand Down Expand Up @@ -356,7 +356,7 @@ arma::fmat ConvolutionLayer::ConvIm2Col(sftensor input, uint32_t kernel_h,
const uint32_t input_padded_h = input_h + 2 * padding_h_;
const uint32_t input_padded_w = input_w + 2 * padding_w_;
const float padding_value = 0.f;
#pragma omp parallel for schedule(dynamic)
#pragma omp parallel for
for (uint32_t ic = 0; ic < input_c_group; ++ic) {
float* input_channel_ptr =
input->matrix_raw_ptr(ic + group * input_c_group);
Expand Down

0 comments on commit bb6b5c9

Please sign in to comment.