Skip to content

Commit

Permalink
update to incomplete_factorization and throw with omp using sparselib
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 22, 2024
1 parent 46cd5c5 commit d681a82
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 124 deletions.
13 changes: 8 additions & 5 deletions core/factorization/ic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ Ic<ValueType, IndexType>::parse(const config::pnode& config,
params.with_both_factors(config::get_value<bool>(obj));
}
if (auto& obj = config.get("algorithm")) {
using gko::factorization::incomplete_factorize_algorithm;
using gko::factorization::incomplete_algorithm;
auto str = obj.get_string();
if (str == "sparselib") {
params.with_algorithm(incomplete_factorize_algorithm::sparselib);
params.with_algorithm(incomplete_algorithm::sparselib);
} else if (str == "syncfree") {
params.with_algorithm(incomplete_factorize_algorithm::syncfree);
params.with_algorithm(incomplete_algorithm::syncfree);
} else {
GKO_INVALID_CONFIG_VALUE("algorithm", str);
}
Expand Down Expand Up @@ -105,8 +105,7 @@ std::unique_ptr<Composition<ValueType>> Ic<ValueType, IndexType>::generate(

std::shared_ptr<const matrix_type> ic;
// Compute LC factorization
if (std::dynamic_pointer_cast<const OmpExecutor>(exec) ||
parameters_.algorithm == incomplete_factorize_algorithm::syncfree) {
if (parameters_.algorithm == incomplete_algorithm::syncfree) {
std::unique_ptr<gko::factorization::elimination_forest<IndexType>>
forest;
const auto nnz = local_system_matrix->get_num_stored_elements();
Expand Down Expand Up @@ -161,6 +160,10 @@ std::unique_ptr<Composition<ValueType>> Ic<ValueType, IndexType>::generate(
transpose_idxs.get_const_data(), *forest, factors.get(), false,
tmp));
ic = factors;
} else if (std::dynamic_pointer_cast<const OmpExecutor>(exec)) {
GKO_INVALID_STATE(
"OmpExecutor does not support sparselib algorithm. Please use "
"syncfree algorithm.");
} else {
exec->run(
ic_factorization::make_sparselib_ic(local_system_matrix.get()));
Expand Down
13 changes: 8 additions & 5 deletions core/factorization/ilu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ Ilu<ValueType, IndexType>::parse(const config::pnode& config,
params.with_skip_sorting(config::get_value<bool>(obj));
}
if (auto& obj = config.get("algorithm")) {
using gko::factorization::incomplete_factorize_algorithm;
using gko::factorization::incomplete_algorithm;
auto str = obj.get_string();
if (str == "sparselib") {
params.with_algorithm(incomplete_factorize_algorithm::sparselib);
params.with_algorithm(incomplete_algorithm::sparselib);
} else if (str == "syncfree") {
params.with_algorithm(incomplete_factorize_algorithm::syncfree);
params.with_algorithm(incomplete_algorithm::syncfree);
} else {
GKO_INVALID_CONFIG_VALUE("algorithm", str);
}
Expand Down Expand Up @@ -103,8 +103,7 @@ std::unique_ptr<Composition<ValueType>> Ilu<ValueType, IndexType>::generate_l_u(

std::shared_ptr<const matrix_type> ilu;
// Compute LU factorization
if (std::dynamic_pointer_cast<const OmpExecutor>(exec) ||
parameters_.algorithm == incomplete_factorize_algorithm::syncfree) {
if (parameters_.algorithm == incomplete_algorithm::syncfree) {
const auto nnz = local_system_matrix->get_num_stored_elements();
const auto num_rows = local_system_matrix->get_size()[0];
auto factors = share(
Expand Down Expand Up @@ -145,6 +144,10 @@ std::unique_ptr<Composition<ValueType>> Ilu<ValueType, IndexType>::generate_l_u(
storage.get_const_data(), diag_idxs.get_const_data(), factors.get(),
false, tmp));
ilu = factors;
} else if (std::dynamic_pointer_cast<const OmpExecutor>(exec)) {
GKO_INVALID_STATE(
"OmpExecutor does not support sparselib algorithm. Please use "
"syncfree algorithm.");
} else {
exec->run(
ilu_factorization::make_sparselib_ilu(local_system_matrix.get()));
Expand Down
6 changes: 4 additions & 2 deletions core/test/config/factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ struct Ic : FactorizationConfigTest<gko::factorization::Ic<float, int>,
config_map["both_factors"] = pnode{false};
param.with_both_factors(false);
config_map["algorithm"] = pnode{"syncfree"};
param.with_algorithm(gko::factorization::factorize_algorithm::syncfree);
param.with_algorithm(
gko::factorization::incomplete_algorithm::syncfree);
}

template <typename AnswerType>
Expand Down Expand Up @@ -115,7 +116,8 @@ struct Ilu : FactorizationConfigTest<gko::factorization::Ilu<float, int>,
config_map["skip_sorting"] = pnode{true};
param.with_skip_sorting(true);
config_map["algorithm"] = pnode{"syncfree"};
param.with_algorithm(gko::factorization::factorize_algorithm::syncfree);
param.with_algorithm(
gko::factorization::incomplete_algorithm::syncfree);
}

template <typename AnswerType>
Expand Down
4 changes: 2 additions & 2 deletions include/ginkgo/core/factorization/ic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class Ic : public Composition<ValueType> {
* cuSPARSE/hipSPARSE/reference (sparselib) implementation. Default is
* sparselib.
*/
incomplete_factorize_algorithm GKO_FACTORY_PARAMETER_SCALAR(
algorithm, incomplete_factorize_algorithm::sparselib);
incomplete_algorithm GKO_FACTORY_PARAMETER_SCALAR(
algorithm, incomplete_algorithm::sparselib);
};
GKO_ENABLE_LIN_OP_FACTORY(Ic, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down
4 changes: 2 additions & 2 deletions include/ginkgo/core/factorization/ilu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class Ilu : public Composition<ValueType> {
* cuSPARSE/hipSPARSE/reference (sparselib) implementation. Default is
* sparselib.
*/
incomplete_factorize_algorithm GKO_FACTORY_PARAMETER_SCALAR(
algorithm, incomplete_factorize_algorithm::sparselib);
incomplete_algorithm GKO_FACTORY_PARAMETER_SCALAR(
algorithm, incomplete_algorithm::sparselib);
};
GKO_ENABLE_LIN_OP_FACTORY(Ilu, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace factorization {
* `syncfree` is Ginkgo's implementation by using the Lu/Cholesky factorization
* components with given sparsity.
*/
enum class incomplete_factorize_algorithm { sparselib, syncfree };
enum class incomplete_algorithm { sparselib, syncfree };


} // namespace factorization
Expand Down
9 changes: 3 additions & 6 deletions reference/test/factorization/ic_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ TYPED_TEST(Ic, GenerateGeneralBySyncfree)

auto fact =
factorization_type::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->exec)
->generate(this->mtx_system);

Expand Down Expand Up @@ -224,8 +223,7 @@ TYPED_TEST(Ic, GenerateIcWithBitmapIsEquivalentToRefBySyncfree)
auto result_lt = gko::as<Csr>(result_l->conj_transpose());
auto factory =
gko::factorization::Ic<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->ref);

auto ic = factory->generate(mtx);
Expand Down Expand Up @@ -272,8 +270,7 @@ TYPED_TEST(Ic, GenerateIcWithHashmapIsEquivalentToRefBySyncfree)
auto result_lt = gko::as<Csr>(result_l->conj_transpose());
auto factory =
gko::factorization::Ic<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->ref);

auto ic = factory->generate(mtx);
Expand Down
9 changes: 3 additions & 6 deletions reference/test/factorization/ilu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,7 @@ TYPED_TEST(Ilu, GenerateForCsrSmallBySyncfree)
using ilu_type = typename TestFixture::ilu_type;
auto factors =
ilu_type::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->exec)
->generate(this->mtx_csr_small);
auto l_factor = factors->get_l_factor();
Expand Down Expand Up @@ -404,8 +403,7 @@ TYPED_TEST(Ilu, GenerateIluWithBitmapIsEquivalentToRefBySyncfree)
result_u->read(result_u_data);
auto factory =
gko::factorization::Ilu<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->ref);

auto lu = factory->generate(mtx);
Expand Down Expand Up @@ -449,8 +447,7 @@ TYPED_TEST(Ilu, GenerateIluWithHashmapIsEquivalentToRefBySyncfree)
result_u->read(result_u_data);
auto factory =
gko::factorization::Ilu<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->ref);

auto lu = factory->generate(mtx);
Expand Down
108 changes: 62 additions & 46 deletions test/factorization/ic_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <gtest/gtest.h>

#include <ginkgo/core/base/exception.hpp>
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/factorization/ic.hpp>
#include <ginkgo/core/factorization/par_ic.hpp>
Expand Down Expand Up @@ -37,24 +38,6 @@ class Ic : public CommonTestFixture {
};


TEST_F(Ic, ComputeICIsEquivalentToRefSorted)
{
auto fact = gko::factorization::Ic<>::build()
.with_skip_sorting(true)
.on(ref)
->generate(mtx);
auto dfact = gko::factorization::Ic<>::build()
.with_skip_sorting(true)
.on(exec)
->generate(dmtx);

GKO_ASSERT_MTX_NEAR(fact->get_l_factor(), dfact->get_l_factor(), 1e-14);
GKO_ASSERT_MTX_NEAR(fact->get_lt_factor(), dfact->get_lt_factor(), 1e-14);
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_l_factor(), dfact->get_l_factor());
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_lt_factor(), dfact->get_lt_factor());
}


