Skip to content

Commit

Permalink
Optimize apply_K (#1006)
Browse files Browse the repository at this point in the history
  • Loading branch information
Technici4n authored Oct 24, 2024
1 parent cb7f2cb commit 7fa34e4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/densities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ end
# … and then we compute the real Fourier transform in the adequate basis.
ifft!(storage.δψnk_real, basis, δψ_plus_k[ik].kpt, δψ_plus_k[ik].ψk[:, n])

storage.δρ[:, :, :, kpt.spin] .+= real_qzero(
2 .* occupation[ik][n] .* basis.kweights[ik] .* conj(storage.ψnk_real)
storage.δρ[:, :, :, kpt.spin] .+= real_qzero.(
2 .* occupation[ik][n] .* basis.kweights[ik] .* conj.(storage.ψnk_real)
.* storage.δψnk_real
.+ δoccupation[ik][n] .* basis.kweights[ik] .* abs2.(storage.ψnk_real))

Expand Down
14 changes: 10 additions & 4 deletions src/response/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,25 @@ end
Compute the application of K defined at ψ to δψ. ρ is the density issued from ψ.
δψ also generates a δρ, computed with `compute_δρ`.
"""
@views @timing function apply_K(basis::PlaneWaveBasis, δψ, ψ, ρ, occupation)
@views @timing function apply_K(basis::PlaneWaveBasis{T}, δψ, ψ, ρ, occupation) where {T}
# ~45% of apply_K is spent computing ifft(ψ) twice: once in compute_δρ and once again below.
# By caching the result, we could compute it only once for a single application of K,
# or even across many applications when using solve_ΩplusK.
# But we don't because the memory requirements would be too high (typically an order of magnitude higher than ψ).

δψ = proj_tangent(δψ, ψ)
δρ = compute_δρ(basis, ψ, δψ, occupation)
δV = apply_kernel(basis, δρ; ρ)

ψnk_real = similar(basis.G_vectors, promote_type(T, eltype(ψ[1])))
Kδψ = map(enumerate(ψ)) do (ik, ψk)
kpt = basis.kpoints[ik]
δVψk = similar(ψk)

for n = 1:size(ψk, 2)
ψnk_real = ifft(basis, kpt, ψk[:, n])
δVψnk_real = δV[:, :, :, kpt.spin] .* ψnk_real
δVψk[:, n] = fft(basis, kpt, δVψnk_real)
ifft!(ψnk_real, basis, kpt, ψk[:, n])
ψnk_real .*= δV[:, :, :, kpt.spin]
fft!(δVψk[:, n], basis, kpt, ψnk_real)
end
δVψk
end
Expand Down

0 comments on commit 7fa34e4

Please sign in to comment.