Skip to content

Commit

Permalink
#0: Migrate ssm_prefix_scan to ttnn
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Jul 29, 2024
1 parent 38166ac commit fcf24f6
Show file tree
Hide file tree
Showing 19 changed files with 279 additions and 142 deletions.
6 changes: 3 additions & 3 deletions models/demos/wormhole/mamba/tt/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ def forward(self, x):
ttnn.deallocate(abar2)
bmulx0_sharded = ttnn.to_memory_config(bmulx0, self.configs["sharded_scan"])
ttnn.deallocate(bmulx0)
hidden_states_sharded = ttl.operations.primary.transformers.ssm_prefix_scan(
hidden_states_sharded = ttnn.experimental.prefix_scan(
abar2_sharded,
bmulx0_sharded,
prev_hidden_state,
output_mem_config=self.configs["sharded_scan"],
output_dtype=ttnn.bfloat8_b,
memory_config=self.configs["sharded_scan"],
dtype=ttnn.bfloat8_b,
)
ttnn.deallocate(abar2_sharded)
ttnn.deallocate(bmulx0_sharded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import ttnn
import tt_lib as ttl
import pytest
from loguru import logger
Expand Down Expand Up @@ -64,9 +65,7 @@ def run_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device):
.to(device, h_memory_config)
)

actual = ttl.operations.primary.transformers.ssm_prefix_scan(
a, bx, h_prev, output_mem_config=memory_config, output_dtype=dtype
)
actual = ttnn.experimental.prefix_scan(a, bx, h_prev, memory_config=memory_config, dtype=dtype)
assert list(actual.get_legacy_shape()) == list(expected.shape)
assert actual.dtype == dtype

Expand Down Expand Up @@ -164,9 +163,7 @@ def to_device(x):
a_chunk = to_device(a_chunks[idx])
bx_chunk = to_device(bx_chunks[idx])

h_chunk = ttl.operations.primary.transformers.ssm_prefix_scan(
a_chunk, bx_chunk, h_prev, output_mem_config=memory_config, output_dtype=dtype
)
h_chunk = ttnn.experimental.prefix_scan(a_chunk, bx_chunk, h_prev, memory_config=memory_config, dtype=dtype)
actual.append(tt2torch_tensor(h_chunk))

actual = torch.concat(actual, dim=2)
Expand Down
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/transformer/device/transformer_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/split/split_tiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/split/split_last_dim_two_chunks_tiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/concat/multi_core/concat_op_multi_core.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,58 +472,6 @@ operation::ProgramWithCallbacks SSM1DSumReduce::create_program(
return multi_core_ssm_1d_sum_reduce(input_tensor_a, output_tensor, math_fidelity, device_compute_with_storage_grid_size);
}

void SSMPrefixScan::validate(const std::vector<Tensor>& input_tensors) const {
TT_FATAL(input_tensors.size() == 3, "Expected 3 input tensors (A, Bx, H)");

const auto& a = input_tensors.at(0);
const auto& bx = input_tensors.at(1);
TT_FATAL(a.dtype() == bx.dtype(), "Expected input tensors to have the same data type");
TT_FATAL(a.layout() == Layout::TILE && bx.layout() == Layout::TILE, "Expected input tensors to be tile layout");
TT_FATAL(a.get_legacy_shape() == bx.get_legacy_shape(), "Expected input tensors to have the same shape");

const auto& shape = a.get_legacy_shape();
TT_FATAL(shape.rank() == 4, "Expected input tensors to be rank 4");
TT_FATAL(shape[0] == 1 && shape[1] == 1, "Dimension 0 and 1 should be size 1");
TT_FATAL(shape[2] >= TILE_HEIGHT && shape[2] % TILE_HEIGHT == 0, "Sequence length should be a multiple of 32");

const auto& h = input_tensors.at(2);
TT_FATAL(h.dtype() == DataType::BFLOAT16, "Expected initial hidden state to be bfloat16");
TT_FATAL(h.layout() == Layout::ROW_MAJOR, "Expected initial hidden state to be row-major");
//TT_FATAL(h.get_legacy_shape() == {1, 1, 1, shape[3]}, "Expected initial hidden state to have the same hidden size as A and Bx");

TT_FATAL(a.is_sharded() && bx.is_sharded() && h.is_sharded(), "Expected input tensors to be sharded");
TT_FATAL(a.shard_spec().has_value() && bx.shard_spec().has_value() && h.shard_spec().has_value(), "Expected input tensors to be sharded");
TT_FATAL(
a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR,
"Expected A tensor to be row major orientation");
TT_FATAL(
bx.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR,
"Expected Bx tensor to be row major orientation");
TT_FATAL(
h.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR,
"Expected h tensor to be row major orientation");
}

std::vector<Shape> SSMPrefixScan::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
const auto& a = input_tensors.at(0);
return {a.get_legacy_shape()};
}

std::vector<Tensor> SSMPrefixScan::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
}

operation::ProgramWithCallbacks SSMPrefixScan::create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const {
const auto& a = input_tensors.at(0);
const auto& bx = input_tensors.at(1);
const auto& h = input_tensors.at(2);
auto& output = output_tensors.at(0);
auto device_compute_with_storage_grid_size = a.device()->compute_with_storage_grid_size();
return multi_core_ssm_prefix_scan(a, bx, h, output, math_fidelity, device_compute_with_storage_grid_size);
}

} // namespace transformers
} // namespace primary
} // namespace operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,42 +180,6 @@ inline Tensor ssm_1d_sum_reduce(const Tensor &input_tensor_a, const MemoryConfig
return output_tensors.at(0);
}

struct SSMPrefixScan {
MemoryConfig output_mem_config;
DataType output_dtype;
MathFidelity math_fidelity;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<Shape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
};