TEST_F(Ic, ComputeICBySyncfreeIsEquivalentToRefSorted)
{
auto fact = gko::factorization::Ic<>::build()
Expand All @@ -64,8 +47,7 @@ TEST_F(Ic, ComputeICBySyncfreeIsEquivalentToRefSorted)
auto dfact =
gko::factorization::Ic<>::build()
.with_skip_sorting(true)
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(exec)
->generate(dmtx);

Expand All @@ -76,21 +58,6 @@ TEST_F(Ic, ComputeICBySyncfreeIsEquivalentToRefSorted)
}


TEST_F(Ic, ComputeICIsEquivalentToRefUnsorted)
{
gko::test::unsort_matrix(mtx, rand_engine);
dmtx->copy_from(mtx);

auto fact = gko::factorization::Ic<>::build().on(ref)->generate(mtx);
auto dfact = gko::factorization::Ic<>::build().on(exec)->generate(dmtx);

GKO_ASSERT_MTX_NEAR(fact->get_l_factor(), dfact->get_l_factor(), 1e-14);
GKO_ASSERT_MTX_NEAR(fact->get_lt_factor(), dfact->get_lt_factor(), 1e-14);
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_l_factor(), dfact->get_l_factor());
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_lt_factor(), dfact->get_lt_factor());
}


