Skip to content

Commit

Permalink
Merge branch 'main' into torch2_concat_sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
jaykru-tt authored Nov 18, 2024
2 parents 184b917 + 55912dc commit 78fe1de
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void Tilize::validate(const std::vector<Tensor>& input_tensors) const {
auto width = input_tensor_a.get_legacy_shape()[-1];
uint32_t stick_s = width;
uint32_t num_sticks = input_tensor_a.volume() / width;
TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Error");
TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16 or input_tensor_a.get_dtype() == DataType::FLOAT32, "data type must be bfloat16 or float32");

uint32_t stick_size = stick_s * input_tensor_a.element_size(); // Assuming bfloat16 dataformat

Expand Down

0 comments on commit 78fe1de

Please sign in to comment.