-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[test, scalapack.py] Add tests for scalapack.py
- Loading branch information
Showing
6 changed files
with
295 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[run] | ||
parallel=True | ||
omit=tests/* | ||
|
||
[report] | ||
fail_under=100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import pytest | ||
import PyScalapack | ||
import numpy as np | ||
|
||
scalapack = PyScalapack("libscalapack.so") | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_array_create_incorrect(): | ||
with scalapack(b'C', nprow=2, npcol=2) as context: | ||
with pytest.raises(RuntimeError) as e_info: | ||
array = context.array(m=10, n=10, mb=1, nb=1) | ||
with pytest.raises(RuntimeError) as e_info: | ||
array = context.array(m=10, n=10, mb=1, nb=1, dtype=np.float64, data=np.zeros([5, 5], dtype=np.float64)) | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_array_create_own_data(): | ||
with scalapack(b'C', nprow=2, npcol=2) as context: | ||
array = context.array(m=10, n=10, mb=1, nb=1, dtype=np.float64) | ||
assert array.dtype == array.c_dtype.value == 1 | ||
assert array.ctxt == array.c_ctxt.value == context.ictxt.value | ||
assert array.m == array.c_m.value == 10 | ||
assert array.n == array.c_n.value == 10 | ||
assert array.mb == array.c_mb.value == 1 | ||
assert array.nb == array.c_nb.value == 1 | ||
assert array.rsrc == array.c_rsrc.value == 0 | ||
assert array.csrc == array.c_csrc.value == 0 | ||
if context: | ||
assert array.lld == array.c_lld.value == 5 | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_array_create_share_data(): | ||
with scalapack(b'C', nprow=2, npcol=2) as context: | ||
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='F')) | ||
with pytest.raises(RuntimeError) as e_info: | ||
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='C')) | ||
|
||
with scalapack(b'R', nprow=2, npcol=2) as context: | ||
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='C')) | ||
with pytest.raises(RuntimeError) as e_info: | ||
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='F')) | ||
#with pytest.raises(RuntimeError) as e_info: | ||
# array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([6, 6], dtype=np.float64, order='C')) | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_array_create_share_data_local_mismatch(): | ||
with scalapack(b'C', nprow=2, npcol=2) as context: | ||
with pytest.raises(RuntimeError) as e_info: | ||
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([6, 6], dtype=np.float64, order='F')) | ||
if not context: | ||
raise RuntimeError("process in context should raise, process out context should not, raise it manually") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import pytest | ||
import PyScalapack | ||
|
||
scalapack = PyScalapack("libscalapack.so") | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_context_column_major(): | ||
with scalapack(b'C', nprow=2, npcol=2) as context: | ||
assert context.layout.value == b'C' | ||
if context: | ||
assert context.rank.value // context.nprow.value == context.mycol.value | ||
assert context.rank.value % context.nprow.value == context.myrow.value | ||
assert context.rank.value < context.nprow.value * context.npcol.value | ||
else: | ||
assert context.rank.value >= context.nprow.value * context.npcol.value | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_context_row_major(): | ||
with scalapack(b'R', nprow=1, npcol=2) as context: | ||
assert context.layout.value == b'R' | ||
if context: | ||
assert context.rank.value % context.npcol.value == context.mycol.value | ||
assert context.rank.value // context.npcol.value == context.myrow.value | ||
assert context.rank.value < context.nprow.value * context.npcol.value | ||
else: | ||
assert context.rank.value >= context.nprow.value * context.npcol.value | ||
|
||
|
||
def test_context_error_major(): | ||
with pytest.raises(RuntimeError): | ||
with scalapack(b'W', nprow=2, npcol=2) as context: | ||
pass | ||
|
||
|
||
@pytest.mark.mpi(min_size=2) | ||
def test_context_auto_row(): | ||
with scalapack(b'R', nprow=-1, npcol=2) as context: | ||
assert context.nprow.value == context.size.value // context.npcol.value | ||
assert context.layout.value == b'R' | ||
if context: | ||
assert context.rank.value % context.npcol.value == context.mycol.value | ||
assert context.rank.value // context.npcol.value == context.myrow.value | ||
assert context.rank.value < context.nprow.value * context.npcol.value | ||
else: | ||
assert context.rank.value >= context.nprow.value * context.npcol.value | ||
|
||
|
||
@pytest.mark.mpi(min_size=2) | ||
def test_context_auto_column(): | ||
with scalapack(b'R', nprow=2, npcol=-1) as context: | ||
assert context.npcol.value == context.size.value // context.nprow.value | ||
assert context.layout.value == b'R' | ||
if context: | ||
assert context.rank.value % context.npcol.value == context.mycol.value | ||
assert context.rank.value // context.npcol.value == context.myrow.value | ||
assert context.rank.value < context.nprow.value * context.npcol.value | ||
else: | ||
assert context.rank.value >= context.nprow.value * context.npcol.value | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_context_barrier(): | ||
with scalapack(b'R', nprow=2, npcol=2) as context: | ||
context.barrier() | ||
context.barrier(b'A') | ||
context.barrier(b'R') | ||
context.barrier(b'C') | ||
context.barrier(scope=b'A') | ||
context.barrier(scope=b'R') | ||
context.barrier(scope=b'C') | ||
|
||
with pytest.raises(RuntimeError): | ||
context.barrier(b'W') | ||
with pytest.raises(RuntimeError): | ||
context.barrier(scope=b'W') | ||
|
||
|
||
def test_context_raise(): | ||
with pytest.raises(RuntimeError) as e_info: | ||
with scalapack(b'R', nprow=1, npcol=1) as context: | ||
raise RuntimeError("Test Error") | ||
assert e_info.value.args == ("Test Error",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import pytest | ||
import numpy as np | ||
import PyScalapack | ||
|
||
scalapack = PyScalapack("libscalapack.so") | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_pdgemm(): | ||
L1 = 128 | ||
L2 = 512 | ||
with scalapack(b'C', 2, 2) as context, scalapack(b'C', 1, 1) as context0: | ||
if context: | ||
# Create array0 add 1*1 grid | ||
array0 = context0.array(m=L1, n=L2, mb=1, nb=1, dtype=np.float64) | ||
if context0: | ||
array0.data[...] = np.random.randn(*array0.data.shape) | ||
|
||
# Redistribute array0 to 2*2 grid as array | ||
array = context.array(m=L1, n=L2, mb=1, nb=1, dtype=np.float64) | ||
scalapack.pgemr2d["D"](*(L1, L2), *array0.scalapack_params(), *array.scalapack_params(), context.ictxt) | ||
|
||
# Call pdgemm to get the product of array and array in 2*2 grid | ||
result = context.array(m=L1, n=L1, mb=1, nb=1, dtype=np.float64) | ||
scalapack.pdgemm( | ||
b'N', | ||
b'T', | ||
*(L1, L1, L2), | ||
scalapack.d_one, | ||
*array.scalapack_params(), | ||
*array.scalapack_params(), | ||
scalapack.d_zero, | ||
*result.scalapack_params(), | ||
) | ||
|
||
# Redistribute result to 1*1 grid as result0 | ||
result0 = context0.array(m=L1, n=L1, mb=1, nb=1, dtype=np.float64) | ||
scalapack.pgemr2d["D"](*(L1, L1), *result.scalapack_params(), *result0.scalapack_params(), context.ictxt) | ||
|
||
# Check result0 == array0 * array0^T | ||
if context0: | ||
diff = result0.data - array0.data @ array0.data.T | ||
assert np.linalg.norm(diff) < 1e-8 | ||
|
||
|
||
def test_dgemm(): | ||
L1 = 128 | ||
L2 = 512 | ||
with scalapack(b'C', 1, 1) as context: | ||
if context: | ||
array = context.array(m=L1, n=L2, mb=1, nb=1, dtype=np.float64) | ||
array.data[...] = np.random.randn(*array.data.shape) | ||
|
||
result = context.array(m=L1, n=L1, mb=1, nb=1, dtype=np.float64) | ||
scalapack.dgemm( | ||
b'N', | ||
b'T', | ||
*(L1, L1, L2), | ||
scalapack.d_one, | ||
*array.lapack_params(), | ||
*array.lapack_params(), | ||
scalapack.d_zero, | ||
*result.lapack_params(), | ||
) | ||
|
||
diff = result.data - array.data @ array.data.T | ||
assert np.linalg.norm(diff) < 1e-8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import pytest | ||
import PyScalapack | ||
import numpy as np | ||
|
||
scalapack = PyScalapack("libscalapack.so") | ||
|
||
|
||
@pytest.mark.mpi(min_size=2) | ||
def test_array_redistribute_0(): | ||
# matrix: | ||
# 1 2 | ||
# 3 4 | ||
with scalapack(b'R', 1, 2) as context1, scalapack(b'R', 2, 1) as context2: | ||
if context1: | ||
m = 2 | ||
n = 2 | ||
array1 = context1.array(m=m, n=n, mb=1, nb=1, dtype=np.float64) | ||
if context1.rank.value == 0: | ||
array1.data[0, 0] = 1 | ||
array1.data[1, 0] = 3 | ||
else: | ||
array1.data[0, 0] = 2 | ||
array1.data[1, 0] = 4 | ||
array2 = context2.array(m=m, n=n, mb=1, nb=1, dtype=np.float64) | ||
scalapack.pgemr2d["D"]( | ||
*(m, n), | ||
*array1.scalapack_params(), | ||
*array2.scalapack_params(), | ||
context1.ictxt, | ||
) | ||
if context1.rank.value == 0: | ||
assert array2.data[0, 0] == 1 | ||
assert array2.data[0, 1] == 2 | ||
else: | ||
assert array2.data[0, 0] == 3 | ||
assert array2.data[0, 1] == 4 | ||
|
||
|
||
@pytest.mark.mpi(min_size=4) | ||
def test_array_redistribute_1(): | ||
with scalapack(b'R', 2, 2) as context1, scalapack(b'R', 1, 1) as context0: | ||
if context1: | ||
m = 100 | ||
n = 100 | ||
array0 = context0.array(m=m, n=n, mb=1, nb=1, dtype=np.float64) | ||
array2 = context0.array(m=m, n=n, mb=1, nb=1, dtype=np.float64) | ||
if context0: | ||
array0.data[...] = np.random.randn(*array0.data.shape) | ||
array1 = context1.array(m=m, n=n, mb=1, nb=1, dtype=np.float64) | ||
scalapack.pgemr2d["D"]( | ||
*(m, n), | ||
*array0.scalapack_params(), | ||
*array1.scalapack_params(), | ||
context1.ictxt, | ||
) | ||
scalapack.pgemr2d["D"]( | ||
*(m, n), | ||
*array1.scalapack_params(), | ||
*array2.scalapack_params(), | ||
context1.ictxt, | ||
) | ||
if context1: | ||
assert np.linalg.norm(array2.data - array0.data) == 0 |