TEST_F(Ic, ComputeICWithBitmapIsEquivalentToRefBySyncfree)
{
// diag + full first row and column
Expand All @@ -104,13 +71,11 @@ TEST_F(Ic, ComputeICWithBitmapIsEquivalentToRefBySyncfree)

auto factory =
gko::factorization::Ic<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->ref);
auto dfactory =
gko::factorization::Ic<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->exec);

auto ic = factory->generate(mtx);
Expand Down Expand Up @@ -149,13 +114,11 @@ TEST_F(Ic, ComputeICWithHashmapIsEquivalentToRefBySyncfree)
auto dmtx = gko::share(mtx->clone(this->exec));
auto factory =
gko::factorization::Ic<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->ref);
auto dfactory =
gko::factorization::Ic<value_type, index_type>::build()
.with_algorithm(
gko::factorization::incomplete_factorize_algorithm::syncfree)
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(this->exec);

auto ic = factory->generate(mtx);
Expand All @@ -172,11 +135,64 @@ TEST_F(Ic, ComputeICWithHashmapIsEquivalentToRefBySyncfree)

TEST_F(Ic, SetsCorrectStrategy)
{
auto dfact =
gko::factorization::Ic<>::build()
.with_l_strategy(std::make_shared<Csr::merge_path>())
.with_algorithm(gko::factorization::incomplete_algorithm::syncfree)
.on(exec)
->generate(dmtx);

ASSERT_EQ(dfact->get_l_factor()->get_strategy()->get_name(), "merge_path");
ASSERT_EQ(dfact->get_lt_factor()->get_strategy()->get_name(), "merge_path");
}


#ifdef GKO_COMPILING_OMP


TEST_F(Ic, OmpComputeICBySparselibShouldThrow)
{
ASSERT_THROW(gko::factorization::Ic<>::build()
.with_skip_sorting(true)
.on(exec)
->generate(dmtx),
gko::InvalidStateError);
}


#else


TEST_F(Ic, ComputeICIsEquivalentToRefSorted)
{
auto fact = gko::factorization::Ic<>::build()
.with_skip_sorting(true)
.on(ref)
->generate(mtx);
auto dfact = gko::factorization::Ic<>::build()
.with_l_strategy(std::make_shared<Csr::merge_path>())
.with_skip_sorting(true)
.on(exec)
->generate(dmtx);

ASSERT_EQ(dfact->get_l_factor()->get_strategy()->get_name(), "merge_path");
ASSERT_EQ(dfact->get_lt_factor()->get_strategy()->get_name(), "merge_path");
GKO_ASSERT_MTX_NEAR(fact->get_l_factor(), dfact->get_l_factor(), 1e-14);
GKO_ASSERT_MTX_NEAR(fact->get_lt_factor(), dfact->get_lt_factor(), 1e-14);
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_l_factor(), dfact->get_l_factor());
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_lt_factor(), dfact->get_lt_factor());
}


TEST_F(Ic, ComputeICIsEquivalentToRefUnsorted)
{
gko::test::unsort_matrix(mtx, rand_engine);
dmtx->copy_from(mtx);

auto fact = gko::factorization::Ic<>::build().on(ref)->generate(mtx);
auto dfact = gko::factorization::Ic<>::build().on(exec)->generate(dmtx);

GKO_ASSERT_MTX_NEAR(fact->get_l_factor(), dfact->get_l_factor(), 1e-14);
GKO_ASSERT_MTX_NEAR(fact->get_lt_factor(), dfact->get_lt_factor(), 1e-14);
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_l_factor(), dfact->get_l_factor());
GKO_ASSERT_MTX_EQ_SPARSITY(fact->get_lt_factor(), dfact->get_lt_factor());
}

#endif
Loading

0 comments on commit d681a82

Please sign in to comment.