Skip to content

Commit

Permalink
Fix pytorch crashing in release
Browse files Browse the repository at this point in the history
  • Loading branch information
Macdu committed Nov 14, 2023
1 parent 8694d0c commit 2c7db44
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/competitors/cpu/cpu_pytorch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace Competitors {
torch::Tensor crow_indices = torch::tensor(S.getRowPositions(), torch::kInt);
torch::Tensor col_indices = torch::tensor(S.getColPositions(), torch::kInt);
torch::Tensor values = torch::tensor(S.getValues(), scalar_type);
torch::IntArrayRef size = {S.getRows(), S.getCols()};
std::vector<int64_t> size = {S.getRows(), S.getCols()};
at::TensorOptions options = torch::device(torch::kCPU).dtype(scalar_type);
torch::Tensor sparse_tensor = torch::sparse_csr_tensor(crow_indices, col_indices, values, size, options);

Expand Down
2 changes: 1 addition & 1 deletion src/competitors/gpu/gpu_pytorch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace Competitors {
torch::Tensor crow_indices = torch::tensor(S.getRowPositions(), int_type);
torch::Tensor col_indices = torch::tensor(S.getColPositions(), int_type);
torch::Tensor values = torch::tensor(S.getValues(), scalar_type);
torch::IntArrayRef size = {S.getRows(), S.getCols()};
std::vector<int64_t> size = {S.getRows(), S.getCols()};
torch::Tensor sparse_tensor = torch::sparse_csr_tensor(crow_indices, col_indices, values, size, scalar_type).to(gpu);

torch::Tensor result = at::native::sparse_sampled_addmm_sparse_csr_cuda(sparse_tensor, A_tensor, B_tensor, 0, 1).to(cpu);
Expand Down

0 comments on commit 2c7db44

Please sign in to comment.