Skip to content

Commit

Permalink
new kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
sshin23 committed Sep 6, 2023
1 parent 52548bc commit eeaa113
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 113 deletions.
4 changes: 2 additions & 2 deletions examples/src/distillation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function distillation_column_model(N = 3; T = Float64, backend=nothing)
function distillation_column_model(N = 3; T = Float64, backend=nothing, kwargs...)
NT = 30
FT = 17
Ac = 0.5
Expand Down Expand Up @@ -66,5 +66,5 @@ function distillation_column_model(N = 3; T = Float64, backend=nothing)
yA[t, i] * (1 - xA[t, i]) - alpha * xA[t, i] * (1 - yA[t, i]) for (t, i) in itr2
)

return ExaModels.ExaModel(c)
return ExaModels.ExaModel(c; kwargs...)
end
4 changes: 2 additions & 2 deletions examples/src/luksanvlcek.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function luksan_vlcek_model(N = 3; T = Float64, backend=nothing)
function luksan_vlcek_model(N = 3; T = Float64, backend=nothing, kwargs...)

c = ExaModels.ExaCore(T, backend)
x = ExaModels.variable(c, N; start = (mod(i, 2) == 1 ? -1.2 : 1.0 for i = 1:N))
Expand All @@ -8,6 +8,6 @@ function luksan_vlcek_model(N = 3; T = Float64, backend=nothing)
x[i]exp(x[i] - x[i+1]) - 3 for i = 1:N-2
)
ExaModels.objective(c, 100 * (x[i-1]^2 - x[i])^2 + (x[i-1] - 1)^2 for i = 2:N)
return ExaModels.ExaModel(c)
return ExaModels.ExaModel(c; kwargs...)
end

5 changes: 3 additions & 2 deletions examples/src/opf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ end
function ac_power_model(
filename = "pglib_opf_case3_lmbd.m";
backend = nothing,
T = Float64
T = Float64,
kwargs...
)

data = parse_ac_power_data(filename, backend)
Expand Down Expand Up @@ -267,7 +268,7 @@ function ac_power_model(
g.bus =>-qg[g.i]
for g in data.gen)

return ExaModels.ExaModel(w)
return ExaModels.ExaModel(w; kwargs...)

end

Expand Down
4 changes: 2 additions & 2 deletions examples/src/quadrotor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function quadrotor_model(N=3; T = Float64, backend=nothing)
function quadrotor_model(N=3; T = Float64, backend=nothing, kwargs...)

n = 9
p = 4
Expand Down Expand Up @@ -33,6 +33,6 @@ function quadrotor_model(N=3; T = Float64, backend=nothing)
ExaModels.objective(c, .5*Q*(x[i,j]-d)^2 for (i,j,Q,d) in itr1)
ExaModels.objective(c, .5*Qf*(x[N+1,j]-d)^2 for (j,Qf,d) in itr2)

return ExaModels.ExaModel(c)
return ExaModels.ExaModel(c; kwargs...)
end

189 changes: 104 additions & 85 deletions ext/ExaModelsKernelAbstractions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ExaModelsKernelAbstractions

import ExaModels
import ExaModels: ExaModels, NLPModels
import KernelAbstractions: KernelAbstractions, @kernel, @index, @Const, synchronize, CPU

ExaModels.convert_array(v, backend::CPU) = v
Expand All @@ -23,7 +23,7 @@ function getptr(backend, array; cmp = isequal)
end


struct KAExtension{T,VT<:AbstractVector{T},VI1,VI2,VI3,B}
struct KAExtension{T,VT<:AbstractVector{T},VI1,VI2,B, H}
backend::B
objbuffer::VT
gradbuffer::VT
Expand All @@ -32,73 +32,92 @@ struct KAExtension{T,VT<:AbstractVector{T},VI1,VI2,VI3,B}
conbuffer::VT
conaugsparsity::VI1
conaugptr::VI2
jacbuffer::VT
jacsparsityi::VI3
jacsparsityj::VI3
jacptri::VI2
jacptrj::VI2
hessbuffer::VT
hesssparsityi::VI3
hesssparsityj::VI3
hessptri::VI2
hessptrj::VI2
prodhelper::H
end

