diff --git a/src/ODE.jl b/src/ODE.jl index f3b0d957b..8ea5ba99f 100644 --- a/src/ODE.jl +++ b/src/ODE.jl @@ -43,11 +43,8 @@ export ode23, ode4, ode45, ode4s, ode4ms # Initialize variables. # Adapted from Cleve Moler's textbook # http://www.mathworks.com/moler/ncm/ode23tx.m -function ode23(F::Function, tspan::AbstractVector, y_0::AbstractVector) - - rtol = 1.e-5 - atol = 1.e-8 +function ode23(F, tspan, y_0; rtol=1e-5, atol=1e-8, vecnorm=norm) t0 = tspan[1] tfinal = tspan[end] tdir = sign(tfinal - t0) @@ -64,7 +61,7 @@ function ode23(F::Function, tspan::AbstractVector, y_0::AbstractVector) # Compute initial step size. s1 = F(t, y) - r = norm(s1./max(abs(y), threshold), Inf) + realmin() # TODO: fix type bug in max() + r = vecnorm(s1./max(abs(y), threshold), Inf) + realmin() # TODO: fix type bug in max() h = tdir*0.8*rtol^(1/3)/r # The main loop. @@ -92,7 +89,7 @@ function ode23(F::Function, tspan::AbstractVector, y_0::AbstractVector) # Estimate the error. e = h*(-5*s1 + 6*s2 + 8*s3 - 9*s4)/72 - err = norm(e./max(max(abs(y), abs(ynew)), threshold), Inf) + realmin() + err = vecnorm(e./max(max(abs(y), abs(ynew)), threshold), Inf) + realmin() # Accept the solution if the estimated error is less than the tolerance. @@ -179,9 +176,8 @@ end # ode23 # created : 06 October 1999 # modified: 17 January 2001 -function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a, b4, b5) - tol = 1.0e-5 - + +function oderkf(F, tspan, x0, a, b4, b5; tol=1.0e-5, norm=Base.norm) # see p.91 in the Ascher & Petzold reference for more infomation. pow = 1/6; @@ -274,7 +270,7 @@ const dp_coefficients = ([ 0 0 0 0 0 # 5th order b-coefficients [35/384 0 500/1113 125/192 -2187/6784 11/84 0], ) -ode45_dp(F, tspan, x0) = oderkf(F, tspan, x0, dp_coefficients...) +ode45_dp(F, tspan, x0; tol=1.0e-5, norm=Base.norm) = oderkf(F, tspan, x0, dp_coefficients...; tol=1.0e-5, norm=norm) # Fehlberg coefficients const fb_coefficients = ([ 0 0 0 0 0 @@ -288,7 +284,7 @@ const fb_coefficients = ([ 0 0 0 0 0 # 5th order b-coefficients [16/135 0 6656/12825 28561/56430 -9/50 2/55], ) -ode45_fb(F, tspan, x0) = oderkf(F, tspan, x0, fb_coefficients...) +ode45_fb(F, tspan, x0; tol=1.0e-5, norm=Base.norm) = oderkf(F, tspan, x0, fb_coefficients...; tol=1.0e-5, norm=norm) # Cash-Karp coefficients # Numerical Recipes in Fortran 77 @@ -303,10 +299,12 @@ const ck_coefficients = ([ 0 0 0 0 0 # 5th order b-coefficients [2825/27648 0 18575/48384 13525/55296 277/14336 1/4], ) -ode45_ck(F, tspan, x0) = oderkf(F, tspan, x0, ck_coefficients...) +ode45_ck(F, tspan, x0; tol=1.0e-5, norm=Base.norm) = oderkf(F, tspan, x0, ck_coefficients...; tol=1.0e-5, norm=norm) # Use Dormand Prince version of ode45 by default -const ode45 = ode45_dp +function ode45(F, tspan, x0; tol=1e-5, norm=Base.norm) + return ode45_dp(F, tspan, x0; tol=tol, norm=norm) +end #ODE4 Solve non-stiff differential equations, fourth order # fixed-step Runge-Kutta method. @@ -317,7 +315,12 @@ const ode45 = ode45_dp # ODEFUN(T,X) must return a column vector corresponding to f(t,x). Each # row in the solution array X corresponds to a time returned in the # column vector T. -function ode4{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}) +function ode4(F::Function, tspan::AbstractVector, x0::Number) + tout, yout = ode4(F, tspan, [x0]) + return tout, yout[1,:] +end + +function ode4{T<:Number}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}) h = diff(tspan) x = Array(T, (length(tspan), length(x0))) x[1,:] = x0' @@ -338,13 +341,18 @@ end #ODEROSENBROCK Solve stiff differential equations, Rosenbrock method # with provided coefficients. -function oderosenbrock{T}(F::Function, G::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c) +function oderosenbrock(F::Function, G::Function, tspan::AbstractVector, x0::Number, gamma, a, b, c) + tout, yout = oderosenbrock(F, G, tspan, [x0], gamma, a, b, c) + return tout, yout[1,:] +end + +function oderosenbrock{T<:Number}(F::Function, G::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c) h = diff(tspan) x = Array(T, length(tspan), length(x0)) x[1,:] = x0' solstep = 1 - while tspan[solstep] < max(tspan) + while tspan[solstep] < maximum(tspan) ts = tspan[solstep] hs = h[solstep] xs = reshape(x[solstep,:], size(x0)) @@ -363,7 +371,12 @@ function oderosenbrock{T}(F::Function, G::Function, tspan::AbstractVector, x0::A return (tspan, x) end -function oderosenbrock{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c) +function oderosenbrock(F::Function, tspan::AbstractVector, x0::Number, gamma, a, b, c) + tout, yout = oderosenbrock(F, tspan, [x0], gamma, a, b, c) + return tout, yout[1,:] +end + +function oderosenbrock{T<:Number}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c) # Crude forward finite differences estimator as fallback function jacobian(F::Function, t::Number, x::AbstractVector) ftx = F(t, x) @@ -410,10 +423,21 @@ ode4s_s(F, tspan, x0) = oderosenbrock(F, tspan, x0, s4_coefficients...) ode4s_s(F, G, tspan, x0) = oderosenbrock(F, G, tspan, x0, s4_coefficients...) # Use Shampine coefficients by default (matching Numerical Recipes) -const ode4s = ode4s_s +function ode4s(F::Function, tspan::AbstractVector, x0::Number) + return ode4s_s(F, tspan, x0) +end +function ode4s{T<:Number}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}) + return ode4s_s(F, tspan, x0) +end # ODE_MS Fixed-step, fixed-order multi-step numerical method with Adams-Bashforth-Moulton coefficients -function ode_ms{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, order::Integer) + +function ode_ms(F::Function, tspan::AbstractVector, x0::Number, order::Integer) + tout, yout = ode_ms(F, tspan, [x0], order) + return tout, yout[1,:] +end + +function ode_ms{T<:Number}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, order::Integer) h = diff(tspan) x = zeros(T,(length(tspan), length(x0))) x[1,:] = x0 @@ -445,6 +469,14 @@ function ode_ms{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, or return (tspan, x) end -ode4ms(F, tspan, x0) = ode_ms(F, tspan, x0, 4) +function ode4ms(F::Function, tspan::AbstractVector, x0::Number) + return ode_ms(F, tspan, x0, 4) +end + +function ode4ms{T<:Number}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}) + return ode_ms(F, tspan, x0, 4) +end + +include("ODEObj.jl") end # module ODE diff --git a/src/ODEObj.jl b/src/ODEObj.jl new file mode 100644 index 000000000..79c4dd673 --- /dev/null +++ b/src/ODEObj.jl @@ -0,0 +1,26 @@ +export ode + +type ODESolver{T<:Number} + # The solver function with which to produce a solution + solver::Function + # The function f in y' = f(t,y) + F::Function + # The solution domain + tspan::AbstractVector + # Initial value for the solver + x0::Union(T,AbstractVector{T}) + + solve::Function + +end + +ODESolver(solver::Function, F::Function, tspan::AbstractVector, x0::Number) = ODESolver(solver, F, tspan, x0, () -> solver(F, tspan, x0)) +ODESolver{T}(solver::Function, F::Function, tspan::AbstractVector, x0::AbstractVector{T}) = ODESolver(solver, F, tspan, x0, () -> solver(F, tspan, x0)) + +function ode{T<:Number}(solver::Symbol, F::Function, tspan::AbstractVector, x0::Union(T,AbstractVector{T})) + if !contains((:ode23, :ode4, :ode45, :ode4s, :ode4ms), solver) + error("$solver is not one of the valid ODE solvers") + end + + ode = ODESolver(eval(solver), F, tspan, x0) +end diff --git a/test/test_ode.jl b/test/test_ode.jl index 79db049a8..a44afc844 100644 --- a/test/test_ode.jl +++ b/test/test_ode.jl @@ -1,38 +1,42 @@ using ODE -using Test +using Base.Test tol = 1e-2 solvers = [ ODE.ode23, ODE.ode4, + ODE.ode45, ODE.ode45_dp, ODE.ode45_fb, ODE.ode45_ck, ODE.ode4ms, ODE.ode4s_s, - ODE.ode4s_kr] + ODE.ode4s_kr + ] for solver in solvers - println("using $solver") + println("Testing $solver...") # dy # -- = 6 ==> y = 6t # dt t,y=solver((t,y)->6, [0:.1:1], [0.]) - @test max(abs(y-6t)) < tol + @test maximum(abs(y-6t)) < tol # dy # -- = 2t ==> y = t.^2 # dt t,y=solver((t,y)->2t, [0:.001:1], [0.]) - @test max(abs(y-t.^2)) < tol + @test maximum(abs(y-t.^2)) < tol # dy # -- = y ==> y = y0*e.^t # dt t,y=solver((t,y)->y, [0:.001:1], [1.]) - @test max(abs(y-e.^t)) < tol + @test maximum(abs(y-e.^t)) < tol + + end