Skip to content

Commit 4df8b02

Browse files
committed
asl: adding additional symbolic loss functions
This commit adds additional symbolic loss functions. These are loss functions analogous to those obtained from PDE equations or boundary equations, but the integrand can be given explicitly in symbolic form.
1 parent 5d01161 commit 4df8b02

File tree

8 files changed

+421
-135
lines changed

8 files changed

+421
-135
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2828
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2929
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
3030
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
31+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
3132
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
3233
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3334
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

src/adaptive_losses.jl

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ end
1515
"""
1616
```julia
1717
NonAdaptiveLoss{T}(; pde_loss_weights = 1,
18+
asl_loss_weights = 1,
1819
bc_loss_weights = 1,
1920
additional_loss_weights = 1)
2021
```
@@ -24,31 +25,34 @@ change during optimization
2425
"""
2526
mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
2627
pde_loss_weights::Vector{T}
28+
asl_loss_weights::Vector{T}
2729
bc_loss_weights::Vector{T}
2830
additional_loss_weights::Vector{T}
2931
SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1,
32+
asl_loss_weights = 1,
3033
bc_loss_weights = 1,
3134
additional_loss_weights = 1) where {
3235
T <:
3336
Real
3437
}
35-
new(vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
38+
new(vectorify(pde_loss_weights, T), vectorify(asl_loss_weights, T), vectorify(bc_loss_weights, T),
3639
vectorify(additional_loss_weights, T))
3740
end
3841
end
3942

4043
# default to Float64
41-
SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, bc_loss_weights = 1,
44+
SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, asl_loss_weights = 1, bc_loss_weights = 1,
4245
additional_loss_weights = 1)
4346
NonAdaptiveLoss{Float64}(; pde_loss_weights = pde_loss_weights,
47+
asl_loss_weights = asl_loss_weights,
4448
bc_loss_weights = bc_loss_weights,
4549
additional_loss_weights = additional_loss_weights)
4650
end
4751