function ExaModels.extension(
w::C,
) where {T,VT,B<:KernelAbstractions.Backend,C<:ExaModels.ExaCore{T,VT,B}}

gsparsity = similar(w.x0, Tuple{Int,Int}, w.nnzg)
function ExaModels.ExaModel(c::C; prod = false, kwargs...) where {
T,
VT<:AbstractVector{T},
B<:KernelAbstractions.Backend,
C<:ExaModels.ExaCore{T,VT,B}
}

gsparsity = similar(c.x0, Tuple{Int,Int}, c.nnzg)

_grad_structure!(w.backend, w.obj, gsparsity)
_grad_structure!(c.backend, c.obj, gsparsity)
ExaModels.sort!(gsparsity; lt = ((i, j), (k, l)) -> i < k)
gptr = getptr(w.backend, gsparsity)
gptr = getptr(c.backend, gsparsity)

conaugsparsity = similar(w.x0, Tuple{Int,Int}, w.nconaug)
_conaug_structure!(w.backend, w.con, conaugsparsity)
conaugsparsity = similar(c.x0, Tuple{Int,Int}, c.nconaug)
_conaug_structure!(c.backend, c.con, conaugsparsity)
length(conaugsparsity) > 0 && ExaModels.sort!(conaugsparsity; lt = ((i, j), (k, l)) -> i < k)
conaugptr = getptr(w.backend, conaugsparsity)

jacbuffer = similar(w.x0, w.nnzj)
hessbuffer = similar(w.x0, w.nnzh)
jacsparsityi = similar(w.x0, Tuple{Tuple{Int,Int},Int}, w.nnzj)
hesssparsityi = similar(w.x0, Tuple{Tuple{Int,Int},Int}, w.nnzh)

_jac_structure!(w.backend, w.con, jacsparsityi, nothing)
jacsparsityj = copy(jacsparsityi)
_obj_hess_structure!(w.backend, w.obj, hesssparsityi, nothing)
_con_hess_structure!(w.backend, w.con, hesssparsityi, nothing)
hesssparsityj = copy(hesssparsityi)

ExaModels.sort!(jacsparsityi; lt = (((i,j), k), ((n,m), l)) -> i < n)
ExaModels.sort!(jacsparsityj; lt = (((i,j), k), ((n,m), l)) -> j < m)
jacptri = getptr(w.backend, jacsparsityi; cmp = (x,y)->x[1] == y[1])
jacptrj = getptr(w.backend, jacsparsityj; cmp = (x,y)->x[2] == y[2])

ExaModels.sort!(hesssparsityi; lt = (((i,j), k), ((n,m), l)) -> i < n)
ExaModels.sort!(hesssparsityj; lt = (((i,j), k), ((n,m), l)) -> j < m)
hessptri = getptr(w.backend, hesssparsityi; cmp = (x,y)->x[1] == y[1])
hessptrj = getptr(w.backend, hesssparsityj; cmp = (x,y)->x[2] == y[2])
conaugptr = getptr(c.backend, conaugsparsity)

if prod
jacbuffer = similar(c.x0, c.nnzj)
hessbuffer = similar(c.x0, c.nnzh)
jacsparsityi = similar(c.x0, Tuple{Tuple{Int,Int},Int}, c.nnzj)
hesssparsityi = similar(c.x0, Tuple{Tuple{Int,Int},Int}, c.nnzh)

_jac_structure!(c.backend, c.con, jacsparsityi, nothing)
jacsparsityj = copy(jacsparsityi)
_obj_hess_structure!(c.backend, c.obj, hesssparsityi, nothing)
_con_hess_structure!(c.backend, c.con, hesssparsityi, nothing)
hesssparsityj = copy(hesssparsityi)

