Skip to content

Commit

Permalink
Add missing CMake code for using Pytorch with MSVC
Browse files Browse the repository at this point in the history
  • Loading branch information
feldspath committed Oct 23, 2023
1 parent 4aa7664 commit 318c143
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 11 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,14 @@ if(MSVC)
endif()

find_package(Torch REQUIRED)

target_link_libraries(dphpc "${TORCH_LIBRARIES}")
target_link_libraries(dphpc "${TORCH_LIBRARIES}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

if (MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET dphpc
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:dphpc>)
endif (MSVC)
4 changes: 2 additions & 2 deletions src/competitors/cpu/cpu_pytorch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ namespace Competitors {

torch::Tensor result = at::native::sparse_sampled_addmm_sparse_csr_cpu(sparse_tensor, A_tensor, B_tensor, 0, 1).to_sparse_csr();

P.setRowPositions(std::vector<int>(result.crow_indices().data_ptr<int>(), result.crow_indices().data_ptr<int>() + result.crow_indices().numel()));
P.setColPositions(std::vector<int>(result.col_indices().data_ptr<int>(), result.col_indices().data_ptr<int>() + result.col_indices().numel()));
P.setRowPositions(std::vector<int>(result.crow_indices().data_ptr<int64_t>(), result.crow_indices().data_ptr<int64_t>() + result.crow_indices().numel()));
P.setColPositions(std::vector<int>(result.col_indices().data_ptr<int64_t>(), result.col_indices().data_ptr<int64_t>() + result.col_indices().numel()));
P.setValues(std::vector<T>(result.values().data_ptr<T>(), result.values().data_ptr<T>() + result.values().numel()));
}

Expand Down

0 comments on commit 318c143

Please sign in to comment.