Skip to content

Use getsym instead of an explicitly generated function and avoid writeback if nothing is returned #3610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions ext/MTKFMIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,

# use `ImperativeAffect` for instance management here
cb_observed = (; inputs = __mtk_internal_x, params = copy(params),
t, wrapper, dt = communication_step_size)
t, wrapper)
cb_modified = (;)
# modify the outputs if present
if symbolic_type(__mtk_internal_o) != NotSymbolic()
Expand All @@ -272,11 +272,12 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
cb_modified = (cb_modified..., states = __mtk_internal_u)
end
initialize_affect = MTK.ImperativeAffect(fmiCSInitialize!; observed = cb_observed,
modified = cb_modified, ctx = _functor)
modified = cb_modified, ctx = (_functor, communication_step_size))
finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], [])
# the callback affect performs the stepping
step_affect = MTK.ImperativeAffect(
fmiCSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor)
fmiCSStep!; observed = cb_observed, modified = cb_modified,
ctx = (_functor, communication_step_size))
instance_management_callback = MTK.SymbolicDiscreteCallback(
communication_step_size, step_affect; initialize = initialize_affect,
finalize = finalize_affect, reinitializealg = reinitializealg
Expand Down Expand Up @@ -775,7 +776,8 @@ the value being the output vector if the FMU has output variables. `o` should co

Initializes the FMU. Only for use with CoSimulation FMUs.
"""
function fmiCSInitialize!(m, o, ctx::FMI2CSFunctor, integrator)
function fmiCSInitialize!(m, o, ctx::Tuple{FMI2CSFunctor, Vararg}, integrator)
functor, dt = ctx
states = isdefined(m, :states) ? m.states : ()
inputs = o.inputs
params = o.params
Expand All @@ -787,10 +789,10 @@ function fmiCSInitialize!(m, o, ctx::FMI2CSFunctor, integrator)

instance = get_instance_CS!(wrapper, states, inputs, params, t)
if isdefined(m, :states)
@statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states)
@statuscheck FMI.fmi2GetReal!(instance, functor.state_value_references, m.states)
end
if isdefined(m, :outputs)
@statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs)
@statuscheck FMI.fmi2GetReal!(instance, functor.output_value_references, m.outputs)
end

return m
Expand All @@ -804,13 +806,13 @@ periodically to communicte with the CoSimulation FMU. Has the same requirements
`fmiCSInitialize!` for `m` and `o`, with the addition that `o` should have a key
`:dt` with the value being the communication step size.
"""
function fmiCSStep!(m, o, ctx::FMI2CSFunctor, integrator)
function fmiCSStep!(m, o, ctx::Tuple{FMI2CSFunctor, Vararg}, integrator)
functor, dt = ctx
wrapper = o.wrapper
states = isdefined(m, :states) ? m.states : ()
inputs = o.inputs
params = o.params
t = o.t
dt = o.dt

instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t)
if !isempty(inputs)
Expand All @@ -820,10 +822,10 @@ function fmiCSStep!(m, o, ctx::FMI2CSFunctor, integrator)
@statuscheck FMI.fmi2DoStep(instance, integrator.t - dt, dt, FMI.fmi2True)

if isdefined(m, :states)
@statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states)
@statuscheck FMI.fmi2GetReal!(instance, functor.state_value_references, m.states)
end
if isdefined(m, :outputs)
@statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs)
@statuscheck FMI.fmi2GetReal!(instance, functor.output_value_references, m.outputs)
end

return m
Expand Down Expand Up @@ -874,7 +876,8 @@ end
"""
$(TYPEDSIGNATURES)
"""
function fmiCSInitialize!(m, o, ctx::FMI3CSFunctor, integrator)
function fmiCSInitialize!(m, o, ctx::Tuple{FMI3CSFunctor, Vararg}, integrator)
functor, dt = ctx
states = isdefined(m, :states) ? m.states : ()
inputs = o.inputs
params = o.params
Expand All @@ -885,10 +888,11 @@ function fmiCSInitialize!(m, o, ctx::FMI3CSFunctor, integrator)
end
instance = get_instance_CS!(wrapper, states, inputs, params, t)
if isdefined(m, :states)
@statuscheck FMI.fmi3GetFloat64!(instance, ctx.state_value_references, m.states)
@statuscheck FMI.fmi3GetFloat64!(instance, functor.state_value_references, m.states)
end
if isdefined(m, :outputs)
@statuscheck FMI.fmi3GetFloat64!(instance, ctx.output_value_references, m.outputs)
@statuscheck FMI.fmi3GetFloat64!(
instance, functor.output_value_references, m.outputs)
end

