Skip to content

update() method for updating a fitted model with new data #2308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
DominiqueMakowski opened this issue Aug 24, 2024 · 12 comments
Closed

update() method for updating a fitted model with new data #2308

DominiqueMakowski opened this issue Aug 24, 2024 · 12 comments
Assignees

Comments

@DominiqueMakowski
Copy link
Contributor

DominiqueMakowski commented Aug 24, 2024

Currently to generate predictions one has to refit the model on missing data, which requires having access to the model object.

It would be quite convenient to be able to update the data of a fitted model using update() (à-la-R), which would allow more flexibility (my use case is that I'm running and saving models locally, and then running some predictions in another step, and currently I need to save the model, the fitted version and the posteriors which is a bit cumbersome).

Would that make sense in Turing? Thanks!


Related, from #2309

  • Is it possible to extract the model object/method from the fitted object? In other words, as far as I understand, a Turing model is often defined as a function (which is hard to serialize), which gets turned into a dynamicPPL object through the @model macro. Can we recover/reconstruct that object from the fitted version?
@DominiqueMakowski
Copy link
Contributor Author

Additionally, when I try to save the model using JLD2 and then load it, it throws a warning:

┌ Warning: type Main.#model_LogNormal does not exist in workspace; reconstructing
└ @ JLD2 C:\Users\domma\.julia\packages\JLD2\twZ5D\src\data\reconstructing_datatypes.jl:492
┌ Warning: some parameters could not be resolved for type DynamicPPL.Model{Main.#model_LogNormal,(:rt,),(:min_rt, :isi),(),Tuple{Vector{Float64}},Tuple{Float64, Vector{Float64}},DynamicPPL.DefaultContext}; reconstructing
└ @ JLD2 C:\Users\domma\.julia\packages\JLD2\twZ5D\src\data\reconstructing_datatypes.jl:617

and then errors when using it (the model is of type Reconstruct):

julia> pred = predict(fit([missing for i in 1:nrow(df)]; min_rt=minimum(df.RT), isi=df.ISI), posteriors)
ERROR: MethodError: objects of type JLD2.ReconstructedStatic{Symbol("DynamicPPL.Model{#model_LogNormal,(:rt,),(:min_rt, :isi),(),Tuple{Vector{Float64}},Tuple{Float64, Vector{Float64}},DynamicPPL.DefaultContext}"), (:args, :defaults), Tuple{@NamedTuple{rt::Vector{Float64}}, @NamedTuple{min_rt::Float64, isi::Vector{Float64}}}} are not callable
Stacktrace:
 [1] top-level scope
   @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\2_models_comparison.jl:44

julia> model
Reconstruct@#model_LogNormal()

@penelopeysm
Copy link
Member

Hey @DominiqueMakowski – sorry it's been a bit quiet in this issue. There is too much work for all of us to do 😅

Just reading #2309 as well. The predict function (docs here) takes both a model (which is conditioned on some data) as well as a chain. The chain is what really represents the 'fitted' model: the model object itself does contain information about the data, but it doesn't contain any 'learned' info (such as parameters) until it's sampled from.

So if you want to be able to run the sampling and the predictions in separate files, you need to serialise both the model as well as the chain itself.

I think the workflow you're looking for could be something like this, perhaps?

# 1. Define a model
@model function f(data)
    ...
end

#2. 'Fit' the model
model = f(data)

#3. Sample from it
chain = sample(model, NUTS(), 1000)

#4. Save both the fitted model and the chain to disk
save(model, "model")
save(chain, "chain")

And then in a different file:

# 1. Load the fitted model and the chain
model = load(model)
chain = load(chain)

# 2. 'Refit' to some other data
new_model = update(model, new_data)

# 3. Run predictions
preds = predict(new_model, chain)

The predict docs have an example of this workflow, just minus the save and load steps, where the 'original' data is a training set and the 'new' data is a test set.

I know the pseudocode doesn't look quite so different from what you suggested and it might seem like I'm splitting hairs, but I just wanted to be clear about the exact interface we want to work towards :)

@penelopeysm
Copy link
Member

penelopeysm commented Oct 14, 2024

Ah, just writing that all out made me realise that I've basically repeated what you've already said. In that case, I guess you should consider the above comment as a request for you to check that I have correctly understood it, rather than a clarification for you :)

@DominiqueMakowski
Copy link
Contributor Author

