diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dbd923a..10214b0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -42,5 +42,14 @@ if(MSVC) endif() find_package(Torch REQUIRED) - -target_link_libraries(dphpc "${TORCH_LIBRARIES}") \ No newline at end of file +target_link_libraries(dphpc "${TORCH_LIBRARIES}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET dphpc + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) \ No newline at end of file diff --git a/src/competitors/cpu/cpu_pytorch.hpp b/src/competitors/cpu/cpu_pytorch.hpp index ba701e2..68efdc7 100644 --- a/src/competitors/cpu/cpu_pytorch.hpp +++ b/src/competitors/cpu/cpu_pytorch.hpp @@ -46,8 +46,8 @@ namespace Competitors { torch::Tensor result = at::native::sparse_sampled_addmm_sparse_csr_cpu(sparse_tensor, A_tensor, B_tensor, 0, 1).to_sparse_csr(); - P.setRowPositions(std::vector(result.crow_indices().data_ptr(), result.crow_indices().data_ptr() + result.crow_indices().numel())); - P.setColPositions(std::vector(result.col_indices().data_ptr(), result.col_indices().data_ptr() + result.col_indices().numel())); + P.setRowPositions(std::vector(result.crow_indices().data_ptr(), result.crow_indices().data_ptr() + result.crow_indices().numel())); + P.setColPositions(std::vector(result.col_indices().data_ptr(), result.col_indices().data_ptr() + result.col_indices().numel())); P.setValues(std::vector(result.values().data_ptr(), result.values().data_ptr() + result.values().numel())); }