Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d83139f

Browse files
committed
Revise training process for FlowOverCircle
1 parent f5a9017 commit d83139f

File tree

4 files changed

+58
-78
lines changed

4 files changed

+58
-78
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ docs/site/
2525
# environment.
2626
Manifest.toml
2727

28-
*.jld2
28+
*.bson

example/FlowOverCircle/Project.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ name = "FlowOverCircle"
22
uuid = "1fc04e5d-1dd1-42ff-8d75-1d53504b2476"
33

44
[deps]
5+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
56
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
67
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7-
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
8+
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
911
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1012
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1113
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
14+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1215
WaterLily = "ed894a53-35f9-47f1-b17f-85db9237eebd"
1316

1417
[extras]
+53-33
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
11
module FlowOverCircle
22

3-
using NeuralOperators
4-
using Flux
5-
using CUDA
6-
using JLD2
3+
using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
4+
using NeuralOperators, Flux
5+
using CUDA, FluxTraining, BSON
76

8-
include("data.jl")
7+
function circle(n, m; Re=250) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
8+
# Set physical parameters
9+
U, R, center = 1., m/8., [m/2, m/2]
10+
ν = U * R / Re
911

10-
function update_model!(model_file_path, model)
11-
model = cpu(model)
12-
jldsave(model_file_path; model)
13-
@warn "model updated!"
12+
body = AutoBody((x,t) -> LinearAlgebra.norm2(x .- center) - R)
13+
Simulation((n+2, m+2), [U, 0.], R; ν, body)
14+
end
15+
16+
function gen_data(ts::AbstractRange)
17+
@info "gen data... "
18+
p = Progress(length(ts))
19+
20+
n, m = 3(2^5), 2^6
21+
circ = circle(n, m)
22+
23+
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
24+
for (i, t) in enumerate(ts)
25+
sim_step!(circ, t)
26+
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
27+
28+
next!(p)
29+
end
30+
31+
return 𝐩s
32+
end
33+
34+
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100)
35+
data = gen_data(ts)
36+
data_train, data_test = splitobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end]), at=ratio)
37+
38+
loader_train = Flux.DataLoader(data_train, batchsize=batchsize, shuffle=true)
39+
loader_test = Flux.DataLoader(data_test, batchsize=batchsize, shuffle=false)
40+
41+
return loader_train, loader_test
1442
end
1543

1644
function train()
@@ -22,42 +50,34 @@ function train()
2250
device = cpu
2351
end
2452

25-
m = Chain(
53+
model = Chain(
2654
Dense(1, 64),
2755
OperatorKernel(64=>64, (24, 24), FourierTransform, gelu),
2856
OperatorKernel(64=>64, (24, 24), FourierTransform, gelu),
2957
OperatorKernel(64=>64, (24, 24), FourierTransform, gelu),
3058
OperatorKernel(64=>64, (24, 24), FourierTransform, gelu),
3159
Dense(64, 1),
32-
) |> device
60+
)
61+
data = get_dataloader()
62+
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
63+
loss_func = l₂loss
3364

34-
loss(𝐱, 𝐲) = l₂loss(m(𝐱), 𝐲)
65+
learner = Learner(
66+
model, data, optimiser, loss_func,
67+
ToDevice(device, device),
68+
Checkpointer(joinpath(@__DIR__, "../model/"))
69+
)
3570

36-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
37-
38-
@info "gen data... "
39-
@time loader_train, loader_test = get_dataloader()
71+
fit!(learner, 50)
4072

41-
losses = Float32[]
42-
function validate()
43-
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
44-
@info "loss: $validation_loss"
45-
46-
push!(losses, validation_loss)
47-
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
48-
end
49-
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
50-
51-
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52-
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
73+
return learner
5374
end
5475

5576
function get_model()
56-
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
57-
model = f["model"]
58-
close(f)
77+
model_path = joinpath(@__DIR__, "../model/")
78+
model_file = readdir(model_path)[end]
5979

60-
return model
80+
return BSON.load(joinpath(model_path, model_file), @__MODULE__)[:model]
6181
end
6282

63-
end
83+
end # module

example/FlowOverCircle/src/data.jl

-43
This file was deleted.

0 commit comments

Comments
 (0)