Skip to content

Commit

Permalink
修改reshape的实现方法
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Sep 7, 2023
1 parent 8c43749 commit 41a6186
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
3 changes: 3 additions & 0 deletions include/data/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ class Tensor {
*/
T* matrix_raw_ptr(uint32_t index);

private:
void Review(const std::vector<uint32_t>& shapes);

private:
std::vector<uint32_t> raw_shapes_; // 张量数据的实际尺寸大小
arma::Cube<T> data_; // 张量数据
Expand Down
71 changes: 55 additions & 16 deletions source/data/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,28 @@ void Tensor<T>::Reshape(const std::vector<uint32_t>& shapes, bool row_major) {
CHECK(shapes.size() <= 3);
CHECK(current_size == origin_size);

std::vector<T> values;
if (row_major) {
values = this->values(true);
}
if (shapes.size() == 3) {
this->data_.reshape(shapes.at(1), shapes.at(2), shapes.at(0));
this->raw_shapes_ = {shapes.at(0), shapes.at(1), shapes.at(2)};
} else if (shapes.size() == 2) {
this->data_.reshape(shapes.at(0), shapes.at(1), 1);
this->raw_shapes_ = {shapes.at(0), shapes.at(1)};
if (!row_major) {
if (shapes.size() == 3) {
this->data_.reshape(shapes.at(1), shapes.at(2), shapes.at(0));
this->raw_shapes_ = {shapes.at(0), shapes.at(1), shapes.at(2)};
} else if (shapes.size() == 2) {
this->data_.reshape(shapes.at(0), shapes.at(1), 1);
this->raw_shapes_ = {shapes.at(0), shapes.at(1)};
} else {
this->data_.reshape(1, shapes.at(0), 1);
this->raw_shapes_ = {shapes.at(0)};
}
} else {
this->data_.reshape(1, shapes.at(0), 1);
this->raw_shapes_ = {shapes.at(0)};
}

if (row_major) {
this->Fill(values, true);
if (shapes.size() == 3) {
this->Review({shapes.at(0), shapes.at(1), shapes.at(2)});
this->raw_shapes_ = {shapes.at(0), shapes.at(1), shapes.at(2)};
} else if (shapes.size() == 2) {
this->Review({1, shapes.at(0), shapes.at(1)});
this->raw_shapes_ = {shapes.at(0), shapes.at(1)};
} else {
this->Review({1, 1, shapes.at(0)});
this->raw_shapes_ = {shapes.at(0)};
}
}
}

Expand Down Expand Up @@ -395,6 +400,40 @@ void Tensor<T>::set_data(arma::Cube<T>&& data) {
this->data_ = std::move(data);
}

template <typename T>
void Tensor<T>::Review(const std::vector<uint32_t>& shapes) {
CHECK(!this->data_.empty());
CHECK_EQ(shapes.size(), 3);
const uint32_t target_channels = shapes.at(0);
const uint32_t target_rows = shapes.at(1);
const uint32_t target_cols = shapes.at(2);

CHECK_EQ(this->data_.size(), target_channels * target_cols * target_rows);
arma::Cube<T> new_data(target_rows, target_cols, target_channels);
const uint32_t plane_size = target_rows * target_cols;
#pragma omp parallel for
for (uint32_t channel = 0; channel < this->data_.n_slices; ++channel) {
const arma::Mat<T>& channel_data = this->data_.slice(channel);
const uint32_t plane_start = channel * data_.n_rows * data_.n_cols;
for (uint32_t src_col = 0; src_col < this->data_.n_cols; ++src_col) {
const T* col_ptr = channel_data.colptr(src_col);
for (uint32_t src_row = 0; src_row < this->data_.n_rows; ++src_row) {
const uint32_t pos_index =
plane_start + src_row * data_.n_cols + src_col;
const uint32_t dest_channel = pos_index / plane_size;
const uint32_t dest_row =
(pos_index - dest_channel * plane_size) / target_cols;
const uint32_t dest_col =
(pos_index - dest_channel * plane_size - dest_row * target_cols);
CHECK(dest_channel < new_data.n_slices && dest_col < new_data.n_cols &&
dest_row < new_data.n_rows);
new_data.at(dest_row, dest_col, dest_channel) = *(col_ptr + src_row);
}
}
}
this->data_ = std::move(new_data);
}

template class Tensor<float>;
template class Tensor<int>;
template class Tensor<uint8_t>;
Expand Down

0 comments on commit 41a6186

Please sign in to comment.