4852
function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
4953
adaloss::NonAdaptiveLoss,
50-
pde_loss_functions, bc_loss_functions)
51-
function null_nonadaptive_loss(θ, pde_losses, bc_losses)
54+
pde_loss_functions, asl_loss_functions, bc_loss_functions)
55+
function null_nonadaptive_loss(θ, pde_loss, asl_loss, bc_losses)
5256
nothing
5357
end
5458
end
@@ -58,6 +62,7 @@ end
5862
GradientScaleAdaptiveLoss(reweight_every;
5963
weight_change_inertia = 0.9,
6064
pde_loss_weights = 1,
65+
asl_loss_weights = 1,
6166
bc_loss_weights = 1,
6267
additional_loss_weights = 1)
6368
```
@@ -90,30 +95,34 @@ mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
9095
reweight_every::Int64
9196
weight_change_inertia::T
9297
pde_loss_weights::Vector{T}
98+
asl_loss_weights::Vector{T}
9399
bc_loss_weights::Vector{T}
94100
additional_loss_weights::Vector{T}
95101
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T}(reweight_every;
96102
weight_change_inertia = 0.9,
97103
pde_loss_weights = 1,
104+
asl_loss_weights = 1,
98105
bc_loss_weights = 1,
99106
additional_loss_weights = 1) where {
100107
T <:
101108
Real
102109
}
103110
new(convert(Int64, reweight_every), convert(T, weight_change_inertia),
104-
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
105-
vectorify(additional_loss_weights, T))
111+
vectorify(pde_loss_weights, T), vectorify(asl_loss_weights, T),
112+
vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T))
106113
end
107114
end
108115
# default to Float64
109116
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every;
110117
weight_change_inertia = 0.9,
111118
pde_loss_weights = 1,
119+
asl_loss_weights = 1,
112120
bc_loss_weights = 1,
113121
additional_loss_weights = 1)
114122
GradientScaleAdaptiveLoss{Float64}(reweight_every;
115123
weight_change_inertia = weight_change_inertia,
116124
pde_loss_weights = pde_loss_weights,
125+
asl_loss_weights = asl_loss_weights,
117126
bc_loss_weights = bc_loss_weights,
118127
additional_loss_weights = additional_loss_weights)
119128
end
@@ -136,7 +145,7 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
136145

137146
nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) :
138147
convert(adaloss_T, 1e-7)
139-
bc_loss_weights_proposed = pde_grads_max ./
148+
bc_loss_weights_proposed = pde_asl_grads_max ./
140149
(bc_grads_mean .+ nonzero_divisor_eps)
141150
adaloss.bc_loss_weights .= weight_change_inertia .*
142151
adaloss.bc_loss_weights .+
@@ -160,8 +169,10 @@ end
160169
```julia
161170
function MiniMaxAdaptiveLoss(reweight_every;
162171
pde_max_optimiser = Flux.ADAM(1e-4),
172+
asl_max_optimiser = Flux.ADAM(1e-4),
163173
bc_max_optimiser = Flux.ADAM(0.5),
164174
pde_loss_weights = 1,
175+
asl_loss_weights = 1,
165176
bc_loss_weights = 1,
166177
additional_loss_weights = 1)
167178
```
@@ -191,65 +202,81 @@ https://arxiv.org/abs/2009.04544
191202
"""
192203
mutable struct MiniMaxAdaptiveLoss{T <: Real,
193204
PDE_OPT <: Flux.Optimise.AbstractOptimiser,
205+
ASL_OPT <: Flux.Optimise.AbstractOptimiser,
194206
BC_OPT <: Flux.Optimise.AbstractOptimiser} <:
195207
AbstractAdaptiveLoss
196208
reweight_every::Int64
197209
pde_max_optimiser::PDE_OPT
210+
asl_max_optimiser::ASL_OPT
198211
bc_max_optimiser::BC_OPT
199212
pde_loss_weights::Vector{T}
213+
asl_loss_weights::Vector{T}
200214
bc_loss_weights::Vector{T}
201215
additional_loss_weights::Vector{T}
202216
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T,
203-
PDE_OPT, BC_OPT}(reweight_every;
217+
PDE_OPT, ASL_OPT, BC_OPT}(reweight_every;
204218
pde_max_optimiser = Flux.ADAM(1e-4),
219+
asl_max_optimiser = Flux.ADAM(1e-4),
205220
bc_max_optimiser = Flux.ADAM(0.5),
206221
pde_loss_weights = 1,
222+
asl_loss_weights = 1,
207223
bc_loss_weights = 1,
208224
additional_loss_weights = 1) where {
209225
T <:
210226
Real,
211227
PDE_OPT <:
212228
Flux.Optimise.AbstractOptimiser,
229+
ASL_OPT <:
230+
Flux.Optimise.AbstractOptimiser,
213231
BC_OPT <:
214232
Flux.Optimise.AbstractOptimiser
215233
}
216-
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser),
234+
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), convert(ASL_OPT, asl_max_optimiser),
217235
convert(BC_OPT, bc_max_optimiser),
218236
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
219-
vectorify(additional_loss_weights, T))
237+
vectorify(asl_loss_weights, T), vectorify(additional_loss_weights, T))
220238
end
221239
end
222240

223241
# default to Float64, ADAM, ADAM
224242
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every;
225243
pde_max_optimiser = Flux.ADAM(1e-4),
244+
asl_max_optimiser = Flux.ADAM(1e-4),
226245
bc_max_optimiser = Flux.ADAM(0.5),
227246
pde_loss_weights = 1,
247+
asl_loss_weights = 1,
228248
bc_loss_weights = 1,
229249
additional_loss_weights = 1)
230250
MiniMaxAdaptiveLoss{Float64, typeof(pde_max_optimiser),
231251
typeof(bc_max_optimiser)}(reweight_every;
232252
pde_max_optimiser = pde_max_optimiser,
253+
asl_max_optimiser = asl_max_optimiser,
233254
bc_max_optimiser = bc_max_optimiser,
234255
pde_loss_weights = pde_loss_weights,
256+
asl_loss_weights = asl_loss_weights,
235257
bc_loss_weights = bc_loss_weights,
236258
additional_loss_weights = additional_loss_weights)
237259
end
238260

239261
function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
240262
adaloss::MiniMaxAdaptiveLoss,
241-
pde_loss_functions, bc_loss_functions)
263+
pde_loss_functions, asl_loss_functions, bc_loss_functions)
242264
pde_max_optimiser = adaloss.pde_max_optimiser
265+
asl_max_optimiser = adaloss.asl_max_optimiser
243266
bc_max_optimiser = adaloss.bc_max_optimiser
244267
iteration = pinnrep.iteration
245268

246-
function run_minimax_adaptive_loss(θ, pde_losses, bc_losses)
269+
function run_minimax_adaptive_loss(θ, pde_losses, asl_losses, bc_losses)
247270
if iteration[1] % adaloss.reweight_every == 0
248271
Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights,
249272
-pde_losses)
273+
Flux.Optimise.update!(asl_max_optimiser, adaloss.asl_loss_weights,
274+
-asl_losses)
250275
Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses)
251276
logvector(pinnrep.logger, adaloss.pde_loss_weights,
252277
"adaptive_loss/pde_loss_weights", iteration[1])
278+
logvector(pinnrep.logger, adaloss.asl_loss_weights,
279+
"adaptive_loss/asl_loss_weights", iteration[1])
253280
logvector(pinnrep.logger, adaloss.bc_loss_weights,
254281
"adaptive_loss/bc_loss_weights",
255282
iteration[1])

0 commit comments

Comments
 (0)