diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 6d79ee746a..0c53a7779e 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -1,7 +1,7 @@ name: IntegrationTest on: push: - branches: [master] + branches: [master, 'backport-v9'] tags: [v*] pull_request: paths-ignore: diff --git a/.github/workflows/ReleaseTest.yml b/.github/workflows/ReleaseTest.yml index f8b592e9d0..96b84f4d79 100644 --- a/.github/workflows/ReleaseTest.yml +++ b/.github/workflows/ReleaseTest.yml @@ -1,7 +1,7 @@ name: ReleaseTest on: push: - branches: [master] + branches: [master, 'backport-v9'] tags: [v*] pull_request: paths-ignore: diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 52c5482970..835ee358a6 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -5,11 +5,13 @@ on: branches: - master - 'release-' + - 'backport-v9' paths-ignore: - 'docs/**' push: branches: - master + - 'backport-v9' paths-ignore: - 'docs/**' diff --git a/Project.toml b/Project.toml index 305f29c691..d68990f6e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.80.1" +version = "9.84.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -101,7 +101,7 @@ DiffEqBase = "6.170.1" DiffEqCallbacks = "2.16, 3, 4" DiffEqNoiseProcess = "5" DiffRules = "0.1, 1.0" -DifferentiationInterface = "0.6.47" +DifferentiationInterface = "0.6.47, 0.7" Distributed = "1" Distributions = "0.23, 0.24, 0.25" DocStringExtensions = "0.7, 0.8, 0.9" @@ -111,7 +111,7 @@ EnumX = "1.0.4" ExprTools = "0.1.10" FMI = "0.14" FindFirstFunctions = "1" -ForwardDiff = "0.10.3" +ForwardDiff = "0.10.3, 1" FunctionWrappers = "1.1" FunctionWrappersWrappers = "0.1" Graphs = "1.5.2" @@ -147,7 +147,7 @@ Serialization = "1" Setfield = "0.7, 0.8, 1" SimpleNonlinearSolve = "0.1.0, 1, 2" SparseArrays = "1" -SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" +SpecialFunctions = "1, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" StochasticDelayDiffEq = "1.10" StochasticDiffEq = "6.72.1" diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index b6665646c9..8cdb76cca0 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors) end end -function delete_srcs!(g::BipartiteGraph, srcs) +function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I} for s in srcs set_neighbors!(g, s, ()) end + if rm_verts + old_to_new_idxs = collect(one(I):I(nsrcs(g))) + for s in srcs + old_to_new_idxs[s] = zero(I) + end + offset = zero(I) + for i in eachindex(old_to_new_idxs) + if iszero(old_to_new_idxs[i]) + offset += one(I) + continue + end + old_to_new_idxs[i] -= offset + end + + if g.badjlist isa AbstractVector + for i in 1:ndsts(g) + for j in eachindex(g.badjlist[i]) + g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]] + end + filter!(!iszero, g.badjlist[i]) + end + end + deleteat!(g.fadjlist, srcs) + end g end -delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs) +function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false) + delete_srcs!(invview(g), srcs; rm_verts) +end ### ### Edges iteration diff --git a/src/linearization.jl b/src/linearization.jl index b30d275818..fec5711dde 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -163,13 +163,13 @@ struct PreparedJacobian{iip, P, F, B, A} end function PreparedJacobian{true}(f, buf, autodiff, args...) - prep = DI.prepare_jacobian(f, buf, autodiff, args...) + prep = DI.prepare_jacobian(f, buf, autodiff, args...; strict = Val(false)) return PreparedJacobian{true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)}( prep, f, buf, autodiff) end function PreparedJacobian{false}(f, autodiff, args...) - prep = DI.prepare_jacobian(f, autodiff, args...) + prep = DI.prepare_jacobian(f, autodiff, args...; strict = Val(false)) return PreparedJacobian{true, typeof(prep), typeof(f), Nothing, typeof(autodiff)}( prep, f, nothing) end diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index d41cbf12d2..244a737705 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -724,7 +724,7 @@ Update the system equations, unknowns, and observables after simplification. """ function update_simplified_system!( state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; - cse_hack = true, array_hack = true) + array_hack = true) @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure diff_to_var = invview(var_to_diff) @@ -740,7 +740,8 @@ function update_simplified_system!( obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = [fast_substitute(observed(sys), obs_sub); solved_eqs] + obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; + fast_substitute(state.additional_observed, obs_sub)] unknowns = Any[v for (i, v) in enumerate(state.fullvars) @@ -748,12 +749,13 @@ function update_simplified_system!( unknowns = [unknowns; extra_unknowns] @set! sys.unknowns = unknowns - obs, subeqs, deps = cse_and_array_hacks( - sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack) + obs = tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack) + deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) + for i in 1:length(solved_eqs)] @set! sys.eqs = neweqs @set! sys.observed = obs - @set! sys.substitutions = Substitutions(subeqs, deps) + @set! sys.substitutions = Substitutions(solved_eqs, deps) # Only makes sense for time-dependent # TODO: generalize to SDE @@ -791,7 +793,7 @@ appear in the system. Algebraic variables are variables that are not differential variables. """ function tearing_reassemble(state::TearingState, var_eq_matching, - full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true) + full_var_eq_matching = nothing; simplify = false, mm = nothing, array_hack = true) extra_vars = Int[] if full_var_eq_matching !== nothing for v in 𝑑vertices(state.structure.graph) @@ -827,7 +829,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var) sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, - extra_unknowns; cse_hack, array_hack) + extra_unknowns; array_hack) @set! state.sys = sys @set! sys.tearing_state = state @@ -835,14 +837,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end """ -# HACK 1 - -Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]` -gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets -_very_ expensive. this hack performs a limited form of CSE specifically for this case to -avoid the unnecessary cost. This and the below hack are implemented simultaneously - -# HACK 2 +# HACK Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation @@ -850,13 +845,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs not) we first count the number of times the scalarized form of each observed variable occurs in observed equations (and unknowns if it's split). """ -function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true) - # HACK 1 - # mapping of rhs to temporary CSE variable - # `f(...) => tmpvar` in above example - rhs_to_tempvar = Dict() - - # HACK 2 +function tearing_hacks(sys, obs, unknowns, neweqs; array = true) # map of array observed variable (unscalarized) to number of its # scalarized terms that appear in observed equations arr_obs_occurrences = Dict() @@ -864,36 +853,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr lhs = eq.lhs rhs = eq.rhs - # HACK 1 - if cse && is_getindexed_array(rhs) - rhs_arr = arguments(rhs)[1] - iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue - if !haskey(rhs_to_tempvar, rhs_arr) - tempvar = gensym(Symbol(lhs)) - N = length(rhs_arr) - tempvar = unwrap(Symbolics.variable( - tempvar; T = Symbolics.symtype(rhs_arr))) - tempvar = setmetadata( - tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) - tempeq = tempvar ~ rhs_arr - rhs_to_tempvar[rhs_arr] = tempvar - push!(obs, tempeq) - push!(subeqs, tempeq) - end - - # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different, - # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr` - # which fails the topological sort - neweq = lhs ~ getindex_wrapper( - rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) - obs[i] = neweq - subeqi = findfirst(isequal(eq), subeqs) - if subeqi !== nothing - subeqs[subeqi] = neweq - end - end - # end HACK 1 - array || continue iscall(lhs) || continue operation(lhs) === getindex || continue @@ -904,32 +863,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr continue end - # Also do CSE for `equations(sys)` - if cse - for (i, eq) in enumerate(neweqs) - (; lhs, rhs) = eq - is_getindexed_array(rhs) || continue - rhs_arr = arguments(rhs)[1] - if !haskey(rhs_to_tempvar, rhs_arr) - tempvar = gensym(Symbol(lhs)) - N = length(rhs_arr) - tempvar = unwrap(Symbolics.variable( - tempvar; T = Symbolics.symtype(rhs_arr))) - tempvar = setmetadata( - tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) - tempeq = tempvar ~ rhs_arr - rhs_to_tempvar[rhs_arr] = tempvar - push!(obs, tempeq) - push!(subeqs, tempeq) - end - # don't need getindex_wrapper, but do it anyway to know that this - # hack took place - neweq = lhs ~ getindex_wrapper( - rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) - neweqs[i] = neweq - end - end - # count variables in unknowns if they are scalarized forms of variables # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)` # is an observed equation. @@ -960,29 +893,11 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr push!(obs_arr_eqs, arrvar ~ rhs) end append!(obs, obs_arr_eqs) - append!(subeqs, obs_arr_eqs) - - # need to re-sort subeqs - subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) - - deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) - for i in 1:length(subeqs)] - - return obs, subeqs, deps -end -function is_getindexed_array(rhs) - (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && - iscall(rhs) && operation(rhs) === getindex && - Symbolics.shape(rhs) != Symbolics.Unknown() + return obs end -# PART OF HACK 1 -getindex_wrapper(x, i) = x[i...] - -@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}}) - -# PART OF HACK 2 +# PART OF HACK function change_origin(origin, arr) if all(isone, Tuple(origin)) return arr @@ -1010,10 +925,10 @@ new residual equations after tearing. End users are encouraged to call [`structu instead, which calls this function internally. """ function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing, - simplify = false, cse_hack = true, array_hack = true, kwargs...) + simplify = false, array_hack = true, kwargs...) var_eq_matching, full_var_eq_matching = tearing(state) invalidate_cache!(tearing_reassemble( - state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack)) + state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack)) end """ @@ -1035,7 +950,7 @@ Perform index reduction and use the dummy derivative technique to ensure that the system is balanced. """ function dummy_derivative(sys, state = TearingState(sys); simplify = false, - mm = nothing, cse_hack = true, array_hack = true, kwargs...) + mm = nothing, array_hack = true, kwargs...) jac = let state = state (eqs, vars) -> begin symeqs = EquationsView(state)[eqs] @@ -1059,5 +974,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, end var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...) - tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack) + tearing_reassemble(state, var_eq_matching; simplify, mm, array_hack) end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 6c00806b7d..d78c48a857 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2835,15 +2835,20 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, elseif sys isa ODESystem rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), collect(rules))) - eqs = fast_substitute(get_eqs(sys), rules) - pdeps = fast_substitute(get_parameter_dependencies(sys), rules) - defs = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + newsys = @set sys.eqs = fast_substitute(get_eqs(sys), rules) + @set! newsys.unknowns = map(get_unknowns(sys)) do var + get(rules, var, var) + end + @set! newsys.ps = map(get_ps(sys)) do var + get(rules, var, var) + end + @set! newsys.parameter_dependencies = fast_substitute( + get_parameter_dependencies(sys), rules) + @set! newsys.defaults = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) for (k, v) in get_defaults(sys)) - guess = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.guesses = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) for (k, v) in get_guesses(sys)) - subsys = map(s -> substitute(s, rules), get_systems(sys)) - ODESystem(eqs, get_iv(sys); name = nameof(sys), defaults = defs, - guesses = guess, parameter_dependencies = pdeps, systems = subsys) + @set! newsys.systems = map(s -> substitute(s, rules), get_systems(sys)) else error("substituting symbols is not supported for $(typeof(sys))") end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index d802f49fee..f87c4d0071 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1495,7 +1495,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, @warn errmsg end - uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)]) + uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys)) # TODO: throw on uninitialized arrays filter!(x -> !(x isa Symbolics.Arr), uninit) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 4a13d7ccf1..334835af58 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -476,6 +476,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys - `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist. - `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function. - `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function. +- `wrap_delays = is_dde(sys)`: Whether to add an argument for the history function and use + it to calculate all delayed variables. ## Returns @@ -514,7 +516,8 @@ function build_explicit_observed_function(sys, ts; op = Operator, throw = true, cse = true, - mkarray = nothing) + mkarray = nothing, + wrap_delays = is_dde(sys)) is_tuple = ts isa Tuple if is_tuple ts = collect(ts) @@ -600,14 +603,15 @@ function build_explicit_observed_function(sys, ts; p_end = length(dvs) + length(inputs) + length(ps) fns = build_function_wrapper( sys, ts, args...; p_start, p_end, filter_observed = obsfilter, - output_type, mkarray, try_namespaced = true, expression = Val{true}, cse) + output_type, mkarray, try_namespaced = true, expression = Val{true}, cse, + wrap_delays) if fns isa Tuple if expression return return_inplace ? fns : fns[1] end oop, iip = eval_or_rgf.(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}( + p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( oop, iip) return return_inplace ? (f, f) : f else @@ -616,7 +620,7 @@ function build_explicit_observed_function(sys, ts; end f = eval_or_rgf(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}( + p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( f, nothing) return f end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index d0b687c212..ae4feec62b 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -388,8 +388,8 @@ function IndexCache(sys::AbstractSystem) observed_syms_to_timeseries, dependent_pars_to_timeseries, disc_buffer_templates, - BufferTemplate(Real, tunable_buffer_size), - BufferTemplate(Real, initials_buffer_size), + BufferTemplate(Number, tunable_buffer_size), + BufferTemplate(Number, initials_buffer_size), const_buffer_sizes, nonnumeric_buffer_sizes, symbol_to_variable diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index d2a988dc07..bd2c7de753 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -598,24 +598,44 @@ function SciMLBase.remake_initialization_data( return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp) end -function promote_u0_p(u0, p::MTKParameters, t0) - u0 = DiffEqBase.promote_u0(u0, p.tunable, t0) - u0 = DiffEqBase.promote_u0(u0, p.initials, t0) +promote_type_with_nothing(::Type{T}, ::Nothing) where {T} = T +promote_type_with_nothing(::Type{T}, ::SizedVector{0}) where {T} = T +function promote_type_with_nothing(::Type{T}, ::AbstractArray{T2}) where {T, T2} + promote_type(T, T2) +end +function promote_type_with_nothing(::Type{T}, p::MTKParameters) where {T} + promote_type_with_nothing(promote_type_with_nothing(T, p.tunable), p.initials) +end - if !isempty(p.tunable) - tunables = DiffEqBase.promote_u0(p.tunable, u0, t0) - p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) - end - if !isempty(p.initials) - initials = DiffEqBase.promote_u0(p.initials, u0, t0) - p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) +promote_with_nothing(::Type, ::Nothing) = nothing +promote_with_nothing(::Type, x::SizedVector{0}) = x +promote_with_nothing(::Type{T}, x::AbstractArray{T}) where {T} = x +function promote_with_nothing(::Type{T}, x::AbstractArray{T2}) where {T, T2} + if ArrayInterface.ismutable(x) + y = similar(x, T) + copyto!(y, x) + return y + else + yT = similar_type(x, T) + return yT(x) end - - return u0, p +end +function promote_with_nothing(::Type{T}, p::MTKParameters) where {T} + tunables = promote_with_nothing(T, p.tunable) + p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) + initials = promote_with_nothing(T, p.initials) + p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) + return p end function promote_u0_p(u0, p, t0) - return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0) + T = Union{} + T = promote_type_with_nothing(T, u0) + T = promote_type_with_nothing(T, p) + + u0 = promote_with_nothing(T, u0) + p = promote_with_nothing(T, p) + return u0, p end function SciMLBase.late_binding_update_u0_p( @@ -628,7 +648,26 @@ function SciMLBase.late_binding_update_u0_p( newu0, newp = promote_u0_p(newu0, newp, t0) # non-symbolic u0 updates initials... - if !(eltype(u0) <: Pair) + if eltype(u0) <: Pair + syms = [] + vals = [] + allsyms = all_symbols(sys) + for (k, v) in u0 + v === nothing && continue + (symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue + if k isa Symbol + k2 = symbol_to_symbolic(sys, k; allsyms) + # if it is returned as-is, there is no match so skip it + k2 === k && continue + k = k2 + end + is_parameter(sys, Initial(k)) || continue + push!(syms, Initial(k)) + push!(vals, v) + end + newp = setp_oop(sys, syms)(newp, vals) + else + allsyms = nothing # if `p` is not provided or is symbolic p === missing || eltype(p) <: Pair || return newu0, newp (newu0 === nothing || isempty(newu0)) && return newu0, newp @@ -641,31 +680,36 @@ function SciMLBase.late_binding_update_u0_p( throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))")) end newp = meta.set_initial_unknowns!(newp, newu0) - return newu0, newp - end - - syms = [] - vals = [] - allsyms = all_symbols(sys) - for (k, v) in u0 - v === nothing && continue - (symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue - if k isa Symbol - k2 = symbol_to_symbolic(sys, k; allsyms) - # if it is returned as-is, there is no match so skip it - k2 === k && continue - k = k2 + end + + if eltype(p) <: Pair + syms = [] + vals = [] + if allsyms === nothing + allsyms = all_symbols(sys) + end + for (k, v) in p + v === nothing && continue + (symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue + if k isa Symbol + k2 = symbol_to_symbolic(sys, k; allsyms) + # if it is returned as-is, there is no match so skip it + k2 === k && continue + k = k2 + end + is_parameter(sys, Initial(k)) || continue + push!(syms, Initial(k)) + push!(vals, v) end - is_parameter(sys, Initial(k)) || continue - push!(syms, Initial(k)) - push!(vals, v) + newp = setp_oop(sys, syms)(newp, vals) end - newp = setp_oop(sys, syms)(newp, vals) return newu0, newp end -function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...) +function DiffEqBase.get_updated_symbolic_problem( + sys::AbstractSystem, prob; u0 = state_values(prob), + p = parameter_values(prob), kw...) supports_initialization(sys) || return prob initdata = prob.f.initialization_data initdata isa SciMLBase.OverrideInitData || return prob @@ -673,10 +717,8 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw.. meta isa InitializationMetadata || return prob meta.get_updated_u0 === nothing && return prob - u0 = state_values(prob) - u0 === nothing && return prob + u0 === nothing && return remake(prob; p) - p = parameter_values(prob) t0 = is_time_dependent(prob) ? current_time(prob) : nothing if p isa MTKParameters @@ -693,7 +735,7 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw.. T = StaticArrays.similar_type(u0) end - return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob))) + return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)), p) end """ @@ -719,20 +761,6 @@ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) push!(rm_idxs, i) continue end - if operation(eq.rhs) == StructuralTransformations.getindex_wrapper - var, idxs = arguments(eq.rhs) - subs[eq.rhs] = var[idxs...] - push!(tempvars, var) - end - end - - for (i, eq) in enumerate(eqs) - iscall(eq.rhs) || continue - if operation(eq.rhs) == StructuralTransformations.getindex_wrapper - var, idxs = arguments(eq.rhs) - subs[eq.rhs] = var[idxs...] - push!(tempvars, var) - end end for (i, eq) in enumerate(obseqs) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index bfe15b62d7..fcb502efdf 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -741,8 +741,9 @@ function structural_simplify(sys::OptimizationSystem; split = true, kwargs...) nlsys = NonlinearSystem(econs, unknowns(sys), parameters(sys); name = :___tmp_nlsystem) snlsys = structural_simplify(nlsys; fully_determined = false, kwargs...) obs = observed(snlsys) - subs = Dict(eq.lhs => eq.rhs for eq in observed(snlsys)) seqs = equations(snlsys) + trueobs, _ = unhack_observed(obs, seqs) + subs = Dict(eq.lhs => eq.rhs for eq in trueobs) cons_simplified = similar(cons, length(icons) + length(seqs)) for (i, eq) in enumerate(Iterators.flatten((seqs, icons))) cons_simplified[i] = fixpoint_sub(eq, subs) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index c3d2a0e831..00742f028c 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -29,7 +29,7 @@ the default behavior). function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, t0 = nothing, substitution_limit = 1000, floatT = nothing, - p_constructor = identity) + p_constructor = identity, fast_path = false) ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -50,9 +50,15 @@ function MTKParameters( is_time_dependent(sys) && add_observed!(sys, u0) add_parameter_dependencies!(sys, p) - op, missing_unknowns, missing_pars = build_operating_point!(sys, - u0, p, defs, cmap, dvs, ps) - + u0map = anydict() + pmap = anydict() + if fast_path + missing_pars = missingvars(p, ps) + op = p + else + op, _, missing_pars = build_operating_point!(sys, + u0, p, defs, cmap, dvs, ps) + end if t0 !== nothing op[get_iv(sys)] = t0 end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 5ff00b4845..9d3b3b5a3e 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -450,8 +450,11 @@ in `varmap`, it is ignored. """ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) for k in vars + v = get(varmap, k, nothing) + v === nothing && continue + symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue haskey(varmap, k) || continue - varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit) + varmap[k] = fixpoint_sub(v, varmap; maxiters = limit) end end @@ -580,15 +583,19 @@ function build_operating_point!(sys::AbstractSystem, end end - for k in keys(u0map) - v = fixpoint_sub(u0map[k], neithermap; operator = Symbolics.Operator) - isequal(k, v) && continue - u0map[k] = v - end - for k in keys(pmap) - v = fixpoint_sub(pmap[k], neithermap; operator = Symbolics.Operator) - isequal(k, v) && continue - pmap[k] = v + if !isempty(neithermap) + for (k, v) in u0map + symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue + v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator) + isequal(k, v) && continue + u0map[k] = v + end + for (k, v) in pmap + symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue + v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator) + isequal(k, v) && continue + pmap[k] = v + end end return op, missing_unknowns, missing_pars @@ -617,37 +624,50 @@ struct ReconstructInitializeprob{GP, GU} ugetter::GU end +""" + $(TYPEDEF) + +A wrapper over an observed function which allows calling it on a problem-like object. +`TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if +`false`). +""" +struct ObservedWrapper{TD, F} + f::F +end + +ObservedWrapper{TD}(f::F) where {TD, F} = ObservedWrapper{TD, F}(f) + +function (ow::ObservedWrapper{true})(prob) + # Edge case for steady state problems + t = applicable(current_time, prob) ? current_time(prob) : Inf + ow.f(state_values(prob), parameter_values(prob), t) +end + +function (ow::ObservedWrapper{false})(prob) + ow.f(state_values(prob), parameter_values(prob)) +end + """ $(TYPEDSIGNATURES) Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter -function by splitting `syms` into contiguous buffers where the getter of each buffer -is type-stable and constructing a function that calls and concatenates the results. -""" -function concrete_getu(indp, syms::AbstractVector) - # a list of contiguous buffer - split_syms = [Any[syms[1]]] - # the type of the getter of the last buffer - current = typeof(getu(indp, syms[1])) - for sym in syms[2:end] - getter = getu(indp, sym) - if typeof(getter) != current - # if types don't match, build a new buffer - push!(split_syms, []) - current = typeof(getter) - end - push!(split_syms[end], sym) - end - split_syms = Tuple(split_syms) - # the getter is now type-stable, and we can vcat it to get the full buffer - return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms) +function. + +Note that the getter ONLY works for problem-like objects, since it generates an observed +function. It does NOT work for solutions. +""" +Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector) + @nospecialize + obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false) + return ObservedWrapper{is_time_dependent(indp)}(obsfn) end """ $(TYPEDEF) A callable struct which applies `p_constructor` to possibly nested arrays. It also -ensures that views (including nested ones) are concretized. +ensures that views (including nested ones) are concretized. This is implemented manually +of using `narrow_buffer_type` to preserve type-stability. """ struct PConstructorApplicator{F} p_constructor::F @@ -657,10 +677,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray) pca.p_constructor(x) end +function (pca::PConstructorApplicator)(x::AbstractArray{Bool}) + pca.p_constructor(BitArray(x)) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray) collect(x) end +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool}) + BitArray(x) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray}) collect(pca.(x)) end @@ -683,6 +711,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; initials = false, unwrap_initials = false, p_constructor = identity) + _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the # fields of `MTKParameters` except caches. @@ -736,14 +765,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[3]) end - rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf - if buf == () - return Returns(()) - else - return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) - end + const_getter = if syms[4] == () + Returns(()) + else + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[4]) end - getters = (tunable_getter, initials_getter, discs_getter, rest_getters...) + nonnumeric_getter = if syms[5] == () + Returns(()) + else + ic = get_index_cache(dstsys) + buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize + Vector{bufsize.type} + end) + # nonnumerics retain the assigned buffer type without narrowing + Base.Fix1(broadcast, _p_constructor) ∘ + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + end + getters = ( + tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) getter = let getters = getters function _getter(valp, initprob) oldcache = parameter_values(initprob).caches @@ -756,6 +795,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac return getter end +function call(f, args...) + f(args...) +end + """ $(TYPEDSIGNATURES) @@ -926,7 +969,7 @@ end $(TYPEDEF) A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`. -Returns the value to use for the `u0` of the problem. +Returns the value to use for the `u0` of the problem. # Fields @@ -1000,6 +1043,9 @@ function (siu::SetInitialUnknowns)(p::AbstractVector, u0) return p end +safe_float(x) = x +safe_float(x::AbstractArray) = isempty(x) ? x : float(x) + """ $(TYPEDSIGNATURES) @@ -1064,7 +1110,8 @@ function maybe_build_initialization_problem( if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys)) - initializeprobmap = u0_constructor ∘ getu(initializeprob, solved_unknowns) + initializeprobmap = u0_constructor ∘ safe_float ∘ + getu(initializeprob, solved_unknowns) else initializeprobmap = nothing end @@ -1087,20 +1134,24 @@ function maybe_build_initialization_problem( update_initializeprob! = ModelingToolkit.update_initializeprob! end - for p in punknowns - is_parameter_solvable(p, pmap, defs, guesses) || continue - get(op, p, missing) === missing || continue + filter!(punknowns) do p + is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing + end + pvals = getu(initializeprob, punknowns)(initializeprob) + for (p, pval) in zip(punknowns, pvals) p = unwrap(p) - op[p] = getu(initializeprob, p)(initializeprob) + op[p] = pval if iscall(p) && operation(p) === getindex arrp = arguments(p)[1] + get(op, arrp, nothing) !== missing && continue op[arrp] = collect(arrp) end end if is_time_dependent(sys) - for v in missing_unknowns - op[v] = getu(initializeprob, v)(initializeprob) + uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob) + for (v, val) in zip(missing_unknowns, uvals) + op[v] = val end empty!(missing_unknowns) end @@ -1124,7 +1175,7 @@ function float_type_from_varmap(varmap, floatT = Bool) if v isa AbstractArray floatT = promote_type(floatT, eltype(v)) - elseif v isa Real + elseif v isa Number floatT = promote_type(floatT, typeof(v)) end end @@ -1335,7 +1386,7 @@ function process_SciMLProblem( if !(pType <: AbstractArray) pType = Array end - p = MTKParameters(sys, op; floatT = floatT, p_constructor) + p = MTKParameters(sys, op; floatT = floatT, p_constructor, fast_path = true) else p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType)) end @@ -1396,7 +1447,7 @@ function check_inputmap_keys(sys, u0map, pmap) end const BAD_KEY_MESSAGE = """ - Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. + Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. The following keys are invalid: """ diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 52f93afb9b..aa7ccce917 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -1,3 +1,5 @@ +const System = AbstractODESystem + function System(eqs::AbstractVector{<:Equation}, iv, args...; name = nothing, kw...) ODESystem(eqs, iv, args...; name, kw..., checks = false) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index e0feb0d34d..48214e01a4 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -208,12 +208,19 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} structure::SystemStructure extra_eqs::Vector param_derivative_map::Dict{BasicSymbolic, Any} + original_eqs::Vector{Equation} + """ + Additional user-provided observed equations. The variables calculated here + are not used in the rest of the system. + """ + additional_observed::Vector{Equation} end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}) eqs = equations(ts) @set! ts.sys.eqs = eqs[ieqs] + @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) ts end @@ -266,6 +273,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) iv = length(ivs) == 1 ? ivs[1] : nothing # scalarize array equations, without scalarizing arguments to registered functions eqs = flatten_equations(copy(equations(sys))) + original_eqs = copy(eqs) neqs = length(eqs) dervaridxs = OrderedSet{Int}() var2idx = Dict{Any, Int}() @@ -378,6 +386,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) end end eqs = eqs[eqs_to_retain] + original_eqs = original_eqs[eqs_to_retain] neqs = length(eqs) symbolic_incidence = symbolic_incidence[eqs_to_retain] @@ -386,6 +395,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # depending on order due to NP-completeness of tearing. sortidxs = Base.sortperm(eqs, by = string) eqs = eqs[sortidxs] + original_eqs = original_eqs[sortidxs] symbolic_incidence = symbolic_incidence[sortidxs] end @@ -475,13 +485,116 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem), - Any[], param_derivative_map) + Any[], param_derivative_map, original_eqs, Equation[]) if sys isa DiscreteSystem ts = shift_discrete_system(ts) end return ts end +""" + $(TYPEDSIGNATURES) + +Preemptively identify observed equations in the system and tear them. This happens before +any simplification. The equations torn by this process are ones that are already given in +an explicit form in the system and where the LHS is not present in any other equation of +the system except for other such preempitvely torn equations. +""" +function trivial_tearing!(ts::TearingState) + @assert length(ts.original_eqs) == length(equations(ts)) + # equations that can be trivially torn an observed equations + trivial_idxs = BitSet() + # equations to never check + blacklist = BitSet() + torn_eqs = Equation[] + # variables that have been matched to trivially torn equations + matched_vars = BitSet() + # variable to index in fullvars + var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars)) + + complete!(ts.structure) + var_to_diff = ts.structure.var_to_diff + graph = ts.structure.graph + while true + # track whether we added an equation to the trivial list this iteration + added_equation = false + for (i, eq) in enumerate(ts.original_eqs) + # don't check already torn equations + i in trivial_idxs && continue + i in blacklist && continue + # ensure it is an observed equation matched to a variable in fullvars + vari = get(var_to_idx, eq.lhs, 0) + iszero(vari) && continue + # don't tear irreducible variables + if isirreducible(eq.lhs) + push!(blacklist, i) + continue + end + # if a variable was the LHS of two trivial observed equations, we wouldn't have + # included it in the list. Error if somehow it made it through. + @assert !(vari in matched_vars) + # don't tear differential/shift equations (or differentiated/shifted variables) + var_to_diff[vari] === nothing || continue + invview(var_to_diff)[vari] === nothing || continue + # get the equations that the candidate matched variable is present in, except + # those equations which have already been torn as observed + eqidxs = setdiff(𝑑neighbors(graph, vari), trivial_idxs) + # it should only be present in this equation + length(eqidxs) == 1 || continue + eqi = only(eqidxs) + @assert eqi == i + + # for every variable present in this equation, make sure it isn't _only_ + # present in trivial equations + isvalid = true + for v in 𝑠neighbors(graph, eqi) + v == vari && continue + v in matched_vars && continue + # `> 1` and not `0` because one entry will be this equation (`eqi`) + isvalid &= count(!in(trivial_idxs), 𝑑neighbors(graph, v)) > 1 + isvalid || break + end + isvalid || continue + # skip if the LHS is present in the RHS, since then this isn't explicit + if occursin(eq.lhs, eq.rhs) + push!(blacklist, i) + continue + end + + added_equation = true + push!(trivial_idxs, eqi) + push!(torn_eqs, eq) + push!(matched_vars, vari) + end + + # if we didn't add an equation this iteration, we won't add one next iteration + added_equation || break + end + + deleteat!(var_to_diff.primal_to_diff, matched_vars) + deleteat!(var_to_diff.diff_to_primal, matched_vars) + deleteat!(ts.structure.eq_to_diff.primal_to_diff, trivial_idxs) + deleteat!(ts.structure.eq_to_diff.diff_to_primal, trivial_idxs) + delete_srcs!(ts.structure.graph, trivial_idxs; rm_verts = true) + delete_dsts!(ts.structure.graph, matched_vars; rm_verts = true) + if ts.structure.solvable_graph !== nothing + delete_srcs!(ts.structure.solvable_graph, trivial_idxs; rm_verts = true) + delete_dsts!(ts.structure.solvable_graph, matched_vars; rm_verts = true) + end + if ts.structure.var_types !== nothing + deleteat!(ts.structure.var_types, matched_vars) + end + deleteat!(ts.fullvars, matched_vars) + deleteat!(ts.original_eqs, trivial_idxs) + ts.additional_observed = torn_eqs + sys = ts.sys + eqs = copy(get_eqs(sys)) + deleteat!(eqs, trivial_idxs) + @set! sys.eqs = eqs + ts.sys = sys + return ts +end + function lower_order_var(dervar, t) if isdifferential(dervar) diffvar = arguments(dervar)[1] @@ -739,6 +852,7 @@ function _structural_simplify!(state::TearingState, io; simplify = false, else input_idxs = 0:-1 # Empty range end + trivial_tearing!(state) sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...) if check_consistency fully_determined = ModelingToolkit.check_consistency( diff --git a/src/utils.jl b/src/utils.jl index 90431a7749..6fa574ba35 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1083,7 +1083,7 @@ Keyword arguments: `available_vars` will not be searched for in the observed equations. """ function observed_equations_used_by(sys::AbstractSystem, exprs; - involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys), available_vars = []) + involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = []) obsvars = getproperty.(obs, :lhs) graph = observed_dependency_graph(obs) if !(available_vars isa Set) diff --git a/test/code_generation.jl b/test/code_generation.jl index cf3d660b81..b48e4acf64 100644 --- a/test/code_generation.jl +++ b/test/code_generation.jl @@ -78,3 +78,34 @@ end @test SciMLBase.successful_retcode(sol) end end + +@testset "scalarized array observed calling same function multiple times" begin + @variables x(t) y(t)[1:2] + @parameters foo(::Real)[1:2] + val = Ref(0) + function _tmp_fn2(x) + val[] += 1 + return [x, 2x] + end + @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) + @test length(equations(sys)) == 1 + @test length(ModelingToolkit.observed(sys)) == 3 + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) + val[] = 0 + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 1 + + @testset "CSE in equations(sys)" begin + val[] = 0 + @variables z(t)[1:2] + @mtkbuild sys = ODESystem( + [D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t) + @test length(equations(sys)) == 5 + @test length(ModelingToolkit.observed(sys)) == 0 + prob = ODEProblem( + sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2]) + val[] = 0 + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 2 + end +end diff --git a/test/complex.jl b/test/complex.jl index 69cc22c985..04be8e4dac 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,4 +1,5 @@ using ModelingToolkit +using OrdinaryDiffEq using ModelingToolkit: t_nounits as t using Test @@ -14,3 +15,30 @@ using Test end @named mixed = ComplexModel() @test length(equations(mixed)) == 2 + +@testset "Complex ODEProblem" begin + using ModelingToolkit: t_nounits as t, D_nounits as D + + vars = @variables x(t) y(t) z(t) + pars = @parameters a b + + eqs = [ + D(x) ~ y - x, + D(y) ~ -x * z + b * abs(z), + D(z) ~ x * y - a + ] + @named modlorenz = System(eqs, t) + sys = structural_simplify(modlorenz) + + ic = ModelingToolkit.get_index_cache(sys) + @test ic.tunable_buffer_size.type == Number + + u0 = ComplexF64[-4.0, 5.0, 0.0] .+ randn(ComplexF64, 3) + p = ComplexF64[5.0, 0.1] + dict = merge(Dict(unknowns(sys) .=> u0), Dict(parameters(sys) .=> p)) + prob = ODEProblem(sys, dict, (0.0, 1.0)) + + sol = solve(prob, Tsit5(), saveat = 0.1) + + @test sol.u[1] isa Vector{ComplexF64} +end diff --git a/test/extensions/Project.toml b/test/extensions/Project.toml index 9f43e6f4a4..f91231fe05 100644 --- a/test/extensions/Project.toml +++ b/test/extensions/Project.toml @@ -1,3 +1,6 @@ +[sources] +ModelingToolkit = {path = "../.."} + [deps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f" @@ -23,6 +26,7 @@ OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 14649b6bb6..b5e71bf0af 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -8,6 +8,7 @@ using OrdinaryDiffEqNonlinearSolve using NonlinearSolve using SciMLSensitivity using ForwardDiff +using StableRNGs using ChainRulesCore using ChainRulesCore: NoTangent using ChainRulesTestUtils: test_rrule, rand_tangent @@ -135,3 +136,46 @@ end prob[sys.x] end end + +@testset "`p` provided to `solve` is respected" begin + @mtkmodel Linear begin + @variables begin + x(t) = 1.0, [description = "Prey"] + end + @parameters begin + α = 1.5 + end + @equations begin + D(x) ~ -α * x + end + end + + @mtkbuild linear = Linear() + problem = ODEProblem(linear, [], (0.0, 1.0)) + solution = solve(problem, Tsit5(), saveat = 0.1) + rng = StableRNG(42) + data = (; + t = solution.t, + # [[y, x], :] + measurements = Array(solution) + ) + data.measurements .+= 0.05 * randn(rng, size(data.measurements)) + + p0, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), problem.p) + + objective = let repack = repack, problem = problem + (p, data) -> begin + pnew = repack(p) + sol = solve(problem, Tsit5(), p = pnew, saveat = data.t) + sum(abs2, sol .- data.measurements) / size(data.t, 1) + end + end + + # Check 0.0031677344878386607 + @test_nowarn objective(p0, data) + + fd = ForwardDiff.gradient(Base.Fix2(objective, data), p0) + zg = Zygote.gradient(Base.Fix2(objective, data), p0) + + @test fd≈zg[1] atol=1e-6 +end diff --git a/test/initial_values.jl b/test/initial_values.jl index 0ed8f7bffe..f66d103cc9 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -362,3 +362,13 @@ end @test state_values(initdata.initializeprob) isa SVector @test parameter_values(initdata.initializeprob) isa SVector end + +@testset "Type promotion of `p` works with non-dual types" begin + @variables x(t) y(t) + @mtkbuild sys = ODESystem([D(x) ~ x + y, x^3 + y^3 ~ 5], t; guesses = [y => 1.0]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) + prob2 = remake(prob; u0 = BigFloat.(prob.u0)) + @test prob2.p.initials isa Vector{BigFloat} + sol = solve(prob2) + @test SciMLBase.successful_retcode(sol) +end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 7c512d37af..a7c6f4a27e 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1253,9 +1253,9 @@ end @test init(prob3)[x] ≈ 1.0 prob4 = remake(prob; p = [p => 1.0]) test_dummy_initialization_equation(prob4, x) - prob5 = remake(prob; p = [p => missing, q => 2.0]) + prob5 = remake(prob; p = [p => missing, q => 4.0]) @test prob5.f.initialization_data !== nothing - @test init(prob5).ps[p] ≈ 1.0 + @test init(prob5).ps[p] ≈ 2.0 end @testset "Variables provided as symbols" begin @@ -1650,3 +1650,44 @@ end @test !SciMLBase.isinplace(prob) @test !SciMLBase.isinplace(prob.f.initialization_data.initializeprob) end + +@testset "Array unknowns occurring unscalarized in initializeprobpmap" begin + @variables begin + u(t)[1:2] = 0.9ones(2) + x(t)[1:2], [guess = 0.01ones(2)] + o(t)[1:2] + end + @parameters p[1:4] = [2.0, 1.875, 2.0, 1.875] + + eqs = [D(u[1]) ~ p[1] * u[1] - p[2] * u[1] * u[2] + x[1] + 0.1 + D(u[2]) ~ p[4] * u[1] * u[2] - p[3] * u[2] - x[2] + o[1] ~ sum(p) * sum(u) + o[2] ~ sum(p) * sum(x) + x[1] ~ 0.01exp(-1) + x[2] ~ 0.01cos(t)] + + @mtkbuild sys = ODESystem(eqs, t) + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Tsit5()) + @test SciMLBase.successful_retcode(sol) +end + +@testset "Nonnumerics aren't narrowed" begin + @mtkmodel Foo begin + @variables begin + x(t) = 1.0 + end + @parameters begin + p::AbstractString + r = 1.0 + end + @equations begin + D(x) ~ r * x + end + end + @mtkbuild sys = Foo(p = "a") + prob = ODEProblem(sys, [], (0.0, 1.0)) + @test prob.p.nonnumeric[1] isa Vector{AbstractString} + integ = init(prob) + @test integ.p.nonnumeric[1] isa Vector{AbstractString} +end diff --git a/test/odesystem.jl b/test/odesystem.jl index 3220555e62..dcda766177 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -5,10 +5,14 @@ using OrdinaryDiffEq, Sundials using DiffEqBase, SparseArrays using StaticArrays using Test -using SymbolicUtils: issym +using SymbolicUtils.Code +using SymbolicUtils: Sym, issym using ForwardDiff using ModelingToolkit: value using ModelingToolkit: t_nounits as t, D_nounits as D +using Symbolics +using Symbolics: unwrap +using DiffEqBase: isinplace # Define some variables @parameters σ ρ β @@ -607,13 +611,6 @@ sys = complete(sys) @test_throws Any ODEFunction(sys) @testset "Preface tests" begin - using OrdinaryDiffEq - using Symbolics - using DiffEqBase: isinplace - using ModelingToolkit - using SymbolicUtils.Code - using SymbolicUtils: Sym - c = [0] function f(c, du::AbstractVector{Float64}, u::AbstractVector{Float64}, p, t::Float64) c .= [c[1] + 1] @@ -656,7 +653,9 @@ sys = complete(sys) @named sys = ODESystem(eqs, t, us, ps; defaults = defs, preface = preface) sys = complete(sys) - prob = ODEProblem(sys, [], (0.0, 1.0)) + # don't build initializeprob because it will use preface in other functions and + # affect `c` + prob = ODEProblem(sys, [], (0.0, 1.0); build_initializeprob = false) sol = solve(prob, Euler(); dt = 0.1) @test c[1] == length(sol) @@ -1755,3 +1754,61 @@ end sol = solve(prob, Tsit5()) @test SciMLBase.successful_retcode(sol) end + +@testset "`full_equations` doesn't recurse infinitely" begin + code = """ + using ModelingToolkit + using ModelingToolkit: t_nounits as t, D_nounits as D + @variables x(t)[1:3]=[0,0,1] + @variables u1(t)=0 u2(t)=0 + y₁, y₂, y₃ = x + k₁, k₂, k₃ = 1,1,1 + eqs = [ + D(y₁) ~ -k₁*y₁ + k₃*y₂*y₃ + u1 + D(y₂) ~ k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2 + u2 + y₁ + y₂ + y₃ ~ 1 + ] + + @named sys = ODESystem(eqs, t) + + inputs = [u1, u2] + outputs = [y₁, y₂, y₃] + ss, = structural_simplify(sys, (inputs, [])) + full_equations(ss) + """ + + cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code` + proc = run(cmd, stdin, stdout, stderr; wait = false) + sleep(120) + @test !process_running(proc) + kill(proc, Base.SIGKILL) +end + +@testset "`substitute` retains events and metadata" begin + @parameters p(t) = 1.0 + @variables x(t) = 0.0 + event = [0.5] => [p ~ t] + event2 = [x ~ 0.75] => [p ~ 2 * t] + + eq = [ + D(x) ~ p + ] + @named sys = ODESystem(eq, t, [x], [p], discrete_events = [event], + continuous_events = [event2], metadata = "TEST") + + @variables x2(t) = 0.0 + sys2 = substitute(sys, [x => x2]) + + @test length(ModelingToolkit.get_discrete_events(sys)) == 1 + @test length(ModelingToolkit.get_discrete_events(sys2)) == 1 + @test length(ModelingToolkit.get_continuous_events(sys)) == 1 + @test length(ModelingToolkit.get_continuous_events(sys2)) == 1 + @test ModelingToolkit.get_metadata(sys) == "TEST" + @test ModelingToolkit.get_metadata(sys2) == "TEST" +end + +@testset "`System` works as a type" begin + @variables x(t) + @named sys = System([D(x) ~ 2x], t) + @test sys isa System +end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 67cf2f72a0..24fb166755 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -51,7 +51,7 @@ end @mtkbuild sys = ODESystem( [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) @test length(equations(sys)) == 1 - @test length(observed(sys)) == 7 + @test length(observed(sys)) == 6 @test any(obs -> isequal(obs, y), observables(sys)) @test any(obs -> isequal(obs, z), observables(sys)) prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn]) @@ -61,60 +61,11 @@ end @test length(unknowns(isys)) == 5 @test length(equations(isys)) == 4 @test !any(equations(isys)) do eq - iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, - StructuralTransformations.change_origin] + iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.change_origin] end end -@testset "scalarized array observed calling same function multiple times" begin - @variables x(t) y(t)[1:2] - @parameters foo(::Real)[1:2] - val = Ref(0) - function _tmp_fn2(x) - val[] += 1 - return [x, 2x] - end - @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) - @test length(equations(sys)) == 1 - @test length(observed(sys)) == 4 - prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) - val[] = 0 - @test_nowarn prob.f(prob.u0, prob.p, 0.0) - @test val[] == 1 - - isys = ModelingToolkit.generate_initializesystem(sys) - @test length(unknowns(isys)) == 3 - @test length(equations(isys)) == 2 - @test !any(equations(isys)) do eq - iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, - StructuralTransformations.change_origin] - end - - @testset "CSE hack in equations(sys)" begin - val[] = 0 - @variables z(t)[1:2] - @mtkbuild sys = ODESystem( - [D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t) - @test length(equations(sys)) == 5 - @test length(observed(sys)) == 2 - prob = ODEProblem( - sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2]) - val[] = 0 - @test_nowarn prob.f(prob.u0, prob.p, 0.0) - @test val[] == 2 - - isys = ModelingToolkit.generate_initializesystem(sys) - @test length(unknowns(isys)) == 5 - @test length(equations(isys)) == 2 - @test !any(equations(isys)) do eq - iscall(eq.rhs) && - operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, - StructuralTransformations.change_origin] - end - end -end - -@testset "array and cse hacks can be disabled" begin +@testset "array hack can be disabled" begin @testset "fully_determined = true" begin @variables x(t) y(t)[1:2] z(t)[1:2] @parameters foo(::AbstractVector)[1:2] @@ -122,15 +73,8 @@ end @named sys = ODESystem( [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) - sys1 = structural_simplify(sys; cse_hack = false) - @test length(observed(sys1)) == 6 - @test !any(observed(sys1)) do eq - iscall(eq.rhs) && - operation(eq.rhs) == StructuralTransformations.getindex_wrapper - end - sys2 = structural_simplify(sys; array_hack = false) - @test length(observed(sys2)) == 5 + @test length(observed(sys2)) == 4 @test !any(observed(sys2)) do eq iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin end @@ -143,15 +87,8 @@ end @named sys = ODESystem( [D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) - sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false) - @test length(observed(sys1)) == 6 - @test !any(observed(sys1)) do eq - iscall(eq.rhs) && - operation(eq.rhs) == StructuralTransformations.getindex_wrapper - end - sys2 = structural_simplify(sys; array_hack = false, fully_determined = false) - @test length(observed(sys2)) == 5 + @test length(observed(sys2)) == 4 @test !any(observed(sys2)) do eq iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin end