Skip to content

Commit

Permalink
Merge pull request #1111 from SciML/testsplit
Browse files Browse the repository at this point in the history
Split QA tests to a separate group
  • Loading branch information
ChrisRackauckas authored Sep 10, 2024
2 parents eded161 + c0bf32c commit 67e0245
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 28 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- QA
- SDE1
- SDE2
- SDE3
Expand Down
22 changes: 13 additions & 9 deletions src/forward_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ function ODEForwardSensitivityProblem(f::F, args...; kwargs...) where {F}
ODEForwardSensitivityProblem(ODEFunction(f), args...; kwargs...)
end

function ODEForwardSensitivityProblem(prob::ODEProblem; sensealg = ForwardSensitivity(), kwargs...)
function ODEForwardSensitivityProblem(
prob::ODEProblem; sensealg = ForwardSensitivity(), kwargs...)
_ODEForwardSensitivityProblem(
prob.f, state_values(prob), prob.tspan, parameter_values(prob), sensealg; kwargs...)
end
Expand Down Expand Up @@ -351,11 +352,10 @@ at time `sol.t[i]`. Note that all the functionality available to ODE solutions
is available in this case, including interpolations and plot recipes (the recipes
will plot the expanded system).
"""
function ODEForwardSensitivityProblem(f::F, u0, tspan, p = nothing;
sensealg = ForwardSensitivity(),
kwargs...) where {F <: DiffEqBase.AbstractODEFunction}

_ODEForwardSensitivityProblem(f,u0,tspan,p,sensealg; kwargs...)
function ODEForwardSensitivityProblem(f::F, u0, tspan, p = nothing;
sensealg = ForwardSensitivity(),
kwargs...) where {F <: DiffEqBase.AbstractODEFunction}
_ODEForwardSensitivityProblem(f, u0, tspan, p, sensealg; kwargs...)
end

# deprecated
Expand All @@ -366,8 +366,10 @@ function ODEForwardSensitivityProblem(f::F, u0,
w0 = nothing,
v0 = nothing,
kwargs...) where {F <: DiffEqBase.AbstractODEFunction}
Base.depwarn("The form of this function with `alg` as a positional argument is deprecated. Please use the `sensealg` keyword argument instead.", :ODEForwardSensitivityProblem)
_ODEForwardSensitivityProblem(f,u0,tspan,p,alg; nus,w0,v0,kwargs...)
Base.depwarn(
"The form of this function with `alg` as a positional argument is deprecated. Please use the `sensealg` keyword argument instead.",
:ODEForwardSensitivityProblem)
_ODEForwardSensitivityProblem(f, u0, tspan, p, alg; nus, w0, v0, kwargs...)
end

function _ODEForwardSensitivityProblem(f::F, u0,
Expand Down Expand Up @@ -490,7 +492,9 @@ function ODEForwardSensitivityProblem(f::DiffEqBase.AbstractODEFunction, u0,
du0 = zeros(eltype(u0), length(u0), length(p)), # perturbations of initial condition
dp = I(length(p)), # perturbations of parameters
kwargs...)
Base.depwarn("The form of this function with `alg` as a positional argument is deprecated. Please use the `sensealg` keyword argument instead.", :ODEForwardSensitivity)
Base.depwarn(
"The form of this function with `alg` as a positional argument is deprecated. Please use the `sensealg` keyword argument instead.",
:ODEForwardSensitivity)
_ODEForwardSensitivityProblem(f, u0, tspan, p, alg, du0, dp, kwargs...)
end

Expand Down
6 changes: 4 additions & 2 deletions test/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ probvecmat = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p,
autojacmat = true))

# tests that the deprecated version still works
dep_prob_const = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p, ForwardSensitivity())
dep_prob_const = ODEForwardSensitivityProblem(
fb, [1.0; 1.0], (0.0, 10.0), p, ForwardSensitivity())

sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14)
@test_broken solve(probInpl, KenCarp4(), abstol = 1e-14, reltol = 1e-14).retcode == :Success
Expand Down Expand Up @@ -199,7 +200,8 @@ sol_MM_ForwardDiffSensitivity = solve(prob_MM_ForwardDiffSensitivity,
Rodas4(autodiff = false), reltol = 1e-14,
abstol = 1e-14)

prob_no_MM = ODEForwardSensitivityProblem(f_no_MM, u0, tspan, p, sensealg = ForwardSensitivity())
prob_no_MM = ODEForwardSensitivityProblem(
f_no_MM, u0, tspan, p, sensealg = ForwardSensitivity())
sol_no_MM = solve(prob_no_MM, Rodas4(autodiff = false), reltol = 1e-14, abstol = 1e-14)

sen_MM_ForwardSensitivity = extract_local_sensitivities(sol_MM_ForwardSensitivity, 10.0,
Expand Down
40 changes: 26 additions & 14 deletions test/noindex_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ end
Base.size(x::CustomArray) = size(x.x)
Base.axes(x::CustomArray) = axes(x.x)
Base.ndims(x::CustomArray) = ndims(x.x)
Base.ndims(::Type{<:CustomArray{T,N}}) where {T,N} = N
Base.ndims(::Type{<:CustomArray{T, N}}) where {T, N} = N
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
Base.zero(::Type{<:CustomArray{T,N}}) where {T,N} = CustomArray(zero(Array{T,N}))
Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...) = CustomArray(similar(x.x, dims...))
Base.zero(::Type{<:CustomArray{T, N}}) where {T, N} = CustomArray(zero(Array{T, N}))
function Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...)
CustomArray(similar(x.x, dims...))
end
Base.copyto!(x::CustomArray, y::CustomArray) = CustomArray(copyto!(x.x, y.x))
Base.copy(x::CustomArray) = CustomArray(copy(x.x))
Base.length(x::CustomArray) = length(x.x)
Expand All @@ -29,22 +31,30 @@ Base.all(f::Function, x::CustomArray; kwargs...) = all(f, x.x; kwargs...)
Base.similar(x::CustomArray, t) = CustomArray(similar(x.x, t))
Base.:(+)(x::CustomArray, y::CustomArray) = CustomArray(x.x + y.x)
Base.:(==)(x::CustomArray, y::CustomArray) = x.x == y.x
Base.:(*)(x::Number, y::CustomArray) = CustomArray(x*y.x)
Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x/y)
Base.:(*)(x::Number, y::CustomArray) = CustomArray(x * y.x)
Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x / y)
LinearAlgebra.norm(x::CustomArray) = norm(x.x)
LinearAlgebra.vec(x::CustomArray) = CustomArray(vec(x.x))

struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end
CustomStyle(::Val{N}) where N = CustomStyle{N}()
CustomStyle{M}(::Val{N}) where {N,M} = NoIndexStyle{N}()
Base.BroadcastStyle(::Type{<:CustomArray{T,N}}) where {T,N} = CustomStyle{N}()
Broadcast.BroadcastStyle(::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N} = CustomStyle{N}()
Base.similar(bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType} = CustomArray(similar(Array{ElType, N}, axes(bc)))
CustomStyle(::Val{N}) where {N} = CustomStyle{N}()
CustomStyle{M}(::Val{N}) where {N, M} = NoIndexStyle{N}()
Base.BroadcastStyle(::Type{<:CustomArray{T, N}}) where {T, N} = CustomStyle{N}()
function Broadcast.BroadcastStyle(
::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N}
CustomStyle{N}()
end
function Base.similar(
bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType}
CustomArray(similar(Array{ElType, N}, axes(bc)))
end
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::CustomArray, i) = x.x[i]
Base.Broadcast.extrude(x::CustomArray) = x
Base.Broadcast.broadcastable(x::CustomArray) = x

@inline function Base.copyto!(dest::CustomArray, bc::Base.Broadcast.Broadcasted{<:Union{Base.Broadcast.AbstractArrayStyle,CustomStyle}})
@inline function Base.copyto!(dest::CustomArray,
bc::Base.Broadcast.Broadcasted{<:Union{
Base.Broadcast.AbstractArrayStyle, CustomStyle}})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.x
Expand Down Expand Up @@ -75,7 +85,7 @@ RecursiveArrayTools.recursivefill!(x::CustomArray, a) = fill!(x, a)

Base.show_vector(io::IO, x::CustomArray) = Base.show_vector(io, x.x)

Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray");show(io, x.x))
Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray"); show(io, x.x))
function Base.show(io::IO, ::MIME"text/plain", x::CustomArray)
println(io, Base.summary(x), ":")
Base.print_array(io, x.x)
Expand All @@ -89,9 +99,11 @@ algs = [Tsit5(), BS3(), Vern9(), DP5()]

for alg in algs
function cost(p)
prob = ODEProblem((du, u, p, t) -> (du[1] = p[1]*u[1] + p[2]*u[2]; du[2] = p[2]*u[1]), ca0, tspan, p)
prob = ODEProblem(
(du, u, p, t) -> (du[1] = p[1] * u[1] + p[2] * u[2]; du[2] = p[2] * u[1]),
ca0, tspan, p)
sol = solve(prob, alg; save_everystep = false)
return 1 - norm(sol[end])^2
end
@test_nowarn Zygote.gradient(cost, par)
end
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ end
@time @safetestset "ForwardDiff Sparsity Components" include("forwarddiffsensitivity_sparsity_components.jl")
@time @safetestset "Complex No u" include("complex_no_u.jl")
@time @safetestset "Parameter Handling" include("parameter_handling.jl")
@time @safetestset "Quality Assurance" include("aqua.jl")
end
end

Expand All @@ -101,6 +100,10 @@ end
end
end

if GROUP == "All" || GROUP == "QA"
@time @safetestset "Quality Assurance" include("aqua.jl")
end

if GROUP == "All" || GROUP == "SDE1"
@testset "SDE 1" begin
@time @safetestset "SDE Adjoint" include("sde_stratonovich.jl")
Expand Down
8 changes: 6 additions & 2 deletions test/sde_transformation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ prob_strat = SDEProblem{false}(
tspan,
p)
Random.seed!(seed)
sol_strat = solve(prob_strat, RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich), adaptive = false,
sol_strat = solve(
prob_strat, RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich),
adaptive = false,
dt = 0.0001, save_noise = true)
prob_strat1 = SDEProblem{false}(
SDEFunction((u, p, t) -> transformed_function(u, p, t) .+
Expand All @@ -56,7 +58,9 @@ prob_strat1 = SDEProblem{false}(
tspan,
p)
Random.seed!(seed)
sol_strat1 = solve(prob_strat1, RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich), adaptive = false,
sol_strat1 = solve(
prob_strat1, RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich),
adaptive = false,
dt = 0.0001, save_noise = true)

# Test if we recover Ito solution in Stratonovich sense
Expand Down

0 comments on commit 67e0245

Please sign in to comment.