diff --git a/src/symbolic.jl b/src/symbolic.jl index 35fa41d..36150e4 100644 --- a/src/symbolic.jl +++ b/src/symbolic.jl @@ -129,6 +129,29 @@ function mul_numeric_args(args) (prod, sym_args) end +cancel_common(num, den) = (num, den) +cancel_common(num::Symbol, den::Symbol) = num==den ? (1, 1) : (num, den) +cancel_common(num::Expr, den::Symbol) = cancel_common(num, Expr(:call, :*, den)) +cancel_common(num::Symbol, den::Expr) = cancel_common(Expr(:call, :*, num), den) + +function cancel_common(num::Expr, den::Expr) + if num.args[1] != :* || den.args[1] != :* + return (num, den) + end + an = num.args[2:end] + ad = den.args[2:end] + i = 1 + while i <= length(ad) + idx = findfirst(an, ad[i]) + if idx != 0 + splice!(ad, i) + splice!(an, idx) + end + i += 1 + end + (Expr(:call, :*, an...), Expr(:call, :*, ad...)) +end + # Handle `args` of all lengths function simplify(::SymbolParameter{:+}, args) # Remove any 0's in a sum @@ -188,6 +211,8 @@ function simplify(::SymbolParameter{:/}, args) elseif args[1] == 0 return 0 else + args = cancel_common(args[1], args[2]) + args = map(simplify, args) return Expr(:call, :/, args...) end end diff --git a/test/symbolic.jl b/test/symbolic.jl index d6fae32..5678b25 100644 --- a/test/symbolic.jl +++ b/test/symbolic.jl @@ -88,3 +88,11 @@ end @test isequal(simplify(:(x*3)), :(*(3,x))) @test isequal(simplify(:(x*3*4)), :(*(12,x))) @test isequal(simplify(:(2*y*x*3)), :(*(6,y,x))) + +@test isequal(simplify(:(x/x)), 1.0) +@test isequal(simplify(:((2*x)/x)), 2.0) +@test isequal(simplify(:(x/(2*x))), 0.5) +@test isequal(simplify(:((2*y)/x)), :(/(*(2,y),x))) +@test isequal(simplify(:((5*x)/(2*x))), 2.5) +@test isequal(simplify(:((y*x*5)/(2*x))), :(5*y/2)) +@test isequal(simplify(:((5*z)/(2*x))), :((5*z)/(2*x)))