The issue is that it currently doesn't work if the model function is not available in the workspace, for instance if the model is loaded and stored within a dictionary or a list, which is unfortunate given that it seems like all the necessary information should be contained within the "fitted" model object (by that I mean the model function applied on data)

Basically can we remove the need to keep depending on the original model function and pull the necessary information from the fitted model when running predictions?

@torfjelde
Copy link
Member

The issue is that it currently doesn't work if the model function is not available in the workspace, for instance if the model is loaded and stored within a dictionary or a list, which is unfortunate given that it seems like all the necessary information should be contained within the "fitted" model object (by that I mean the model function applied on data)

This isn't possible in the way you describe, unfortunately 😕 Models are Julia code, and they can depend on arbitrary Julia code. This is, arguably, one of the selling points of Turing.jl:) But this means that you cannot save a Turing.jl model to disk without also being explicit about its dependencies. It miiiight be possible to something like compile the model into a binary and re-use that once static compilation becomes more of a thing in Julia, but highly doubt this is possible at this moment.

However, it's "easy enough" to set up a custom solution that does what you want by just putting each model and what it needs into it's own file and then call include(...). You could then save the path to the file containing the model definition in the chain and then include that as you need.

@penelopeysm
Copy link
Member

@torfjelde I was thinking it might still be useful to expose a convenience function that does exactly that custom solution / defines a particular way to package the model together with the underlying methods? 😄

@DominiqueMakowski
Copy link
Contributor Author

to expose a convenience function that does exactly that

For real though, making Turing model easily shareable and reusable is critical IMO, especially in like applied contexts where users want to use a model without having to "train" it. I wonder how (if?) other ML Julia packages handle that

@torfjelde
Copy link
Member

I was thinking it might still be useful to expose a convenience function that does exactly that custom solution / defines a particular way to package the model together with the underlying methods? 😄

But this seems difficult to do without something like static compilation, no? The include approach I mentioned doesn't (AFAIK) really require any particular methods (other than include)?

@DominiqueMakowski
Copy link
Contributor Author

I've given it a go with the include option, and it might work, but I am facing an obstacle and I'm not sure how to best address it:

  1. I have saved all my model definitions inside a models.jl file that I'm including in the other files to have access to the functions. One of the model being for instance named model_Gaussian.

  2. I sample them in another file sampling_and_saving.jl, and then I save the serialized info:

out = Dict("fit" => fit, "posteriors" => posteriors)
Serialization.serialize(filename * ".turing", out)

fit being the output of model_Gaussian(y=y, x=x) (i.e., the model function called on the data)

  1. In a third file analysis.jl:
  • I include("models.jl") to have access to all the model functions
  • I load the serialized object:
gaussian = Serialization.deserialize("models/Gaussian.turing")
exgaussian = Serialization.deserialize("models/ExGaussian.turing")
lognormal = Serialization.deserialize("models/LogNormal.turing")

objects = (Gaussian=gaussian, ExGaussian=exgaussian, LogNormal=lognormal)

The Problem

In this analysis file, I would like to programmatically generate posterior predictive checks figures. In the past (when I thought I could save the model definition and serialize it as a "model" key), I created a function like:

function make_ppcheck(data, info, figure)
    # Compute predictions
    pred = predict(info["model"]([missing for i in 1:nrow(data)], info["posteriors"])

    # Code to make a figure
    ...
end

And then loop using:

f = Figure(size=(1200, 1200))
for (i, k) in enumerate(keys(objects))
    f = make_ppcheck(data, objects[k], f)
end
f

Using the new workflow that was suggested however, I am not sure how to code the function to programatically use the model definition that exists (because it was included in the environment). One option I thought was to save (serialized) the model name (e.g., "model_Gaussian") as a string and then somehow parse this string as code, but it feels hacky and clunky.

I still do think that having an update() method would make things easier as I could call update(info["fit"], newdata)

TLDR

Any suggestions on how to use predict() on new data programmatically (i.e., in a loop, on an arbitrary list of models stored in a list)?

@torfjelde
Copy link
Member

One option I thought was to save (serialized) the model name (e.g., "model_Gaussian") as a string and then somehow parse this string as code, but it feels hacky and clunky.

Slightly less clunky way: have a list of the relevant models in models.jl, e.g.

AVAILABLE_MODELS = Dict{String,Any}()

# After definition of `model_Gaussian`, you do
AVAILABLE_MODELS["model_Gaussian"] = model_Gaussian

Or something like this. I'd recommend this over eval, since this allows you to change the model, etc. in a backwards compatible way, e.g. if you decide to rename it and you have chains with that name alreaduy stored, you can just keep around that entry pointing "model_Gaussian" to the new model instance.

I still do think that having an update() method would make things easier as I could call update(info["fit"], newdata)

I think everyone agrees that this would be very nice:) It's just that we're talking about serializing code, which itself have dependencies, etc., and so it's quite a non-trivial thing to do in a way actually works well.

