Skip to content

Commit

Permalink
#4148: Temporary changes to allow all_gather to work with op profiler…
Browse files Browse the repository at this point in the history
…. Will be removed once new profiler + software queues come in
  • Loading branch information
tt-aho committed Feb 7, 2024
1 parent d752cca commit 280ef36
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 14 deletions.
35 changes: 33 additions & 2 deletions tt_eager/tt_dnn/op_library/all_gather/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,42 @@ std::vector<Tensor> all_gather(const std::vector<Tensor>& input_tensors, uint32_

std::vector<Tensor> output_tensors;
output_tensors.reserve(input_tensors.size());

// Temporary changes to allow multi-device ops to work with op profiler
// Should be removed with new profiler + software queue changes
tt:tt_metal::operation::skip_profile = getDeviceProfilerState();
std::vector<AllGather> ops;
ops.reserve(input_tensors.size());
for (uint32_t i = 0; i < input_tensors.size(); ++i) {
chip_id_t receiver_device_id = input_tensors[(i + 1) % input_tensors.size()].device()->id();
chip_id_t sender_device_id = input_tensors[i == 0 ? input_tensors.size() - 1 : i - 1].device()->id();
output_tensors.push_back(operation::run(AllGather{dim, (uint32_t)input_tensors.size(), i, receiver_device_id, sender_device_id, output_mem_config}, {input_tensors.at(i)}).at(0));
ops.emplace_back(AllGather{dim, static_cast<uint32_t>(input_tensors.size()), i, receiver_device_id, sender_device_id, output_mem_config});
output_tensors.push_back(operation::run(ops[i], {input_tensors[i]}).at(0));
}
if (tt::tt_metal::operation::skip_profile) {
for (uint32_t i = 0; i < input_tensors.size(); ++i) {
const auto& operation = ops[i];
const std::vector<Tensor> inputs = {input_tensors[i]};
const std::vector<Tensor> outputs = {output_tensors[i]};
const auto& program = operation::skipped_programs.at(input_tensors[i].device()->id());

tt::tt_metal::operation::ProfilerInfo profiler_info = {.preferred_name = "tt::tt_metal::AllGather", .parallelization_strategy = std::nullopt};
auto profile_scope = op_profiler::OpProfileScope(profiler_info.preferred_name.value(), op_profiler::OpType::tt_dnn_device);
auto do_profile = op_profiler::get_profiler_flag();
if (do_profile) {
if (profiler_info.preferred_name.has_value()) {
op_profiler::set_preferred_name(profiler_info.preferred_name.value());
}
if (profiler_info.parallelization_strategy.has_value()) {
op_profiler::set_parallelization_strategy(profiler_info.parallelization_strategy.value());
}
op_profiler::append_math_fidelities(program);
op_profiler::append_meta_data(fmt::format("{}", operation.attributes()));
}
op_profiler::dump_device_profiler_results(input_tensors[i].device(), program);
op_profiler::append_all_tensor_io_data(inputs, {}, outputs);
}
tt::tt_metal::operation::skip_profile = false;
operation::skipped_programs.clear();
}
return output_tensors;
}
Expand Down
32 changes: 22 additions & 10 deletions tt_eager/tt_dnn/op_library/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

namespace tt::tt_metal::operation {

bool skip_profile = false;
std::map<chip_id_t, std::reference_wrapper<Program>> skipped_programs;

bool is_logging_enabled() {
bool enabled = false;
if (std::getenv("TT_METAL_LOGGER_TYPES") != nullptr and std::getenv("TT_METAL_LOGGER_LEVEL") != nullptr) {
Expand Down Expand Up @@ -155,8 +158,10 @@ std::vector<Tensor> run_device_operation(
const std::vector<std::optional<Tensor>>& optional_output_tensors) {
ZoneScoped;
ZoneText(operation.get_type_name().c_str(), operation.get_type_name().size());

auto profile_scope = op_profiler::OpProfileScope(operation.get_type_name(), op_profiler::OpType::tt_dnn_device);
std::unique_ptr<op_profiler::OpProfileScope> profile_scope;
if (!operation::skip_profile) {
profile_scope = std::make_unique<op_profiler::OpProfileScope>(operation.get_type_name(), op_profiler::OpType::tt_dnn_device);
}

std::function<std::variant<Program, std::reference_wrapper<Program>>(
const DeviceOperation&,
Expand Down Expand Up @@ -215,9 +220,11 @@ std::vector<Tensor> run_device_operation(
[&operation, &input_tensors, &optional_input_tensors](auto& program) {
auto device = detail::get_device(input_tensors, optional_input_tensors);

auto do_profile = op_profiler::get_profiler_flag();
if (do_profile) {
detail::setup_profiler(operation, input_tensors, program);
if (!operation::skip_profile) {
auto do_profile = op_profiler::get_profiler_flag();
if (do_profile) {
detail::setup_profiler(operation, input_tensors, program);
}
}

if (USE_FAST_DISPATCH) {
Expand All @@ -235,17 +242,22 @@ std::vector<Tensor> run_device_operation(
operation.get_type_name(),
elapsed_seconds);
#endif
// Only need to dump device data when in dispatch mode
// LaunchKernel automatically dumps device data
op_profiler::dump_device_profiler_results(device, program);
if (!operation::skip_profile) {
// Only need to dump device data when in dispatch mode
// LaunchKernel automatically dumps device data
op_profiler::dump_device_profiler_results(device, program);
} else {
operation::skipped_programs.emplace(device->id(), std::ref(program));
}
} else {
::detail::LaunchProgram(device, program);
}
},
program);

op_profiler::append_all_tensor_io_data(input_tensors, optional_input_tensors, output_tensors);

if (!operation::skip_profile) {
op_profiler::append_all_tensor_io_data(input_tensors, optional_input_tensors, output_tensors);
}
return output_tensors;
}
} // namespace detail
Expand Down
4 changes: 4 additions & 0 deletions tt_eager/tt_dnn/op_library/run_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ namespace tt::tt_metal {

namespace operation {

// Temporary changes to allow multi-device ops to work with op profiler
// Should be removed with new profiler + software queue changes
extern bool skip_profile;
extern std::map<chip_id_t, std::reference_wrapper<Program>> skipped_programs;

template<typename ConcreteOperation>
std::vector<Tensor> generic_create_output_tensors(
Expand Down
3 changes: 1 addition & 2 deletions tt_metal/hw/firmware/src/erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ void __attribute__((section("code_l1"))) router_init() {
}

void __attribute__((section("erisc_l1_code"))) ApplicationHandler(void) {
kernel_profiler::init_profiler();
rtos_context_switch_ptr = (void (*)())RtosTable[0];

risc_init();
Expand All @@ -105,6 +104,7 @@ void __attribute__((section("erisc_l1_code"))) ApplicationHandler(void) {
while (routing_info->routing_enabled) {
// FD: assume that no more host -> remote writes are pending
if (erisc_info->launch_user_kernel == 1) {
kernel_profiler::init_profiler();
kernel_profiler::mark_time(CC_MAIN_START);
kernel_init();
kernel_profiler::mark_time(CC_MAIN_END);
Expand Down Expand Up @@ -155,5 +155,4 @@ void __attribute__((section("erisc_l1_code"))) ApplicationHandler(void) {
}
}
internal_::disable_erisc_app();
kernel_profiler::mark_time(CC_MAIN_END);
}

0 comments on commit 280ef36

Please sign in to comment.