Skip to content

Change allow_symbolic to default to true #3481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e7bb024
Change allow_symbolic to default to true
ChrisRackauckas Mar 21, 2025
83ec981
test: update some tests to account for `use_symbolic = true`
AayushSabharwal Mar 24, 2025
1049bad
feat: add `denominators` field to `SystemStructure`
AayushSabharwal Mar 27, 2025
cc8fa2b
feat: populate `state.structure.denominators` in `find_eq_solvables!`
AayushSabharwal Mar 27, 2025
a0fc653
feat: add `allow_algebraic`, default `allow_symbolic` to false
AayushSabharwal Mar 27, 2025
ae0c750
Revert "test: update some tests to account for `use_symbolic = true`"
AayushSabharwal Mar 27, 2025
828bcdb
feat: propagate and implement `allow_algebraic`
AayushSabharwal Mar 27, 2025
8a304b3
test: update tests
AayushSabharwal Mar 27, 2025
b20fc7f
fix: account for `allow_algebraic` in `partial_state_selection_graph!`
AayushSabharwal Mar 28, 2025
c7e923d
fix: remove CSE hack
AayushSabharwal Apr 9, 2025
77bea8a
feat: add `remove_denominators`
AayushSabharwal Apr 9, 2025
148d3dd
refactor: remove denominators in `generate_function` and `calculate_j…
AayushSabharwal Apr 9, 2025
f5d037e
refactor: remove denominators in `NonlinearSystem`
AayushSabharwal Apr 9, 2025
e7a6245
test: update tests to account for removal of CSE hack
AayushSabharwal Apr 9, 2025
8ce87ea
test: update test to account for `allow_algebraic`
AayushSabharwal Apr 9, 2025
4a7b12d
feat: allow `full_equations` on a singular system
AayushSabharwal Apr 9, 2025
c988e2b
refactor: default `allow_algebraic` to `fully_determined`
AayushSabharwal Apr 9, 2025
8f8bf39
refactor: format
AayushSabharwal Apr 9, 2025
8a177ff
fixup! refactor: default `allow_algebraic` to `fully_determined`
AayushSabharwal Apr 9, 2025
3423435
fix: fix usage of `make_differential_denominators_unsolvable!`
AayushSabharwal Apr 11, 2025
f019275
fix: don't accept `allow_algebraic` as a kwarg to `InitializationProb…
AayushSabharwal Apr 11, 2025
38f2b0a
fix: fix defaulting `allow_algebraic = fully_determined`
AayushSabharwal Apr 11, 2025
0022f19
test: account for `allow_algebraic` in tests
AayushSabharwal Apr 11, 2025
76d9045
test: account for `allow_algebraic` in HC tests
AayushSabharwal Apr 11, 2025
ad8730c
fix: better handle `allow_algebraic=false`
AayushSabharwal Apr 11, 2025
66034d6
fix: check edge existence before removing
AayushSabharwal Apr 11, 2025
25917bd
test: account for `allow_algebraic` in tests
AayushSabharwal Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ export structural_simplify, expand_connections, linearize, linearization_functio
LinearizationProblem
export solve

export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function, generate_W
export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function,
generate_W
export calculate_control_jacobian, generate_control_jacobian
export calculate_tgrad, generate_tgrad
export calculate_gradient, generate_gradient
Expand Down
22 changes: 15 additions & 7 deletions src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function partial_state_selection_graph!(state::TransformationState)
find_solvables!(state; allow_symbolic = true)
var_eq_matching = complete(pantelides!(state))
var_eq_matching = complete(pantelides!(state; allow_algebraic = false))
complete!(state.structure)
partial_state_selection_graph!(state.structure, var_eq_matching)
end
Expand Down Expand Up @@ -170,11 +170,14 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
end

function dummy_derivative_graph!(state::TransformationState, jac = nothing;
state_priority = nothing, log = Val(false), kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
state_priority = nothing, log = Val(false), allow_symbolic = false, allow_algebraic = true, kwargs...)
state.structure.solvable_graph === nothing &&
find_solvables!(state; allow_symbolic, allow_algebraic, kwargs...)
complete!(state.structure)
var_eq_matching = complete(pantelides!(state; kwargs...))
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log)
var_eq_matching = complete(pantelides!(
state; allow_symbolic, allow_algebraic, kwargs...))
dummy_derivative_graph!(
state.structure, var_eq_matching, jac, state_priority, log; allow_symbolic, allow_algebraic)
end