@DominiqueMakowski
Copy link
Contributor Author

I managed to make it work, thanks so much @torfjelde for your help, much appreciated!

@penelopeysm
Copy link
Member

penelopeysm commented Oct 30, 2024

One option I thought was to save (serialized) the model name (e.g., "model_Gaussian") as a string and then somehow parse this string as code

I'm not claiming this is the best code ever, but this works. I'm not sure if steps 4/5 could be turned into a macro. eval is unsafe, but then again any form of deserialisation is intrinsically unsafe (you need to trust the thing you're deserialising) so I don't think eval adds any extra insecurity.

using Turing

# 1. Define a model
@model function hello()
    x ~ Normal(0, 1)
    y ~ Normal(x, 1)
end

#2. Fit the model
model = condition(hello(), y=1.5)

#3. Sample from it
chain = sample(model, MH(), 10)

#4. Serialize literally everything, including the name of the function, so that we
# can call `eval` later. I know I'm not getting any presents from Santa this year :')
target_file = "model.tar"

# These lines could maybe be extracted out
using Serialization: serialize
import Tar
tmp_dir = mktempdir()
serialize("$tmp_dir/methods", methods(model.f))
serialize("$tmp_dir/model", model)
serialize("$tmp_dir/chain", chain)
serialize("$tmp_dir/name", string(Base.nameof(model.f)))
Tar.create(tmp_dir, target_file)

Then reopen a new Julia session and do this:

#5. Start a new instance of Julia
# Load the model and chain back from disk
using Turing
target_file = "model.tar"

# These lines could maybe be extracted out
import Tar
using Serialization: deserialize
tmp_dir = mktempdir()
Tar.extract(target_file, tmp_dir)
__function_name_ = deserialize("$tmp_dir/name")
eval(Expr(:function, Symbol(__function_name_)))
model = deserialize("$tmp_dir/model")
deserialize("$tmp_dir/methods")
chain = deserialize("$tmp_dir/chain")

#6. Then predict
new_model = condition(model, y=missing)
predict(new_model, chain)

Note that, in order to use condition(model, y=missing) (second-last line) and have the predictions come out correctly you need to have defined the model as

@model function hello()
    x ~ Normal(0, 1)
    y ~ Normal(x, 1)
end
model = condition(hello(), y=1.5)

instead of the more 'traditional'

@model function hello(y)
    x ~ Normal(0, 1)
    y ~ Normal(x, 1)
end
model = hello(1.5)

We are trying to move towards the former syntax anyway, so this isn't a bad thing.


Here are the corresponding macros – I reckon they can definitely be improved, but it's a proof of concept 😄

import Tar
import Serialization

macro save_model(target_file, model)
    return quote
        tmp_dir = mktempdir()
        $(Serialization.serialize)(tmp_dir * "/methods", methods($model.f))
        $(Serialization.serialize)(tmp_dir * "/model", $model)
        $(Serialization.serialize)(tmp_dir * "/name", string(Base.nameof($model.f)))
        $(Tar.create)(tmp_dir, $target_file)
    end
end

macro load_model(target_file)
    return quote begin
        tmp_dir = Base.mktempdir()
        $(Tar.extract)($target_file, tmp_dir)
        __function_name_ = $(Serialization.deserialize)("$tmp_dir/name")
        eval(Expr(:function, Symbol(__function_name_)))
        $(Serialization.deserialize)("$tmp_dir/methods")
        $(Serialization.deserialize)("$tmp_dir/model")
    end
    end
end

Usage:

# ... (repeat steps 1-3 from above)
@save_model "model2.tar" model
Serialization.serialize("chain2", chain)

# restart session

model = @load_model "model2.tar"
chain = Serialization.deserialize("chain2")
new_model = condition(model, y=missing)
predict(new_model, chain)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants