Skip to content

Commit

Permalink
Merge pull request ColfaxResearch#2 from ColfaxResearch/smem-naive
Browse files Browse the repository at this point in the history
adding naive and bank conflict versions
  • Loading branch information
jayhshah authored Apr 18, 2024
2 parents fed3a00 + 0e095f1 commit c7b4d21
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 1 deletion.
6 changes: 5 additions & 1 deletion transpose-cute/main.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "cutlass/util/command_line.h"

#include "transpose_naive.h"
#include "transpose_smem.h"
#include "transpose_smem_bank_conflict.h"
#include "transpose_tmastore_vectorized.h"

int main(int argc, char const **argv) {
Expand All @@ -14,8 +16,10 @@ int main(int argc, char const **argv) {

std::cout << "(M, N): " << M << ", " << N << std::endl;

transpose_host_kernel_naive(M, N);
transpose_host_kernel_smem_bank_conflict(M, N);
transpose_host_kernel_smem(M, N);
transpose_host_kernel_tma(M, N);

return 0;
}
}
152 changes: 152 additions & 0 deletions transpose-cute/transpose_naive.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#pragma once

#include <cassert>
#include <cstdio>
#include <cstdlib>

#include <chrono>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include "cutlass/numeric_types.h"
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/cutlass.h>

#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"

#include "cutlass/detail/layout.hpp"

#include "shared_storage.h"

template <class TensorS, class TensorD, class ThreadLayoutS, class ThreadLayoutD>
__global__ static void __launch_bounds__(256, 1)
transposeKernelNaive(TensorS const S, TensorD const DT,
ThreadLayoutS const tS,
ThreadLayoutD const tD) {
using namespace cute;
using Element = typename TensorS::value_type;

Tensor gS = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (bM, bN)
Tensor gDT = DT(make_coord(_, _), blockIdx.x, blockIdx.y); // (bN, bM)


Tensor tSgS = local_partition(gS, tS, threadIdx.x); // (ThrValM, ThrValN)
Tensor tDgDT = local_partition(gDT, tD, threadIdx.x);

cute::copy(tSgS, tDgDT);

}

int transpose_host_kernel_naive(int M, int N) {
printf("NO tma, NO smem, not vectorized\n");

using Element = float;
using namespace cute;

auto tensor_shape = make_shape(M, N);
auto tensor_shape_trans = make_shape(N, M);

// Allocate and initialize
thrust::host_vector<Element> h_S(size(tensor_shape)); // (M, N)
thrust::host_vector<Element> h_D(size(tensor_shape_trans)); // (N, M)

for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}

thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;

//
// Make tensors
//

auto gmemLayoutS = make_layout(tensor_shape, GenRowMajor{});
auto gmemLayoutD = make_layout(tensor_shape_trans, GenRowMajor{});
Tensor tensor_S = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), gmemLayoutS);
Tensor tensor_D = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), gmemLayoutD);

// Make a transposed view of the output
auto gmemLayoutDT = make_layout(tensor_shape, GenColMajor{});
Tensor tensor_DT = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), gmemLayoutDT);

//
// Tile tensors
//

using bM = Int<32>;
using bN = Int<32>;

auto block_shape = make_shape(bM{}, bN{}); // (bM, bN)
auto block_shape_trans = make_shape(bN{}, bM{}); // (bN, bM)

Tensor tiled_tensor_S =
tiled_divide(tensor_S, block_shape); // ((bM, bN), m', n')
Tensor tiled_tensor_DT =
tiled_divide(tensor_DT, block_shape_trans); // ((bN, bM), n', m')

auto threadLayoutS =
make_layout(make_shape(Int<8>{}, Int<32>{}), GenRowMajor{});
auto threadLayoutD =
make_layout(make_shape(Int<8>{}, Int<32>{}), GenRowMajor{});

//
// Determine grid and block dimensions
//

dim3 gridDim(
size<1>(tiled_tensor_S),
size<2>(tiled_tensor_S)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(threadLayoutS)); // 256 threads

int iterations = 10;

for (int i = 0; i < iterations; i++) {
auto t1 = std::chrono::high_resolution_clock::now();
transposeKernelNaive<<<gridDim, blockDim>>>(
tiled_tensor_S, tiled_tensor_DT, threadLayoutS, threadLayoutD);
cudaError result = cudaDeviceSynchronize();
auto t2 = std::chrono::high_resolution_clock::now();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result)
<< std::endl;
return -1;
}
std::chrono::duration<double, std::milli> tDiff = t2 - t1;
double time_ms = tDiff.count();
std::cout << "Trial " << i << " Completed in " << time_ms << "ms ("
<< 2e-6 * M * N * sizeof(Element) / time_ms << " GB/s)"
<< std::endl;
}

//
// Verify
//

h_D = d_D;

int good = 0, bad = 0;

auto transposeFunction = make_layout(tensor_shape, GenRowMajor{});

for (size_t i = 0; i < h_D.size(); ++i) {
if (h_D[i] == h_S[transposeFunction(i)])
good++;
else
bad++;
}

std::cout << "Success " << good << ", Fail " << bad << std::endl;

