diff --git a/include/data/tensor.hpp b/include/data/tensor.hpp index f712c398..e4f816d7 100644 --- a/include/data/tensor.hpp +++ b/include/data/tensor.hpp @@ -258,6 +258,9 @@ class Tensor { */ T* matrix_raw_ptr(uint32_t index); + private: + void Review(const std::vector& shapes); + private: std::vector raw_shapes_; // 张量数据的实际尺寸大小 arma::Cube data_; // 张量数据 diff --git a/source/data/tensor.cpp b/source/data/tensor.cpp index 5c1abe88..ad5d829d 100644 --- a/source/data/tensor.cpp +++ b/source/data/tensor.cpp @@ -321,23 +321,28 @@ void Tensor::Reshape(const std::vector& shapes, bool row_major) { CHECK(shapes.size() <= 3); CHECK(current_size == origin_size); - std::vector values; - if (row_major) { - values = this->values(true); - } - if (shapes.size() == 3) { - this->data_.reshape(shapes.at(1), shapes.at(2), shapes.at(0)); - this->raw_shapes_ = {shapes.at(0), shapes.at(1), shapes.at(2)}; - } else if (shapes.size() == 2) { - this->data_.reshape(shapes.at(0), shapes.at(1), 1); - this->raw_shapes_ = {shapes.at(0), shapes.at(1)}; + if (!row_major) { + if (shapes.size() == 3) { + this->data_.reshape(shapes.at(1), shapes.at(2), shapes.at(0)); + this->raw_shapes_ = {shapes.at(0), shapes.at(1), shapes.at(2)}; + } else if (shapes.size() == 2) { + this->data_.reshape(shapes.at(0), shapes.at(1), 1); + this->raw_shapes_ = {shapes.at(0), shapes.at(1)}; + } else { + this->data_.reshape(1, shapes.at(0), 1); + this->raw_shapes_ = {shapes.at(0)}; + } } else { - this->data_.reshape(1, shapes.at(0), 1); - this->raw_shapes_ = {shapes.at(0)}; - } - - if (row_major) { - this->Fill(values, true); + if (shapes.size() == 3) { + this->Review({shapes.at(0), shapes.at(1), shapes.at(2)}); + this->raw_shapes_ = {shapes.at(0), shapes.at(1), shapes.at(2)}; + } else if (shapes.size() == 2) { + this->Review({1, shapes.at(0), shapes.at(1)}); + this->raw_shapes_ = {shapes.at(0), shapes.at(1)}; + } else { + this->Review({1, 1, shapes.at(0)}); + this->raw_shapes_ = {shapes.at(0)}; + } } } @@ -395,6 +400,40 @@ void Tensor::set_data(arma::Cube&& data) { this->data_ = std::move(data); } +template +void Tensor::Review(const std::vector& shapes) { + CHECK(!this->data_.empty()); + CHECK_EQ(shapes.size(), 3); + const uint32_t target_channels = shapes.at(0); + const uint32_t target_rows = shapes.at(1); + const uint32_t target_cols = shapes.at(2); + + CHECK_EQ(this->data_.size(), target_channels * target_cols * target_rows); + arma::Cube new_data(target_rows, target_cols, target_channels); + const uint32_t plane_size = target_rows * target_cols; +#pragma omp parallel for + for (uint32_t channel = 0; channel < this->data_.n_slices; ++channel) { + const arma::Mat& channel_data = this->data_.slice(channel); + const uint32_t plane_start = channel * data_.n_rows * data_.n_cols; + for (uint32_t src_col = 0; src_col < this->data_.n_cols; ++src_col) { + const T* col_ptr = channel_data.colptr(src_col); + for (uint32_t src_row = 0; src_row < this->data_.n_rows; ++src_row) { + const uint32_t pos_index = + plane_start + src_row * data_.n_cols + src_col; + const uint32_t dest_channel = pos_index / plane_size; + const uint32_t dest_row = + (pos_index - dest_channel * plane_size) / target_cols; + const uint32_t dest_col = + (pos_index - dest_channel * plane_size - dest_row * target_cols); + CHECK(dest_channel < new_data.n_slices && dest_col < new_data.n_cols && + dest_row < new_data.n_rows); + new_data.at(dest_row, dest_col, dest_channel) = *(col_ptr + src_row); + } + } + } + this->data_ = std::move(new_data); +} + template class Tensor; template class Tensor; template class Tensor;