inline Tensor ssm_prefix_scan(
const Tensor& a,
const Tensor& bx,
const Tensor& h,
const MemoryConfig& mem_config,
std::optional<const DataType> output_dtype = std::nullopt,
MathFidelity math_fidelity = MathFidelity::HiFi4) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({a, bx}))};
operation::launch_op(
[mem_config, output_dtype, math_fidelity](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
const auto& a = input_tensors.at(0);
const auto& bx = input_tensors.at(1);
const auto& h = input_tensors.at(2);
return operation::run(
SSMPrefixScan{mem_config, output_dtype.value_or(a.get_dtype()), math_fidelity}, input_tensors);
},
{a, bx, h},
output_tensors);
return output_tensors.at(0);
}

} // namespace transformers

} // namespace primary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,6 @@ void py_module(py::module& m_transformers) {
Performs a custom reduction along dim 3 which is used in the SSM block of the Mamba architecture. Performs the following PyTorch equivalent (where latent_size = 32):
x = torch.sum(x.reshape(1, 1, shape[2], shape[3] // latent_size, latent_size), dim=-1).reshape(1, 1, shape[2], shape[3] // latent_size)
)doc");
m_transformers.def(
"ssm_prefix_scan",
&ssm_prefix_scan,
py::arg().noconvert(),
py::arg().noconvert(),
py::arg().noconvert(),
py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
py::arg("output_dtype").noconvert() = std::nullopt,
py::arg("math_fidelity").noconvert() = MathFidelity::HiFi4,
R"doc(
Performs a prefix scan to produce the SSM hidden states across an entire sequence. All input and output tensors are expected to be shape [1, 1, L, 2EN] where E = 2560 and N = 32. L can be any multiple of 32.)doc");


py::class_<SDPADefaultProgramConfig>(m_transformers, "SDPADefaultProgramConfig")
.def(py::init<>());

Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/experimental/experimental_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "ttnn/operations/experimental/transformer/transformer_pybind.hpp"
#include "ttnn/operations/experimental/reduction/argmax/argmax_pybind.hpp"
#include "ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.hpp"

namespace ttnn::operations::experimental {

Expand All @@ -17,6 +18,7 @@ void py_module(py::module& module) {
transformer::detail::bind_experimental_transformer_operations(module);
reduction::detail::bind_argmax_operation(module);
reduction::detail::bind_argmin_operation(module);
ssm::detail::bind_prefix_scan(module);
}

} // namespace ttnn::operations::experimental
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "prefix_scan_op.hpp"

#include "prefix_scan_program_factory.hpp"

namespace ttnn::operations::experimental::ssm {

void PrefixScan::validate(const std::vector<Tensor>& input_tensors) const {
TT_FATAL(input_tensors.size() == 3, "Expected 3 input tensors (A, Bx, H)");

const auto& a = input_tensors[0];
const auto& bx = input_tensors[1];
TT_FATAL(a.dtype() == bx.dtype(), "Expected input tensors to have the same data type");
TT_FATAL(a.layout() == Layout::TILE && bx.layout() == Layout::TILE, "Expected input tensors to be tile layout");
TT_FATAL(a.get_legacy_shape() == bx.get_legacy_shape(), "Expected input tensors to have the same shape");

const auto& shape = a.get_legacy_shape();
TT_FATAL(shape.rank() == 4, "Expected input tensors to be rank 4");
TT_FATAL(shape[0] == 1 && shape[1] == 1, "Dimension 0 and 1 should be size 1");
TT_FATAL(shape[2] >= TILE_HEIGHT && shape[2] % TILE_HEIGHT == 0, "Sequence length should be a multiple of 32");

const auto& h = input_tensors.at(2);
TT_FATAL(h.dtype() == DataType::BFLOAT16, "Expected initial hidden state to be bfloat16");
TT_FATAL(h.layout() == Layout::ROW_MAJOR, "Expected initial hidden state to be row-major");

TT_FATAL(a.is_sharded() && bx.is_sharded() && h.is_sharded(), "Expected input tensors to be sharded");
TT_FATAL(
a.shard_spec().has_value() && bx.shard_spec().has_value() && h.shard_spec().has_value(),
"Expected input tensors to be sharded");
TT_FATAL(
a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR,
"Expected A tensor to be row major orientation");
TT_FATAL(
bx.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR,
"Expected Bx tensor to be row major orientation");
TT_FATAL(
h.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR,
"Expected h tensor to be row major orientation");
}

std::vector<tt::tt_metal::Shape> PrefixScan::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
const auto& a = input_tensors.at(0);
return {a.get_legacy_shape()};
}

std::vector<Tensor> PrefixScan::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
return operation::generic_create_output_tensors(
*this, input_tensors, this->dtype, Layout::TILE, this->memory_config);
}

operation::ProgramWithCallbacks PrefixScan::create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const {
const auto& a = input_tensors.at(0);
const auto& bx = input_tensors.at(1);
const auto& h = input_tensors.at(2);
auto& output = output_tensors.at(0);
auto device_compute_with_storage_grid_size = a.device()->compute_with_storage_grid_size();
return tt::operations::experimental::ssm::detail::multi_core_ssm_prefix_scan(
a, bx, h, output, math_fidelity, device_compute_with_storage_grid_size);
}
} // namespace ttnn::operations::experimental::ssm
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/run_operation.hpp"
#include "ttnn/tensor/tensor.hpp"

namespace ttnn::operations::experimental::ssm {

struct PrefixScan {
MemoryConfig memory_config;
DataType dtype;
MathFidelity math_fidelity;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<tt::tt_metal::Shape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
};

} // namespace ttnn::operations::experimental::ssm
Loading

0 comments on commit fcf24f6

Please sign in to comment.