Skip to content

Commit

Permalink
#0: Simpler threading, get rid of canceling allocation feature
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Oct 23, 2024
1 parent f954e36 commit 9585a28
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 80 deletions.
93 changes: 18 additions & 75 deletions tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,92 +218,40 @@ std::shared_ptr<Buffer> Buffer::create(
buffer->weak_self = buffer;

if (buffer->size_ == 0) {
buffer->allocation_status_ = AllocationStatus::ALLOCATED;
return buffer;
}

// Faster path for single-threaded mode
if (buffer->device_->can_use_passthrough_scheduling()) {
buffer->allocate_impl();
buffer->allocation_status_ = AllocationStatus::ALLOCATED;
buffer->allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::relaxed);
return buffer;
}

buffer->device_->push_work([buffer] {
auto expected_status = AllocationStatus::ALLOCATION_REQUESTED;
if (!buffer->allocation_status_.compare_exchange_strong(expected_status, AllocationStatus::ALLOCATING)) {
// Buffer was already deallocated before we got here
buffer->allocation_status_.notify_all();
return;
}

buffer->allocate_impl();
bool bottom_up = buffer->bottom_up_.value_or(buffer->is_dram());
buffer->address_ = detail::AllocateBuffer(buffer.get(), bottom_up);
detail::BUFFER_MAP.insert({buffer->device_->id(), buffer->address_}, buffer.get());

// We need compare exchange here to handle the case of deallocation being requested before we finished allocating
expected_status = AllocationStatus::ALLOCATING;
buffer->allocation_status_.compare_exchange_strong(expected_status, AllocationStatus::ALLOCATED);
buffer->allocation_status_.store(AllocationStatus::ALLOCATED, std::memory_order::release);
buffer->allocation_status_.notify_all();
});

return buffer;
}

void Buffer::allocate_impl() {
bool bottom_up = bottom_up_.value_or(is_dram());
address_ = detail::AllocateBuffer(this, bottom_up);
detail::BUFFER_MAP.insert({device_->id(), address_}, this);
}

bool Buffer::prepare_deallocation(std::atomic<AllocationStatus>& status) {
while (true) {
auto current_status = status.load();
switch (current_status) {
case AllocationStatus::ALLOCATION_REQUESTED:
// Allocation was requested but not started, canceling allocation, nothing else to be done
if (status.compare_exchange_weak(current_status, AllocationStatus::DEALLOCATED)) {
return false;
}
break;
case AllocationStatus::ALLOCATING:
case AllocationStatus::ALLOCATED:
// Allocation already started, will have to deallocate
if (status.compare_exchange_weak(current_status, AllocationStatus::DEALLOCATION_REQUESTED)) {
return true;
}
break;
case AllocationStatus::DEALLOCATION_REQUESTED:
case AllocationStatus::DEALLOCATED:
// Deallocation was already started, nothing to be done
return false;
}
}
}

void Buffer::deallocate() {
if (!prepare_deallocation(allocation_status_)) {
return;
}

deallocation_requested_.store(true, std::memory_order::relaxed);
device_->push_work([self = weak_self.lock()] {
auto expected_status = AllocationStatus::DEALLOCATION_REQUESTED;
if (!self->allocation_status_.compare_exchange_strong(expected_status, AllocationStatus::DEALLOCATED)) {
// Buffer was already deallocated, nothing to do
if (self->allocation_status_.load(std::memory_order::relaxed) == AllocationStatus::ALLOCATED) {
return;
}

if (self->device_->initialized_ && self->size_ != 0) {
// address_ is only modified from this thread, no sync required
detail::BUFFER_MAP.erase({self->device_->id(), self->address_});
detail::DeallocateBuffer(self.get());
}

self->allocation_status_.store(AllocationStatus::DEALLOCATED, std::memory_order::relaxed);
});
}

void Buffer::deleter(Buffer* buffer) {
// There is no concurrent allocations/deallocations happening, so no extra checks are required
if (buffer->allocation_status_ == AllocationStatus::DEALLOCATED) {
return;
}

buffer->device_->push_work([buffer] {
std::unique_ptr<Buffer> unique_buffer = std::unique_ptr<Buffer>(buffer);

Expand All @@ -317,28 +265,23 @@ void Buffer::deleter(Buffer* buffer) {
}

bool Buffer::is_allocated() const {
auto allocation_status = allocation_status_.load();
if (deallocation_requested_.load(std::memory_order::relaxed)) {
return false;
}

auto allocation_status = allocation_status_.load(std::memory_order::relaxed);

if (device_->can_use_passthrough_scheduling()) {
return allocation_status == AllocationStatus::ALLOCATED;
}

// For calls from different threads we consider buffer to be allocated even if it's just ALLOCATION_REQUESTED or ALLOCATING,
// For calls from different threads we consider buffer to be allocated even if it's just ALLOCATION_REQUESTED,
// because once the caller will try to access it, the buffer will already be fully allocated
return allocation_status == AllocationStatus::ALLOCATION_REQUESTED
|| allocation_status == AllocationStatus::ALLOCATING
|| allocation_status == AllocationStatus::ALLOCATED;
return allocation_status == AllocationStatus::ALLOCATION_REQUESTED || allocation_status == AllocationStatus::ALLOCATED;
}

uint32_t Buffer::address() const {
if (device_->can_use_passthrough_scheduling()) {
return address_;
}

// Waiting for the buffer to be allocated if the allocation is pending
allocation_status_.wait(AllocationStatus::ALLOCATION_REQUESTED);
allocation_status_.wait(AllocationStatus::ALLOCATING);

allocation_status_.wait(AllocationStatus::ALLOCATION_REQUESTED, std::memory_order::acquire);
return address_;
}

Expand Down
8 changes: 3 additions & 5 deletions tt_metal/impl/buffers/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,12 @@ class Buffer final {

enum class AllocationStatus : uint8_t {
ALLOCATION_REQUESTED,
ALLOCATING,
ALLOCATED,
DEALLOCATION_REQUESTED,
DEALLOCATED,
};

void allocate_impl();
// Deallocate is allowed to be called multiple times on the same buffer
void deallocate();
static bool prepare_deallocation(std::atomic<AllocationStatus>& status);
static void deleter(Buffer* buffer);
friend void DeallocateBuffer(Buffer &buffer);

Expand All @@ -236,7 +232,9 @@ class Buffer final {
const std::optional<bool> bottom_up_;

std::atomic<AllocationStatus> allocation_status_ = AllocationStatus::ALLOCATION_REQUESTED;
std::atomic<DeviceAddr> address_ = 0;
DeviceAddr address_ = 0;
// Used exclusively for is_allocated() method
std::atomic<bool> deallocation_requested_ = false;

// These members must be only accessed on the device worker thread
DeviceAddr page_size_; // Size of unit being interleaved. For non-interleaved buffers: size == page_size
Expand Down

0 comments on commit 9585a28

Please sign in to comment.