Skip to content

Commit

Permalink
[test, scalapack.py] Add tests for scalapack.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Sep 7, 2023
1 parent d8298e2 commit 3d3d9ed
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 2 deletions.
23 changes: 21 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,34 @@ jobs:
with:
python-version: "3.x"
- name: install pytest
run: pip install pytest pytest-cov
run: pip install pytest pytest-cov numpy
- name: run test
working-directory: ${{github.workspace}}/lazy_graph
run: python -m pytest --cov=lazy --cov-fail-under=100

test_scalapack_py:
name: test scalapack.py
runs-on: ubuntu-latest
if: "! contains(toJSON(github.event.commits.*.message), '[skip ci]')"
steps:
- name: checkout
uses: actions/checkout@v3
- name: install mpi, scalapack
run: sudo apt-get install -y mpi-default-dev libscalapack-mpi-dev
- name: setup python
uses: actions/setup-python@v4
with:
python-version: "3.x"
- name: install pytest
run: pip install pytest pytest-cov pytest-mpi
- name: run test
working-directory: ${{github.workspace}}/PyScalapack
run: mpirun --oversubscribe -n 6 coverage run -m pytest --capture=no --with-mpi && coverage combine && coverage report

build_wheels_trigger:
name: trigger for building wheels
runs-on: ubuntu-latest
needs: [test_TAT_hpp, test_lazy_py]
needs: [test_TAT_hpp, test_lazy_py, test_scalapack_py]
if: "github.event_name == 'push' && (startsWith(github.ref, 'refs/tags') || contains(toJSON(github.event.commits.*.message), '[force ci]'))"
steps:
- name: nothing
Expand Down
6 changes: 6 additions & 0 deletions PyScalapack/.coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[run]
parallel=True
omit=tests/*

[report]
fail_under=100
54 changes: 54 additions & 0 deletions PyScalapack/tests/test_create_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import PyScalapack
import numpy as np

scalapack = PyScalapack("libscalapack.so")


@pytest.mark.mpi(min_size=4)
def test_array_create_incorrect():
with scalapack(b'C', nprow=2, npcol=2) as context:
with pytest.raises(RuntimeError) as e_info:
array = context.array(m=10, n=10, mb=1, nb=1)
with pytest.raises(RuntimeError) as e_info:
array = context.array(m=10, n=10, mb=1, nb=1, dtype=np.float64, data=np.zeros([5, 5], dtype=np.float64))


@pytest.mark.mpi(min_size=4)
def test_array_create_own_data():
with scalapack(b'C', nprow=2, npcol=2) as context:
array = context.array(m=10, n=10, mb=1, nb=1, dtype=np.float64)
assert array.dtype == array.c_dtype.value == 1
assert array.ctxt == array.c_ctxt.value == context.ictxt.value
assert array.m == array.c_m.value == 10
assert array.n == array.c_n.value == 10
assert array.mb == array.c_mb.value == 1
assert array.nb == array.c_nb.value == 1
assert array.rsrc == array.c_rsrc.value == 0
assert array.csrc == array.c_csrc.value == 0
if context:
assert array.lld == array.c_lld.value == 5


@pytest.mark.mpi(min_size=4)
def test_array_create_share_data():
with scalapack(b'C', nprow=2, npcol=2) as context:
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='F'))
with pytest.raises(RuntimeError) as e_info:
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='C'))

with scalapack(b'R', nprow=2, npcol=2) as context:
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='C'))
with pytest.raises(RuntimeError) as e_info:
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([5, 5], dtype=np.float64, order='F'))
#with pytest.raises(RuntimeError) as e_info:
# array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([6, 6], dtype=np.float64, order='C'))


@pytest.mark.mpi(min_size=4)
def test_array_create_share_data_local_mismatch():
with scalapack(b'C', nprow=2, npcol=2) as context:
with pytest.raises(RuntimeError) as e_info:
array = context.array(m=10, n=10, mb=1, nb=1, data=np.zeros([6, 6], dtype=np.float64, order='F'))
if not context:
raise RuntimeError("process in context should raise, process out context should not, raise it manually")
84 changes: 84 additions & 0 deletions PyScalapack/tests/test_create_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
import PyScalapack

scalapack = PyScalapack("libscalapack.so")


@pytest.mark.mpi(min_size=4)
def test_context_column_major():
with scalapack(b'C', nprow=2, npcol=2) as context:
assert context.layout.value == b'C'
if context:
assert context.rank.value // context.nprow.value == context.mycol.value
assert context.rank.value % context.nprow.value == context.myrow.value
assert context.rank.value < context.nprow.value * context.npcol.value
else:
assert context.rank.value >= context.nprow.value * context.npcol.value


@pytest.mark.mpi(min_size=4)
def test_context_row_major():
with scalapack(b'R', nprow=1, npcol=2) as context:
assert context.layout.value == b'R'
if context:
assert context.rank.value % context.npcol.value == context.mycol.value
assert context.rank.value // context.npcol.value == context.myrow.value
assert context.rank.value < context.nprow.value * context.npcol.value
else:
assert context.rank.value >= context.nprow.value * context.npcol.value


def test_context_error_major():
with pytest.raises(RuntimeError):
with scalapack(b'W', nprow=2, npcol=2) as context:
pass


@pytest.mark.mpi(min_size=2)
def test_context_auto_row():
with scalapack(b'R', nprow=-1, npcol=2) as context:
assert context.nprow.value == context.size.value // context.npcol.value
assert context.layout.value == b'R'
if context:
assert context.rank.value % context.npcol.value == context.mycol.value
assert context.rank.value // context.npcol.value == context.myrow.value
assert context.rank.value < context.nprow.value * context.npcol.value
else:
assert context.rank.value >= context.nprow.value * context.npcol.value


@pytest.mark.mpi(min_size=2)
def test_context_auto_column():
with scalapack(b'R', nprow=2, npcol=-1) as context:
assert context.npcol.value == context.size.value // context.nprow.value
assert context.layout.value == b'R'
if context:
assert context.rank.value % context.npcol.value == context.mycol.value
assert context.rank.value // context.npcol.value == context.myrow.value
assert context.rank.value < context.nprow.value * context.npcol.value
else:
assert context.rank.value >= context.nprow.value * context.npcol.value


@pytest.mark.mpi(min_size=4)
def test_context_barrier():
with scalapack(b'R', nprow=2, npcol=2) as context:
context.barrier()
context.barrier(b'A')
context.barrier(b'R')
context.barrier(b'C')
context.barrier(scope=b'A')
context.barrier(scope=b'R')
context.barrier(scope=b'C')

with pytest.raises(RuntimeError):
context.barrier(b'W')
with pytest.raises(RuntimeError):
context.barrier(scope=b'W')


def test_context_raise():
with pytest.raises(RuntimeError) as e_info:
with scalapack(b'R', nprow=1, npcol=1) as context:
raise RuntimeError("Test Error")
assert e_info.value.args == ("Test Error",)
67 changes: 67 additions & 0 deletions PyScalapack/tests/test_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
import numpy as np
import PyScalapack

scalapack = PyScalapack("libscalapack.so")


@pytest.mark.mpi(min_size=4)
def test_pdgemm():
L1 = 128
L2 = 512
with scalapack(b'C', 2, 2) as context, scalapack(b'C', 1, 1) as context0:
if context:
# Create array0 add 1*1 grid
array0 = context0.array(m=L1, n=L2, mb=1, nb=1, dtype=np.float64)
if context0:
array0.data[...] = np.random.randn(*array0.data.shape)

# Redistribute array0 to 2*2 grid as array
array = context.array(m=L1, n=L2, mb=1, nb=1, dtype=np.float64)
scalapack.pgemr2d["D"](*(L1, L2), *array0.scalapack_params(), *array.scalapack_params(), context.ictxt)

# Call pdgemm to get the product of array and array in 2*2 grid
result = context.array(m=L1, n=L1, mb=1, nb=1, dtype=np.float64)
scalapack.pdgemm(
b'N',
b'T',
*(L1, L1, L2),
scalapack.d_one,
*array.scalapack_params(),
*array.scalapack_params(),
scalapack.d_zero,
*result.scalapack_params(),
)

# Redistribute result to 1*1 grid as result0
result0 = context0.array(m=L1, n=L1, mb=1, nb=1, dtype=np.float64)
scalapack.pgemr2d["D"](*(L1, L1), *result.scalapack_params(), *result0.scalapack_params(), context.ictxt)

# Check result0 == array0 * array0^T
if context0:
diff = result0.data - array0.data @ array0.data.T
assert np.linalg.norm(diff) < 1e-8


def test_dgemm():
L1 = 128
L2 = 512
with scalapack(b'C', 1, 1) as context:
if context:
array = context.array(m=L1, n=L2, mb=1, nb=1, dtype=np.float64)
array.data[...] = np.random.randn(*array.data.shape)

result = context.array(m=L1, n=L1, mb=1, nb=1, dtype=np.float64)
scalapack.dgemm(
b'N',
b'T',
*(L1, L1, L2),
scalapack.d_one,
*array.lapack_params(),
*array.lapack_params(),
scalapack.d_zero,
*result.lapack_params(),
)

diff = result.data - array.data @ array.data.T
assert np.linalg.norm(diff) < 1e-8
63 changes: 63 additions & 0 deletions PyScalapack/tests/test_redistribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
import PyScalapack
import numpy as np

scalapack = PyScalapack("libscalapack.so")


@pytest.mark.mpi(min_size=2)
def test_array_redistribute_0():
# matrix:
# 1 2
# 3 4
with scalapack(b'R', 1, 2) as context1, scalapack(b'R', 2, 1) as context2:
if context1:
m = 2
n = 2
array1 = context1.array(m=m, n=n, mb=1, nb=1, dtype=np.float64)
if context1.rank.value == 0:
array1.data[0, 0] = 1
array1.data[1, 0] = 3
else:
array1.data[0, 0] = 2
array1.data[1, 0] = 4
array2 = context2.array(m=m, n=n, mb=1, nb=1, dtype=np.float64)
scalapack.pgemr2d["D"](
*(m, n),
*array1.scalapack_params(),
*array2.scalapack_params(),
context1.ictxt,
)
if context1.rank.value == 0:
assert array2.data[0, 0] == 1
assert array2.data[0, 1] == 2
else:
assert array2.data[0, 0] == 3
assert array2.data[0, 1] == 4


@pytest.mark.mpi(min_size=4)
def test_array_redistribute_1():
with scalapack(b'R', 2, 2) as context1, scalapack(b'R', 1, 1) as context0:
if context1:
m = 100
n = 100
array0 = context0.array(m=m, n=n, mb=1, nb=1, dtype=np.float64)
array2 = context0.array(m=m, n=n, mb=1, nb=1, dtype=np.float64)
if context0:
array0.data[...] = np.random.randn(*array0.data.shape)
array1 = context1.array(m=m, n=n, mb=1, nb=1, dtype=np.float64)
scalapack.pgemr2d["D"](
*(m, n),
*array0.scalapack_params(),
*array1.scalapack_params(),
context1.ictxt,
)
scalapack.pgemr2d["D"](
*(m, n),
*array1.scalapack_params(),
*array2.scalapack_params(),
context1.ictxt,
)
if context1:
assert np.linalg.norm(array2.data - array0.data) == 0

0 comments on commit 3d3d9ed

Please sign in to comment.