return 0;
}
175 changes: 175 additions & 0 deletions transpose-cute/transpose_smem_bank_conflict.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#pragma once

#include <cassert>
#include <cstdio>
#include <cstdlib>

#include <chrono>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include "cutlass/numeric_types.h"
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/cutlass.h>

#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"

#include "cutlass/detail/layout.hpp"

#include "shared_storage.h"

template <class TensorS, class TensorD, class SmemLayout, class ThreadLayoutS,
class SmemLayoutT, class ThreadLayoutD>
__global__ static void __launch_bounds__(256, 1)
transposeKernelSmemBC(TensorS const S, TensorD const D,
SmemLayout const smemLayout, ThreadLayoutS const tS,
SmemLayoutT const smemLayoutT, ThreadLayoutD const tD) {
using namespace cute;
using Element = typename TensorS::value_type;

// Use Shared Storage structure to allocate aligned SMEM addresses.
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorageTranspose<Element, SmemLayout>;
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(shared_memory);

// two different views of smem
Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem.data()),
smemLayout); // (bM, bN)
Tensor sD = make_tensor(make_smem_ptr(shared_storage.smem.data()),
smemLayoutT); // (bN, bM)

Tensor gS = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (bM, bN)
Tensor gD = D(make_coord(_, _), blockIdx.y, blockIdx.x); // (bN, bM)

Tensor tSgS = local_partition(gS, tS, threadIdx.x); // (ThrValM, ThrValN)
Tensor tSsS = local_partition(sS, tS, threadIdx.x); // (ThrValM, ThrValN)
Tensor tDgD = local_partition(gD, tD, threadIdx.x);
Tensor tDsD = local_partition(sD, tD, threadIdx.x);

cute::copy(tSgS, tSsS); // LDGSTS

cp_async_fence();
cp_async_wait<0>();
__syncthreads();

cute::copy(tDsD, tDgD);
}

int transpose_host_kernel_smem_bank_conflict(int M, int N) {
printf("NO tma, smem passthrough, not vectorized, not swizzled\n");

using Element = float;
using namespace cute;

auto tensor_shape = make_shape(M, N);
auto tensor_shape_trans = make_shape(N, M);

// Allocate and initialize
thrust::host_vector<Element> h_S(size(tensor_shape)); // (M, N)
thrust::host_vector<Element> h_D(size(tensor_shape_trans)); // (N, M)

for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}

thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;

//
// Make tensors
//

// Could also have ColMajor.
auto gmemLayoutS = make_layout(tensor_shape, GenRowMajor{});
auto gmemLayoutD = make_layout(tensor_shape_trans, GenRowMajor{});

Tensor tensor_S = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), gmemLayoutS);
Tensor tensor_D = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), gmemLayoutD);

//
// Tile tensors
//

using bM = Int<32>;
using bN = Int<32>;

auto block_shape = make_shape(bM{}, bN{}); // (bM, bN)
auto block_shape_trans = make_shape(bN{}, bM{}); // (bN, bM)

Tensor tiled_tensor_S =
tiled_divide(tensor_S, block_shape); // ((bM, bN), m', n')
Tensor tiled_tensor_D =
tiled_divide(tensor_D, block_shape_trans); // ((bN, bM), n', m')

auto smemLayout = make_layout(block_shape, GenRowMajor{});
auto smemLayoutT = make_layout(block_shape, GenColMajor{});

auto threadLayoutS =
make_layout(make_shape(Int<8>{}, Int<32>{}), GenRowMajor{});
auto threadLayoutD =
make_layout(make_shape(Int<8>{}, Int<32>{}), GenRowMajor{});

size_t smem_size = int(
sizeof(SharedStorageTranspose<Element, decltype(smemLayout)>));

//
// Determine grid and block dimensions
//

dim3 gridDim(
size<1>(tiled_tensor_S),
size<2>(tiled_tensor_S)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(threadLayoutS)); // 256 threads

int iterations = 10;

for (int i = 0; i < iterations; i++) {
auto t1 = std::chrono::high_resolution_clock::now();
transposeKernelSmemBC<<<gridDim, blockDim, smem_size>>>(
tiled_tensor_S, tiled_tensor_D, smemLayout, threadLayoutS,
smemLayoutT, threadLayoutD);
cudaError result = cudaDeviceSynchronize();
auto t2 = std::chrono::high_resolution_clock::now();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result)
<< std::endl;
return -1;
}
std::chrono::duration<double, std::milli> tDiff = t2 - t1;
double time_ms = tDiff.count();
std::cout << "Trial " << i << " Completed in " << time_ms << "ms ("
<< 2e-6 * M * N * sizeof(Element) / time_ms << " GB/s)"
<< std::endl;
}

//
// Verify
//

h_D = d_D;

int good = 0, bad = 0;

auto transposeFunction = make_layout(tensor_shape, GenRowMajor{});

for (size_t i = 0; i < h_D.size(); ++i) {
if (h_D[i] == h_S[transposeFunction(i)])
good++;
else
bad++;
}

std::cout << "Success " << good << ", Fail " << bad << std::endl;

return 0;
}

0 comments on commit c7b4d21

Please sign in to comment.