Skip to content

Commit

Permalink
#0: Update begin_trace to take in a trace id. Prep for multi trace/de…
Browse files Browse the repository at this point in the history
…vice
  • Loading branch information
tt-aho committed May 14, 2024
1 parent 4ed6580 commit fcc0af2
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 11 deletions.
3 changes: 1 addition & 2 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1874,9 +1874,8 @@ bool Device::using_slow_dispatch() const {
return not (this->using_fast_dispatch);
}

uint32_t Device::begin_trace(const uint8_t cq_id, const uint32_t trace_buff_size, const bool inc_trace_id) {
uint32_t Device::begin_trace(const uint8_t cq_id, const uint32_t tid, const uint32_t trace_buff_size) {
auto desc = std::make_shared<detail::TraceDescriptor>();
uint32_t tid = Trace::next_id(inc_trace_id);
TT_ASSERT(this->trace_buffer_pool_.at(cq_id).count(tid) == 0, "Trace already exists for tid {} on device", tid);
detail::EnableAllocs(this);
this->trace_buffer_pool_[cq_id][tid] = Trace::create_trace_buffer(this->command_queue(cq_id), desc, trace_buff_size);
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class Device {
CommandQueue& command_queue(size_t cq_id = 0);

// Metal trace device capture mode
uint32_t begin_trace(const uint8_t cq_id, const uint32_t trace_buff_size, const bool inc_trace_id);
uint32_t begin_trace(const uint8_t cq_id, const uint32_t tid, const uint32_t trace_buff_size);
void end_trace(const uint8_t cq_id, const uint32_t tid);
void replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking);
void release_trace(const uint8_t cq_id, const uint32_t tid);
Expand Down
8 changes: 2 additions & 6 deletions tt_metal/impl/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ namespace tt::tt_metal {

std::atomic<uint32_t> Trace::global_trace_id = 0;

uint32_t Trace::next_id(bool inc_id) {
if (inc_id) {
return global_trace_id++;
} else {
return global_trace_id;
}
uint32_t Trace::next_id() {
return global_trace_id++;
}

std::shared_ptr<TraceBuffer> Trace::create_trace_buffer(
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/trace/trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Trace {
public:
Trace() = delete;

static uint32_t next_id(bool inc_id);
static uint32_t next_id();

// Thread-safe accessors to manage trace instances
static void validate_instance(const TraceBuffer& trace_buffer);
Expand Down
4 changes: 3 additions & 1 deletion tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "tools/profiler/profiler.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/detail/program.hpp"
#include "tt_metal/impl/trace/trace.hpp"

#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp"

Expand Down Expand Up @@ -949,7 +950,8 @@ void UpdateRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, con
}

uint32_t BeginTraceCapture(Device *device, const uint8_t cq_id, const uint32_t trace_buff_size) {
return device->begin_trace(cq_id, trace_buff_size, true);
uint32_t tid = Trace::next_id();
return device->begin_trace(cq_id, tid, trace_buff_size);
}

void EndTraceCapture(Device *device, const uint8_t cq_id, const uint32_t tid) {
Expand Down

0 comments on commit fcc0af2

Please sign in to comment.