Skip to content

Commit

Permalink
[Bugfix] Fixed is allocated (#12109)
Browse files Browse the repository at this point in the history
 fixed is_allocated

 removed new include
  • Loading branch information
dmakoviichuk-tt authored Sep 1, 2024
1 parent c2c334a commit dcd47ef
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 28 deletions.
27 changes: 2 additions & 25 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "tt_metal/graph/graph_tracking.hpp"
#include "ttnn/core.hpp"
#include "ttnn/tensor/tensor_ops.hpp"

using namespace tt::constants;


Expand Down Expand Up @@ -450,31 +451,7 @@ bool Tensor::is_allocated() const {
ZoneScoped;
auto output = std::visit(
[](auto&& storage) -> bool {
using T = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<T, OwnedStorage>) {
return std::visit([](auto&& buffer) -> bool { return buffer.is_allocated(); }, storage.buffer);
} else if constexpr (std::is_same_v<T, DeviceStorage>) {
return bool(storage.buffer) and storage.buffer->size() > 0;
} else if constexpr (std::is_same_v<T, BorrowedStorage>) {
return true;
} else if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
bool is_allocated = true;
for (int i = 0; i < storage.num_buffers(); i++) {
is_allocated &=
std::visit([](auto&& buffer) -> bool { return buffer.is_allocated(); }, storage.get_buffer(i));
}
return is_allocated;
} else if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
bool is_allocated = true;
for (int i = 0; i < storage.ordered_device_ids.size(); ++i) {
auto device_id = storage.ordered_device_ids[i];
const auto& buffer = storage.get_buffer_for_device_id(device_id);
is_allocated &= bool(buffer) and buffer->size() > 0;
}
return is_allocated;
} else {
raise_unsupported_storage<T>();
}
return storage.is_allocated();
},
this->get_storage());
return output;
Expand Down
38 changes: 35 additions & 3 deletions ttnn/cpp/ttnn/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,18 @@ struct OwnedStorage {
static constexpr auto attribute_names = std::forward_as_tuple();
const auto attribute_values() const { return std::forward_as_tuple(); }

inline void insert_buffer(OwnedBuffer buffer_) {
inline void insert_buffer(const OwnedBuffer& buffer_) {
this->buffer = buffer_;
}

inline OwnedBuffer get_buffer() const {
return this->buffer;
}

inline bool is_allocated() const {
return std::visit([](auto&& buffer) -> bool { return buffer.is_allocated(); }, buffer);
}

};

using DeviceBuffer = std::shared_ptr<Buffer>;
Expand Down Expand Up @@ -344,6 +348,10 @@ struct DeviceStorage {
inline DeviceBuffer get_buffer() const { return this->buffer; }
static constexpr auto attribute_names = std::forward_as_tuple("memory_config");
const auto attribute_values() const { return std::make_tuple(this->memory_config()); }

inline bool is_allocated() const {
return buffer && buffer->size() > 0;
}
};

using BorrowedBuffer = std::variant<
Expand Down Expand Up @@ -404,6 +412,11 @@ struct BorrowedStorage {

static constexpr auto attribute_names = std::forward_as_tuple();
const auto attribute_values() const { return std::forward_as_tuple(); }

inline bool is_allocated() const {
return true;
}

};

struct MultiDeviceHostStorage {
Expand Down Expand Up @@ -452,7 +465,7 @@ struct MultiDeviceHostStorage {

// Helper Functions - Getters and setters to get/modify storage attributes. These are needed to
// preinitialize empty tensor handles and use/populate them in the worker threads.
void insert_buffer_and_shape_for_device(int buffer_index, const OwnedBuffer buffer, const Shape shape) {
void insert_buffer_and_shape_for_device(int buffer_index, const OwnedBuffer& buffer, const Shape shape) {
std::lock_guard<std::mutex> lock(mtx);
buffers[buffer_index] = buffer;
shapes[buffer_index] = shape;
Expand All @@ -461,7 +474,7 @@ struct MultiDeviceHostStorage {
OwnedBuffer get_buffer(int buffer_index) const {
std::lock_guard<std::mutex> lock(mtx);
TT_ASSERT(buffer_index < buffers.size(), "Buffer not found for buffer_index " + std::to_string(buffer_index));
return buffers[buffer_index];;
return buffers[buffer_index];
}

OwnedBuffer& get_buffer(int buffer_index) {
Expand All @@ -480,6 +493,16 @@ struct MultiDeviceHostStorage {
std::lock_guard<std::mutex> lock(mtx);
return buffers.size();
}

inline bool is_allocated() const {
// not sure what is better mutex for each buffer 10 times or one here.
// I think this one is better.
std::lock_guard<std::mutex> lock(mtx);

return std::all_of(buffers.begin(), buffers.end(), [](auto&& buffer) {
return std::visit([](auto&& buffer) -> bool { return buffer.is_allocated(); }, buffer);
});
}
};

struct MultiDeviceStorage {
Expand Down Expand Up @@ -609,6 +632,15 @@ struct MultiDeviceHostStorage {
std::lock_guard<std::mutex> lock(buffer_mtx);
return buffers.find(device_id) != buffers.end();
}

inline bool is_allocated() const {
std::lock_guard<std::mutex> lock(buffer_mtx);

return std::all_of(ordered_device_ids.begin(), ordered_device_ids.end(), [&buffers = this->buffers](auto&& device_id) {
const auto& buffer = buffers.at(device_id);
return buffer && buffer->size() > 0;
});
}
};

using Storage = std::variant<OwnedStorage, DeviceStorage, BorrowedStorage, MultiDeviceHostStorage, MultiDeviceStorage>;
Expand Down

0 comments on commit dcd47ef

Please sign in to comment.