struct DummyDerivativeSummary
Expand All @@ -184,7 +187,8 @@ end

function dummy_derivative_graph!(
structure::SystemStructure, var_eq_matching, jac = nothing,
state_priority = nothing, ::Val{log} = Val(false)) where {log}
state_priority = nothing, ::Val{log} = Val(false); allow_symbolic = false,
allow_algebraic = true) where {log}
@unpack eq_to_diff, var_to_diff, graph = structure
diff_to_eq = invview(eq_to_diff)
diff_to_var = invview(var_to_diff)
Expand Down Expand Up @@ -342,7 +346,11 @@ function dummy_derivative_graph!(
@warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)."
end

ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives))
dummy_derivatives_set = BitSet(dummy_derivatives)
make_differential_denominators_unsolvable!(
structure, dummy_derivatives_set; allow_algebraic)

ret = tearing_with_dummy_derivatives(structure, dummy_derivatives_set)
if log
(ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority))
else
Expand Down
97 changes: 15 additions & 82 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ These equations matches generated numerical code.

See also [`equations`](@ref) and [`ModelingToolkit.get_eqs`](@ref).
"""
function full_equations(sys::AbstractSystem; simplify = false)
function full_equations(sys::AbstractSystem; simplify = false, allow_singular = false)
empty_substitutions(sys) && return equations(sys)
substitutions = get_substitutions(sys)
substitutions.subed_eqs === nothing || return substitutions.subed_eqs
Expand All @@ -119,7 +119,7 @@ function full_equations(sys::AbstractSystem; simplify = false)
eq = 0 ~ eq.rhs - eq.lhs
end
rhs = tearing_sub(eq.rhs, solved, simplify)
if rhs isa Symbolic
if rhs isa Symbolic || allow_singular
return 0 ~ rhs
else # a number
error("tearing failed because the system is singular")
Expand Down Expand Up @@ -708,7 +708,7 @@ Update the system equations, unknowns, and observables after simplification.
"""
function update_simplified_system!(
state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
cse_hack = true, array_hack = true)
array_hack = true)
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
diff_to_var = invview(var_to_diff)

Expand All @@ -732,8 +732,8 @@ function update_simplified_system!(
unknowns = [unknowns; extra_unknowns]
@set! sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
obs, subeqs, deps = array_var_hack(
sys, obs, solved_eqs, unknowns, neweqs; array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs
Expand Down Expand Up @@ -775,7 +775,7 @@ appear in the system. Algebraic variables are variables that are not
differential variables.
"""
function tearing_reassemble(state::TearingState, var_eq_matching,
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
full_var_eq_matching = nothing; simplify = false, mm = nothing, array_hack = true)
extra_vars = Int[]
if full_var_eq_matching !== nothing
for v in 𝑑vertices(state.structure.graph)
Expand Down Expand Up @@ -811,21 +811,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)

sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
extra_unknowns; cse_hack, array_hack)
extra_unknowns; array_hack)

@set! state.sys = sys
@set! sys.tearing_state = state
return invalidate_cache!(sys)
end

"""
# HACK 1

Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
avoid the unnecessary cost. This and the below hack are implemented simultaneously

# HACK 2

Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
Expand All @@ -834,12 +827,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
not) we first count the number of times the scalarized form of each observed variable
occurs in observed equations (and unknowns if it's split).
"""
function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true)
# HACK 1
# mapping of rhs to temporary CSE variable
# `f(...) => tmpvar` in above example
rhs_to_tempvar = Dict()

function array_var_hack(sys, obs, subeqs, unknowns, neweqs; array = true)
# HACK 2
# map of array observed variable (unscalarized) to number of its
# scalarized terms that appear in observed equations
Expand All @@ -851,36 +839,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
rhs = eq.rhs
vars!(all_vars, rhs)

# HACK 1
if cse && is_getindexed_array(rhs)
rhs_arr = arguments(rhs)[1]
iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
N = length(rhs_arr)
tempvar = unwrap(Symbolics.variable(
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
push!(subeqs, tempeq)
end

# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
# which fails the topological sort
neweq = lhs ~ getindex_wrapper(
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
obs[i] = neweq
subeqi = findfirst(isequal(eq), subeqs)
if subeqi !== nothing
subeqs[subeqi] = neweq
end
end
# end HACK 1

array || continue
iscall(lhs) || continue
operation(lhs) === getindex || continue
Expand All @@ -891,33 +849,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
continue
end

# Also do CSE for `equations(sys)`
if cse
for (i, eq) in enumerate(neweqs)
(; lhs, rhs) = eq
is_getindexed_array(rhs) || continue
rhs_arr = arguments(rhs)[1]
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
N = length(rhs_arr)
tempvar = unwrap(Symbolics.variable(
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
vars!(all_vars, rhs_arr)
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
push!(subeqs, tempeq)
end
# don't need getindex_wrapper, but do it anyway to know that this
# hack took place
neweq = lhs ~ getindex_wrapper(
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
neweqs[i] = neweq
end
end

# count variables in unknowns if they are scalarized forms of variables
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
# is an observed equation.
Expand Down Expand Up @@ -995,6 +926,8 @@ end
function tearing(state::TearingState; kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
make_differential_denominators_unsolvable!(
state.structure; allow_algebraic = get(kwargs, :allow_algebraic, true))
tearing_with_dummy_derivatives(state.structure, ())
end

Expand All @@ -1006,10 +939,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
instead, which calls this function internally.
"""
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
simplify = false, cse_hack = true, array_hack = true, kwargs...)
var_eq_matching, full_var_eq_matching = tearing(state)
simplify = false, array_hack = true, kwargs...)
var_eq_matching, full_var_eq_matching = tearing(state; kwargs...)
invalidate_cache!(tearing_reassemble(
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
end

"""
Expand All @@ -1031,7 +964,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
the system is balanced.
"""
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
mm = nothing, array_hack = true, kwargs...)
jac = let state = state
(eqs, vars) -> begin
symeqs = EquationsView(state)[eqs]
Expand All @@ -1055,5 +988,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
end
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
kwargs...)
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
tearing_reassemble(state, var_eq_matching; simplify, mm, array_hack)
end
40 changes: 33 additions & 7 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ end
### Structural and symbolic utilities
###

