diff --git a/ext/MTKFMIExt.jl b/ext/MTKFMIExt.jl index 5cfe9a82ef..2e9fa70de2 100644 --- a/ext/MTKFMIExt.jl +++ b/ext/MTKFMIExt.jl @@ -261,7 +261,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6, # use `ImperativeAffect` for instance management here cb_observed = (; inputs = __mtk_internal_x, params = copy(params), - t, wrapper, dt = communication_step_size) + t, wrapper) cb_modified = (;) # modify the outputs if present if symbolic_type(__mtk_internal_o) != NotSymbolic() @@ -272,11 +272,12 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6, cb_modified = (cb_modified..., states = __mtk_internal_u) end initialize_affect = MTK.ImperativeAffect(fmiCSInitialize!; observed = cb_observed, - modified = cb_modified, ctx = _functor) + modified = cb_modified, ctx = (_functor, communication_step_size)) finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], []) # the callback affect performs the stepping step_affect = MTK.ImperativeAffect( - fmiCSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor) + fmiCSStep!; observed = cb_observed, modified = cb_modified, + ctx = (_functor, communication_step_size)) instance_management_callback = MTK.SymbolicDiscreteCallback( communication_step_size, step_affect; initialize = initialize_affect, finalize = finalize_affect, reinitializealg = reinitializealg @@ -775,7 +776,8 @@ the value being the output vector if the FMU has output variables. `o` should co Initializes the FMU. Only for use with CoSimulation FMUs. """ -function fmiCSInitialize!(m, o, ctx::FMI2CSFunctor, integrator) +function fmiCSInitialize!(m, o, ctx::Tuple{FMI2CSFunctor, Vararg}, integrator) + functor, dt = ctx states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params @@ -787,10 +789,10 @@ function fmiCSInitialize!(m, o, ctx::FMI2CSFunctor, integrator) instance = get_instance_CS!(wrapper, states, inputs, params, t) if isdefined(m, :states) - @statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi2GetReal!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi2GetReal!(instance, functor.output_value_references, m.outputs) end return m @@ -804,13 +806,13 @@ periodically to communicte with the CoSimulation FMU. Has the same requirements `fmiCSInitialize!` for `m` and `o`, with the addition that `o` should have a key `:dt` with the value being the communication step size. """ -function fmiCSStep!(m, o, ctx::FMI2CSFunctor, integrator) +function fmiCSStep!(m, o, ctx::Tuple{FMI2CSFunctor, Vararg}, integrator) + functor, dt = ctx wrapper = o.wrapper states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params t = o.t - dt = o.dt instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t) if !isempty(inputs) @@ -820,10 +822,10 @@ function fmiCSStep!(m, o, ctx::FMI2CSFunctor, integrator) @statuscheck FMI.fmi2DoStep(instance, integrator.t - dt, dt, FMI.fmi2True) if isdefined(m, :states) - @statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi2GetReal!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi2GetReal!(instance, functor.output_value_references, m.outputs) end return m @@ -874,7 +876,8 @@ end """ $(TYPEDSIGNATURES) """ -function fmiCSInitialize!(m, o, ctx::FMI3CSFunctor, integrator) +function fmiCSInitialize!(m, o, ctx::Tuple{FMI3CSFunctor, Vararg}, integrator) + functor, dt = ctx states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params @@ -885,10 +888,11 @@ function fmiCSInitialize!(m, o, ctx::FMI3CSFunctor, integrator) end instance = get_instance_CS!(wrapper, states, inputs, params, t) if isdefined(m, :states) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi3GetFloat64!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi3GetFloat64!( + instance, functor.output_value_references, m.outputs) end return m @@ -897,13 +901,13 @@ end """ $(TYPEDSIGNATURES) """ -function fmiCSStep!(m, o, ctx::FMI3CSFunctor, integrator) +function fmiCSStep!(m, o, ctx::Tuple{FMI3CSFunctor, Vararg}, integrator) + functor, dt = ctx wrapper = o.wrapper states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params t = o.t - dt = o.dt instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t) if !isempty(inputs) @@ -921,10 +925,11 @@ function fmiCSStep!(m, o, ctx::FMI3CSFunctor, integrator) @assert earlyReturn[] == FMI.fmi3False if isdefined(m, :states) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi3GetFloat64!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi3GetFloat64!( + instance, functor.output_value_references, m.outputs) end return m diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a58ff3f8ec..eb490a5417 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -970,6 +970,14 @@ end end) end +@generated function _generated_readback(integ, getters::NamedTuple{NS1, <:Tuple}) where {NS1} + getter_exprs = [] + for name in NS1 + push!(getter_exprs, :($name = getters.$name(integ))) + end + return :((; $(getter_exprs...))) +end + function check_assignable(sys, sym) if symbolic_type(sym) == ScalarSymbolic() is_variable(sys, sym) || is_parameter(sys, sym) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index b0742e70a7..c209c9931e 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -26,7 +26,7 @@ The NamedTuple returned from `f` includes the values to be written back to the s Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in `modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it -in the returned tuple, in which case the associated field will not be updated. +in the returned tuple, in which case the associated field will not be updated. To avoid writing back, either return `nothing` or an empty named tuple. """ @kwdef struct ImperativeAffect f::Any @@ -189,18 +189,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. else zeros(sz) end - obs_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(obs_exprs); - mkarray = (es, _) -> MakeTuple(es)) - obs_sym_tuple = (obs_syms...,) + geto_funs = NamedTuple{(obs_syms...,)}((getsym.((sys,), obs_exprs)...,)) # okay so now to generate the stuff to assign it back into the system + getm_funs = NamedTuple{(mod_syms...,)}((getsym.((sys,), mod_exprs)...,)) + mod_pairs = mod_exprs .=> mod_syms mod_names = (mod_syms...,) - mod_og_val_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(first.(mod_pairs)); - mkarray = (es, _) -> MakeTuple(es)) - upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing @@ -212,21 +207,20 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. let user_affect = func(affect), ctx = context(affect) function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - modvals = mod_og_val_fun(integ.u, integ.p, integ.t) - upd_component_array = NamedTuple{mod_names}(modvals) + upd_component_array = _generated_readback(integ, getm_funs) # update the observed values - obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun( - integ.u, integ.p, integ.t)) + obs_component_array = _generated_readback(integ, geto_funs) # let the user do their thing upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) - - for idx in save_idxs - SciMLBase.save_discretes!(integ, idx) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + for idx in save_idxs + SciMLBase.save_discretes!(integ, idx) + end end end end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 5c0a2ee7fc..a409d0c942 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1461,3 +1461,50 @@ end sys = structural_simplify(sys) sol = solve(ODEProblem(sys, [], (0.0, 1.0)), Tsit5()) end + +@testset "Tuples in ImperativeAffect arguments" begin + @mtkmodel ImperativeAffectTupleMWE begin + @parameters begin + y(t) = 1.0 + end + @variables begin + x(t) = 0.0 + end + @equations begin + D(x) ~ y + end + @continuous_events begin + (x ~ 0.5) => ModelingToolkit.ImperativeAffect( + observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i + return (; y = 2 * o.mypars[1] + o.mypars[2]) + end + end + end + @mtkbuild sys = ImperativeAffectTupleMWE() + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Tsit5()) +end + +@testset "ImperativeAffect skips writing back when nothing is returned" begin + @mtkmodel ImperativeAffectWriteNothingMWE begin + @parameters begin + y(t) = 1.0 + end + @variables begin + x(t) = 0.0 + end + @equations begin + D(x) ~ y + end + @continuous_events begin + (x ~ 0.5) => ModelingToolkit.ImperativeAffect( + observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i + return nothing + end + end + end + @mtkbuild sys = ImperativeAffectWriteNothingMWE() + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Tsit5()) + @test length(sol[sys.y]) == 1 +end