Skip to content

Commit

Permalink
Enhance user experience with MPI (#997)
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy authored Oct 31, 2024
1 parent cc05b51 commit ea0ffe4
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 44 deletions.
88 changes: 65 additions & 23 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ struct PlaneWaveBasis{T,
fft_grid::FFTtype

## MPI-local information of the kpoints this processor treats
# Irreducible kpoints. In the case of collinear spin,
# this lists all the spin up, then all the spin down
# In principle, irreducible kpoints (although some kpoints might be duplicated in parallel runs).
# In the case of collinear spin, this lists all the spin up, then all the spin down
kpoints::Vector{Kpoint{T, T_kpt_G_vecs}}
# BZ integration weights, summing up to model.n_spin_components
kweights::Vector{T}
Expand All @@ -58,10 +58,17 @@ struct PlaneWaveBasis{T,
## These fields are not actually used in computation, but can be used to reconstruct a basis
# Monkhorst-Pack grid used to generate the k-points, or nothing for custom k-points
kgrid::AbstractKgrid
# full list of (non spin doubled) k-point coordinates in the irreducible BZ
# Full list of (non spin doubled) k-point coordinates in the irreducible BZ (duplicates possible)
# Best to use the irreducible_kcoords_global() and irreducible_kweights_global() functions
# to insure none of the k-points are duplicated
kcoords_global::Vector{Vec3{T}}
kweights_global::Vector{T}

# Number of irreducible k-points in the basis. If there are more MPI ranks than irreducible
# k-points, some are duplicated over the MPI ranks (with adjusted weight). In such a case
# n_irreducible_kpoints < length(kcoords_global)
n_irreducible_kpoints::Int

## Setup for MPI-distributed processing over k-points
comm_kpts::MPI.Comm # communicator for the kpoints distribution
krange_thisproc::Vector{UnitRange{Int}} # Indices of kpoints treated explicitly by this
Expand Down Expand Up @@ -175,30 +182,32 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Real, fft_size::Tuple{Int, Int, I
# Compute k-point information and spread them across processors
# Right now we split only the kcoords: both spin channels have to be handled
# by the same process
n_kpt = length(kcoords_global)
n_procs = mpi_nprocs(comm_kpts)
n_kpt = length(kcoords_global)
n_irreducible_kpoints = n_kpt

# The code cannot handle MPI ranks without k-points. If there are more prcocesses
# than k-points, we duplicate k-points with the highest weight on the empty MPI
# ranks (and scale the weight accordingly)
if n_procs > n_kpt
# XXX Supporting more processors than kpoints would require
# fixing a bunch of "reducing over empty collections" errors
# In the unit tests it is really annoying that this fails so we hack around it, but
# generally it leads to duplicated work that is not in the users interest.
if parse(Bool, get(ENV, "CI", "false"))
comm_kpts = MPI.COMM_SELF
krange_thisproc1 = 1:n_kpt
krange_allprocs1 = fill(1:n_kpt, n_procs)
else
error("No point in trying to parallelize $n_kpt kpoints over $n_procs " *
"processes; reduce the number of MPI processes.")
for i in n_kpt+1:n_procs
idx = argmax(kweights_global)
kweights_global[idx] *= 0.5
push!(kweights_global, kweights_global[idx])
push!(kcoords_global, kcoords_global[idx])
end
else
# get the slice of 1:n_kpt to be handled by this process
# Note: MPI ranks are 0-based
krange_allprocs1 = split_evenly(1:n_kpt, n_procs)
krange_thisproc1 = krange_allprocs1[1 + MPI.Comm_rank(comm_kpts)]
@assert mpi_sum(length(krange_thisproc1), comm_kpts) == n_kpt
@assert !isempty(krange_thisproc1)
@warn("Attempting to parallelize $n_kpt k-points over $n_procs MPI ranks. " *
"DFTK does not support processes empty of k-point. Some k-points were " *
"duplicated over the extra ranks with scaled weights.")
end
n_kpt = length(kcoords_global)

# get the slice of 1:n_kpt to be handled by this process
# Note: MPI ranks are 0-based
krange_allprocs1 = split_evenly(1:n_kpt, n_procs)
krange_thisproc1 = krange_allprocs1[1 + MPI.Comm_rank(comm_kpts)]
@assert mpi_sum(length(krange_thisproc1), comm_kpts) == n_kpt
@assert !isempty(krange_thisproc1)

# Setup k-point basis sets
!variational && @warn(
Expand Down Expand Up @@ -237,7 +246,7 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Real, fft_size::Tuple{Int, Int, I
Ecut, variational,
fft_grid,
kpoints, kweights, kgrid,
kcoords_global, kweights_global,
kcoords_global, kweights_global, n_irreducible_kpoints,
comm_kpts, krange_thisproc, krange_allprocs, krange_thisproc_allspin,
architecture, symmetries, symmetries_respect_rgrid,
use_symmetries_for_kpoint_reduction, terms)
Expand Down Expand Up @@ -422,6 +431,38 @@ function weighted_ksum(basis::PlaneWaveBasis, array)
mpi_sum(res, basis.comm_kpts)
end

"""
Utilities to get information about the irreducible k-point mesh (in case of duplication)
Useful for I/O, where k-point information should not be duplicated
"""
function irreducible_kcoords_global(basis::PlaneWaveBasis)
# Assume that duplicated k-points are appended at the end of the kcoords array
basis.kcoords_global[1:basis.n_irreducible_kpoints]
end

function irreducible_kweights_global(basis::PlaneWaveBasis{T}) where {T}
function same_kpoint(i_irr, i_dupl)
maximum(abs, basis.kcoords_global[i_dupl]-basis.kcoords_global[i_irr]) < eps(T)
end

# Check that weights add up to 1 on entry (non spin doubled k-points)
@assert sum(basis.kweights_global) 1

# Assume that duplicated k-points are appended at the end of the kcoords array
irr_kweights = basis.kweights_global[1:basis.n_irreducible_kpoints]
for i_dupl = basis.n_irreducible_kpoints+1:length(basis.kweights_global)
for i_irr = 1:basis.n_irreducible_kpoints
if same_kpoint(i_irr, i_dupl)
irr_kweights[i_irr] += basis.kweights_global[i_dupl]
break
end
end
end

# Test that irreducible weight add up to 1 (non spin doubled k-points)
@assert sum(irr_kweights) 1
irr_kweights
end

"""
Gather the distributed ``k``-point data on the master process and return
Expand Down Expand Up @@ -461,6 +502,7 @@ and save it in `dest` as a dense `(size(kdata[1])..., n_kpoints)` array. On the

# Note: This function assumes that k-points are stored contiguously in rank-increasing
# order, i.e. it depends on the splitting realised by split_evenly.
# Note that if some k-points are duplicated over MPI ranks, they are also gathered here.
for σ in 1:basis.model.n_spin_components
if mpi_master(basis.comm_kpts)
# Setup variable buffer using appropriate data lengths and
Expand Down
25 changes: 18 additions & 7 deletions src/input_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function Base.show(io::IO, ::MIME"text/plain", basis::PlaneWaveBasis)
end
showfieldln(io, "kgrid", basis.kgrid)
showfieldln(io, "num. red. kpoints", length(basis.kgrid))
showfieldln(io, "num. irred. kpoints", length(basis.kcoords_global))
showfieldln(io, "num. irred. kpoints", basis.n_irreducible_kpoints)

println(io)
modelstr = sprint(show, "text/plain", basis.model)
Expand Down Expand Up @@ -149,10 +149,10 @@ function todict!(dict, basis::PlaneWaveBasis)
todict!(dict, basis.model)

dict["kgrid"] = sprint(show, "text/plain", basis.kgrid)
dict["kcoords"] = basis.kcoords_global
dict["kcoords_cart"] = vector_red_to_cart.(basis.model, basis.kcoords_global)
dict["kweights"] = basis.kweights_global
dict["n_kpoints"] = length(basis.kcoords_global)
dict["kcoords"] = irreducible_kcoords_global(basis)
dict["kcoords_cart"] = vector_red_to_cart.(basis.model, irreducible_kcoords_global(basis))
dict["kweights"] = irreducible_kweights_global(basis)
dict["n_kpoints"] = basis.n_irreducible_kpoints
dict["fft_size"] = basis.fft_size
dict["dvol"] = basis.dvol
dict["Ecut"] = basis.Ecut
Expand Down Expand Up @@ -224,11 +224,22 @@ function band_data_to_dict!(dict, band_data::NamedTuple; save_ψ=false, save_ρ=
end

function gather_and_store!(dict, key, basis, data)
# Gather from all k-points, even possibly duplicated ones
gathered = gather_kpts_block(basis, data)
if !isnothing(gathered)
n_kpoints = length(basis.kcoords_global)
n_kpoints = basis.n_irreducible_kpoints
n_spin = basis.model.n_spin_components
dict[key] = reshape(gathered, (size(data[1])..., n_kpoints, n_spin))
n_kpt_tot = length(basis.kcoords_global)

reshaped_data = reshape(gathered, (size(data[1])..., n_kpt_tot, n_spin))

# Only store irreducible k-points (assumed to be first in an array)
if n_kpt_tot > n_kpoints
index = ntuple(_ -> Colon(), ndims(dict[key]))
index = Base.setindex(index, 1:n_kpoints, ndims(dict[key]) - 1)
reshaped_data = reshaped_data[index...]
end
dict[key] = reshaped_data
end
end

Expand Down
5 changes: 5 additions & 0 deletions src/symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,11 @@ function unfold_array(basis_irred, basis_unfolded, data, is_ψ)
if !(basis_irred.comm_kpts == basis_irred.comm_kpts == MPI.COMM_WORLD)
error("Brillouin zone symmetry unfolding not supported with MPI yet")
end
if basis_irred.n_irreducible_kpoints < mpi_nprocs(basis_irred.comm_kpts)
# Note: if this routine is ever generalised for MPI,
# need special care for potentially duplicated KP
error("Brillouin zone symmetry unfolding not supported with duplicated k-points")
end
data_unfolded = similar(data, length(basis_unfolded.kpoints))
for ik_unfolded = 1:length(basis_unfolded.kpoints)
kpt_unfolded = basis_unfolded.kpoints[ik_unfolded]
Expand Down
2 changes: 1 addition & 1 deletion test/diag_compare.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "Comparison of diagonalisaton procedures" begin
@testitem "Comparison of diagonalisaton procedures" tags=[:dont_test_mpi] begin
using DFTK

function test_solver(reference, eigensolver, prec_type)
Expand Down
2 changes: 1 addition & 1 deletion test/external/atoms_calculators.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "Test AtomsCalculators interfaces" setup=[TestCases] tags=[:atomsbase] begin
@testitem "Test AtomsCalculators interfaces" setup=[TestCases] tags=[:atomsbase, :dont_test_mpi] begin
using AtomsCalculators
using AtomsCalculators.Testing: test_energy_forces_virial
using Unitful
Expand Down
2 changes: 1 addition & 1 deletion test/forces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ end
end
end

@testitem "Forces on oxygen with spin and temperature" setup=[TestCases] begin
@testitem "Forces on oxygen with spin and temperature" setup=[TestCases] tags=[:dont_test_mpi] begin
using DFTK
using DFTK: mpi_mean!
using MPI
Expand Down
2 changes: 1 addition & 1 deletion test/hamiltonian_consistency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
end


@testitem "Hamiltonian consistency" setup=[TestCases, HamConsistency] begin
@testitem "Hamiltonian consistency" setup=[TestCases, HamConsistency] tags=[:dont_test_mpi] begin
using DFTK
using LinearAlgebra
using .HamConsistency: test_consistency_term
Expand Down
2 changes: 1 addition & 1 deletion test/helium_all_electron.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "Helium all electron" tags=[:minimal, :core] begin
@testitem "Helium all electron" tags=[:minimal, :core, :dont_test_mpi] begin
using DFTK
using LinearAlgebra

Expand Down
2 changes: 1 addition & 1 deletion test/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "Pairwise forces" begin
@testitem "Pairwise forces" tags=[:dont_test_mpi] begin
using DFTK
using DFTK: energy_forces_pairwise
using LinearAlgebra
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using DFTK
using Interpolations

if base_tag == :mpi
nprocs = parse(Int, get(ENV, "DFTK_TEST_NPROCS", "$(clamp(Sys.CPU_THREADS, 2, 4))"))
nprocs = parse(Int, get(ENV, "DFTK_TEST_NPROCS", "$(clamp(Sys.CPU_THREADS, 2, 4))"))
run(`$(mpiexec()) -n $nprocs $(Base.julia_cmd())
--project --startup-file=no --compiled-modules=no
--check-bounds=yes --depwarn=yes --color=yes
Expand Down
2 changes: 1 addition & 1 deletion test/timeout.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "Timeout of SCF" setup=[TestCases] begin
@testitem "Timeout of SCF" setup=[TestCases] tags=[:dont_test_mpi] begin
using DFTK
using Dates
using Logging
Expand Down
6 changes: 3 additions & 3 deletions test/todict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function test_agreement_bands(band_data, dict; explicit_reshape=false, test_ψ=t

basis = band_data.basis
model = basis.model
n_kpoints = length(basis.kcoords_global)
n_kpoints = basis.n_irreducible_kpoints
n_spin = model.n_spin_components
n_bands = length(band_data.eigenvalues[1])
max_n_G = DFTK.mpi_max(maximum(kpt -> length(G_vectors(basis, kpt)), basis.kpoints),
Expand Down Expand Up @@ -45,8 +45,8 @@ function test_agreement_bands(band_data, dict; explicit_reshape=false, test_ψ=t
@test dict["atomic_symbols"] == map(e -> string(atomic_symbol(e)), model.atoms)
@test dict["atomic_positions"] model.positions atol=1e-12
@test dict["εF"] band_data.εF atol=1e-12
@test dict["kcoords"] basis.kcoords_global atol=1e-12
@test dict["kweights"] basis.kweights_global atol=1e-12
@test dict["kcoords"] DFTK.irreducible_kcoords_global(basis) atol=1e-12
@test dict["kweights"] DFTK.irreducible_kweights_global(basis) atol=1e-12
@test dict["Ecut"] basis.Ecut
@test dict["dvol"] basis.dvol atol=1e-12
@test [dict["fft_size"]...] == [basis.fft_size...]
Expand Down
2 changes: 1 addition & 1 deletion test/transfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
@test norm-ψ_bb) < eps(eltype(basis))
end

@testitem "Transfer of density" begin
@testitem "Transfer of density" tags=[:dont_test_mpi] begin
using DFTK
using DFTK: transfer_density
using LinearAlgebra
Expand Down
4 changes: 2 additions & 2 deletions test/variational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
end


@testitem "Energy is exact for supersampling>2 without XC" #=
@testitem "Energy is exact for supersampling>2 without XC" tags=[:dont_test_mpi] #=
=# setup=[Variational, TestCases] begin
using LinearAlgebra: norm
testcase = TestCases.silicon
Expand All @@ -36,7 +36,7 @@ end
@test norm(energies[2] .- energies[3]) < 1e-5
end

@testitem "Energy is not exact for supersampling>2 with XC" #=
@testitem "Energy is not exact for supersampling>2 with XC" tags=[:dont_test_mpi] #=
=# setup=[Variational, TestCases] begin
testcase = TestCases.silicon

Expand Down

0 comments on commit ea0ffe4

Please sign in to comment.