ExaModels.sort!(jacsparsityi; lt = (((i,j), k), ((n,m), l)) -> i < n)
ExaModels.sort!(jacsparsityj; lt = (((i,j), k), ((n,m), l)) -> j < m)
jacptri = getptr(c.backend, jacsparsityi; cmp = (x,y)->x[1] == y[1])
jacptrj = getptr(c.backend, jacsparsityj; cmp = (x,y)->x[2] == y[2])

ExaModels.sort!(hesssparsityi; lt = (((i,j), k), ((n,m), l)) -> i < n)
ExaModels.sort!(hesssparsityj; lt = (((i,j), k), ((n,m), l)) -> j < m)
hessptri = getptr(c.backend, hesssparsityi; cmp = (x,y)->x[1] == y[1])
hessptrj = getptr(c.backend, hesssparsityj; cmp = (x,y)->x[2] == y[2])

prodhelper = (
jacbuffer = jacbuffer,
jacsparsityi = jacsparsityi,
jacsparsityj = jacsparsityj,
jacptri = jacptri,
jacptrj = jacptrj,
hessbuffer = hessbuffer,
hesssparsityi = hesssparsityi,
hesssparsityj = hesssparsityj,
hessptri = hessptri,
hessptrj = hessptrj,
)
else
prodhelper = nothing
end

