Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
MozammilQ committed Nov 11, 2024
1 parent 2ae116e commit ed9a907
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
10 changes: 1 addition & 9 deletions src/simulators/matrix_product_state/matrix_product_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,9 @@ void State::set_config(const Config &config) {
MPS::set_mps_lapack_svd(config.mps_lapack);

// Set device for SVD
MPS::set_mps_svd_device(config.device);

// Get CUDA device, if GPU offloading enabled
if (config.device.compare("GPU") == 0) {
#ifdef AER_THRUST_CUDA
cudaDeviceProp prop;
int deviceId{-1};
HANDLE_CUDA_ERROR(cudaGetDevice(&deviceId));
HANDLE_CUDA_ERROR(cudaGetDeviceProperties(&prop, deviceId));
MPS::set_mps_svd_device(config.device);
#endif // AER_THRUST_CUDA
}
}

void State::add_metadata(ExperimentResult &result) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,22 @@ class MPS {
public:
MPS(uint_t num_qubits = 0) : num_qubits_(num_qubits) {
#ifdef AER_THRUST_CUDA
cuda_stream = NULL;
cutensor_handle = NULL;
if (mps_svd_device_.compare("GPU") == 0) {
cudaStreamCreate(&cuda_stream);
cutensornetCreate(&cutensor_handle);
}
#endif // AER_THRUST_CUDA
}
~MPS() {}
~MPS() {
#ifdef AER_THRUST_CUDA
if (cutensor_handle)
cutensornetDestroy(cutensor_handle);
if (cuda_stream)
cudaStreamDestroy(cuda_stream);
#endif // AER_THRUST_CUDA
}

//--------------------------------------------------------------------------
// Function name: initialize
Expand Down Expand Up @@ -328,9 +337,11 @@ class MPS {
}

static void set_mps_lapack_svd(bool mps_lapack) { mps_lapack_ = mps_lapack; }
#ifdef AER_THRUST_CUDA
static void set_mps_svd_device(std::string mps_svd_device) {
mps_svd_device_ = mps_svd_device;
}
#endif // AER_THRUST_CUDA

static uint_t get_omp_threads() { return omp_threads_; }
static uint_t get_omp_threshold() { return omp_threshold_; }
Expand Down Expand Up @@ -585,7 +596,9 @@ class MPS {
static bool mps_log_data_;
static MPS_swap_direction mps_swap_direction_;
static bool mps_lapack_;
#ifdef AER_THRUST_CUDA
static std::string mps_svd_device_;
#endif // AER_THRUST_CUDA
};

inline std::ostream &operator<<(std::ostream &out, const rvector_t &vec) {
Expand Down

0 comments on commit ed9a907

Please sign in to comment.