function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = nothing;
function find_eq_solvables!(state::TearingState, ieq::Int, to_rm = Int[], coeffs = nothing;
may_be_zero = false,
allow_symbolic = false, allow_parameter = true,
allow_symbolic = false, allow_parameter = true, allow_algebraic = true,
conservative = false,
kwargs...)
fullvars = state.fullvars
Expand All @@ -218,6 +218,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
all_int_vars = true
coeffs === nothing || empty!(coeffs)
empty!(to_rm)
varsbuf = Set()
for j in 𝑠neighbors(graph, ieq)
var = fullvars[j]
isirreducible(var) && (all_int_vars = false; continue)
Expand All @@ -229,13 +230,18 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
if a isa Symbolic
all_int_vars = false
if !allow_symbolic
if allow_parameter
all(
x -> ModelingToolkit.isparameter(x) || ModelingToolkit.isconstant(x),
vars(a)) || continue
else
allow_parameter || allow_algebraic || continue
empty!(varsbuf)
vars!(varsbuf, a)
denomvars = Int[]
for v in varsbuf
idx = findfirst(isequal(v), fullvars)
idx === nothing || push!(denomvars, idx)
end
if !allow_algebraic && !isempty(denomvars)
continue
end
state.structure.denominators[ieq => j] = denomvars
end
add_edge!(solvable_graph, ieq, j)
continue
Expand Down Expand Up @@ -269,6 +275,26 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
all_int_vars, term
end

"""
$(TYPEDSIGNATURES)

Remove edges in `structure.solvable_graph` that require differential variables in the
denominator to solve. `additional_algevars` is a collection of integers corresponding to
differential variables that should be considered as algebraic for the purpose of this
transformation.
"""
function make_differential_denominators_unsolvable!(
structure::SystemStructure, additional_algevars = (); allow_algebraic)
for ((eqi, vari), denoms) in structure.denominators
if allow_algebraic &&
all(i -> isalgvar(structure, i) || i in additional_algevars, denoms) ||
!has_edge(structure.solvable_graph, BipartiteEdge(eqi, vari))
continue
end
rem_edge!(structure.solvable_graph, eqi, vari)
end
end

function find_solvables!(state::TearingState; kwargs...)
@assert state.structure.solvable_graph === nothing
eqs = equations(state)
Expand Down
Loading
Loading