return KAExtension(
w.backend,
similar(w.x0, w.nobj),
similar(w.x0, w.nnzg),
gsparsity,
gptr,
similar(w.x0, w.nconaug),
conaugsparsity,
conaugptr,
jacbuffer,
jacsparsityi,
jacsparsityj,
jacptri,
jacptrj,
hessbuffer,
hesssparsityi,
hesssparsityj,
hessptri,
hessptrj,
return ExaModels.ExaModel(
c.obj,
c.con,
NLPModels.NLPModelMeta(
c.nvar,
ncon = c.ncon,
nnzj = c.nnzj,
nnzh = c.nnzh,
x0 = c.x0,
lvar = c.lvar,
uvar = c.uvar,
y0 = c.y0,
lcon = c.lcon,
ucon = c.ucon,
),
NLPModels.Counters(),
KAExtension(
c.backend,
similar(c.x0, c.nobj),
similar(c.x0, c.nnzg),
gsparsity,
gptr,
similar(c.x0, c.nconaug),
conaugsparsity,
conaugptr,
prodhelper
),
)
end

Expand Down Expand Up @@ -263,59 +282,59 @@ function _jac_coord!(backend, y, cons::ExaModels.ConstraintNull, x) end
function ExaModels.jprod_nln!(m::ExaModels.ExaModel{T,VT,E}, x::AbstractVector, v::AbstractVector, Jv::AbstractVector) where {T,VT,E <: KAExtension}

fill!(Jv, zero(eltype(Jv)))
fill!(m.ext.jacbuffer, zero(eltype(Jv)))
_jac_coord!(m.ext.backend, m.ext.jacbuffer, m.cons, x)
fill!(m.ext.prodhelper.jacbuffer, zero(eltype(Jv)))
_jac_coord!(m.ext.backend, m.ext.prodhelper.jacbuffer, m.cons, x)
synchronize(m.ext.backend)
kerspmv(m.ext.backend)(
Jv,
v,
m.ext.jacsparsityi,
m.ext.jacbuffer,
m.ext.jacptri,
ndrange = length(m.ext.jacptri) - 1,
m.ext.prodhelper.jacsparsityi,
m.ext.prodhelper.jacbuffer,
m.ext.prodhelper.jacptri,
ndrange = length(m.ext.prodhelper.jacptri) - 1,
)
synchronize(m.ext.backend)
end
function ExaModels.jtprod_nln!(m::ExaModels.ExaModel{T,VT,E}, x::AbstractVector, v::AbstractVector, Jtv::AbstractVector) where {T,VT,E <: KAExtension}

fill!(Jtv, zero(eltype(Jtv)))
fill!(m.ext.jacbuffer, zero(eltype(Jtv)))
_jac_coord!(m.ext.backend, m.ext.jacbuffer, m.cons, x)
fill!(m.ext.prodhelper.jacbuffer, zero(eltype(Jtv)))
_jac_coord!(m.ext.backend, m.ext.prodhelper.jacbuffer, m.cons, x)
synchronize(m.ext.backend)
kerspmv2(m.ext.backend)(
Jtv,
v,
m.ext.jacsparsityj,
m.ext.jacbuffer,
m.ext.jacptrj,
ndrange = length(m.ext.jacptrj) - 1,
m.ext.prodhelper.jacsparsityj,
m.ext.prodhelper.jacbuffer,
m.ext.prodhelper.jacptrj,
ndrange = length(m.ext.prodhelper.jacptrj) - 1,
)
synchronize(m.ext.backend)
end
function ExaModels.hprod!(m::ExaModels.ExaModel{T,VT,E}, x::AbstractVector, y::AbstractVector, v::AbstractVector, Hv::AbstractVector; obj_weight= one(eltype(x))) where {T,VT,E <: KAExtension}

fill!(Hv, zero(eltype(Hv)))
fill!(m.ext.hessbuffer, zero(eltype(Hv)))
fill!(m.ext.prodhelper.hessbuffer, zero(eltype(Hv)))

_obj_hess_coord!(m.ext.backend, m.ext.hessbuffer, m.objs, x, obj_weight)
_con_hess_coord!(m.ext.backend, m.ext.hessbuffer, m.cons, x, y)
_obj_hess_coord!(m.ext.backend, m.ext.prodhelper.hessbuffer, m.objs, x, obj_weight)
_con_hess_coord!(m.ext.backend, m.ext.prodhelper.hessbuffer, m.cons, x, y)
synchronize(m.ext.backend)
kersyspmv(m.ext.backend)(
Hv,
v,
m.ext.hesssparsityi,
m.ext.hessbuffer,
m.ext.hessptri,
ndrange = length(m.ext.hessptri) - 1,
m.ext.prodhelper.hesssparsityi,
m.ext.prodhelper.hessbuffer,
m.ext.prodhelper.hessptri,
ndrange = length(m.ext.prodhelper.hessptri) - 1,
)
synchronize(m.ext.backend)
kersyspmv2(m.ext.backend)(
Hv,
v,
m.ext.hesssparsityj,
m.ext.hessbuffer,
m.ext.hessptrj,
ndrange = length(m.ext.hessptrj) - 1,
m.ext.prodhelper.hesssparsityj,
m.ext.prodhelper.hessbuffer,
m.ext.prodhelper.hessptrj,
ndrange = length(m.ext.prodhelper.hessptrj) - 1,
)
synchronize(m.ext.backend)
end
Expand Down
3 changes: 0 additions & 3 deletions ext/ExaModelsOneAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ end

ExaModels.convert_array(v, backend::oneAPI.oneAPIBackend) = oneAPI.oneArray(v)

function ExaModels.sum(a::A) where {A<:oneAPI.oneVector{Float64}}
return sum(Array(a))
end
ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneVector} =
copyto!(array, sort!(Array(array)))

Expand Down
4 changes: 1 addition & 3 deletions src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ function ExaModel(c::C) where {C<:ExaCore}
ucon = c.ucon,
),
NLPModels.Counters(),
extension(c),
nothing,
)
end

Expand Down Expand Up @@ -416,8 +416,6 @@ function constraint!(c::C, c1, gen) where {C<:ExaCore}
end


function extension(args...) end

function jac_structure!(m::ExaModel, rows::AbstractVector, cols::AbstractVector)

_jac_structure!(m.cons, rows, cols)
Expand Down
Loading

0 comments on commit eeaa113

Please sign in to comment.