return m
Expand All @@ -897,13 +901,13 @@ end
"""
$(TYPEDSIGNATURES)
"""
function fmiCSStep!(m, o, ctx::FMI3CSFunctor, integrator)
function fmiCSStep!(m, o, ctx::Tuple{FMI3CSFunctor, Vararg}, integrator)
functor, dt = ctx
wrapper = o.wrapper
states = isdefined(m, :states) ? m.states : ()
inputs = o.inputs
params = o.params
t = o.t
dt = o.dt

instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t)
if !isempty(inputs)
Expand All @@ -921,10 +925,11 @@ function fmiCSStep!(m, o, ctx::FMI3CSFunctor, integrator)
@assert earlyReturn[] == FMI.fmi3False

if isdefined(m, :states)
@statuscheck FMI.fmi3GetFloat64!(instance, ctx.state_value_references, m.states)
@statuscheck FMI.fmi3GetFloat64!(instance, functor.state_value_references, m.states)
end
if isdefined(m, :outputs)
@statuscheck FMI.fmi3GetFloat64!(instance, ctx.output_value_references, m.outputs)
@statuscheck FMI.fmi3GetFloat64!(
instance, functor.output_value_references, m.outputs)
end

return m
Expand Down
8 changes: 8 additions & 0 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,14 @@ end
end)
end

@generated function _generated_readback(integ, getters::NamedTuple{NS1, <:Tuple}) where {NS1}
getter_exprs = []
for name in NS1
push!(getter_exprs, :($name = getters.$name(integ)))
end
return :((; $(getter_exprs...)))
end

function check_assignable(sys, sym)
if symbolic_type(sym) == ScalarSymbolic()
is_variable(sys, sym) || is_parameter(sys, sym)
Expand Down
28 changes: 11 additions & 17 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The NamedTuple returned from `f` includes the values to be written back to the s

Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
in the returned tuple, in which case the associated field will not be updated.
in the returned tuple, in which case the associated field will not be updated. To avoid writing back, either return `nothing` or an empty named tuple.
"""
@kwdef struct ImperativeAffect
f::Any
Expand Down Expand Up @@ -189,18 +189,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
else
zeros(sz)
end
obs_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(obs_exprs);
mkarray = (es, _) -> MakeTuple(es))
obs_sym_tuple = (obs_syms...,)
geto_funs = NamedTuple{(obs_syms...,)}((getsym.((sys,), obs_exprs)...,))

# okay so now to generate the stuff to assign it back into the system
getm_funs = NamedTuple{(mod_syms...,)}((getsym.((sys,), mod_exprs)...,))

mod_pairs = mod_exprs .=> mod_syms
mod_names = (mod_syms...,)
mod_og_val_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(first.(mod_pairs));
mkarray = (es, _) -> MakeTuple(es))

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
Expand All @@ -212,21 +207,20 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
let user_affect = func(affect), ctx = context(affect)
function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
upd_component_array = NamedTuple{mod_names}(modvals)
upd_component_array = _generated_readback(integ, getm_funs)

# update the observed values
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
integ.u, integ.p, integ.t))
obs_component_array = _generated_readback(integ, geto_funs)

# let the user do their thing
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)

for idx in save_idxs
SciMLBase.save_discretes!(integ, idx)
if !isnothing(upd_vals)
_generated_writeback(integ, upd_funs, upd_vals)
for idx in save_idxs
SciMLBase.save_discretes!(integ, idx)
end
end
end
end
Expand Down
47 changes: 47 additions & 0 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1461,3 +1461,50 @@ end
sys = structural_simplify(sys)
sol = solve(ODEProblem(sys, [], (0.0, 1.0)), Tsit5())
end

@testset "Tuples in ImperativeAffect arguments" begin
@mtkmodel ImperativeAffectTupleMWE begin
@parameters begin
y(t) = 1.0
end
@variables begin
x(t) = 0.0
end
@equations begin
D(x) ~ y
end
@continuous_events begin
(x ~ 0.5) => ModelingToolkit.ImperativeAffect(
observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i
return (; y = 2 * o.mypars[1] + o.mypars[2])
end
end
end
@mtkbuild sys = ImperativeAffectTupleMWE()
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Tsit5())
end

@testset "ImperativeAffect skips writing back when nothing is returned" begin
@mtkmodel ImperativeAffectWriteNothingMWE begin
@parameters begin
y(t) = 1.0
end
@variables begin
x(t) = 0.0
end
@equations begin
D(x) ~ y
end
@continuous_events begin
(x ~ 0.5) => ModelingToolkit.ImperativeAffect(
observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i
return nothing
end
end
end
@mtkbuild sys = ImperativeAffectWriteNothingMWE()
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Tsit5())
@test length(sol[sys.y]) == 1
end
Loading