15
15
"""
16
16
```julia
17
17
NonAdaptiveLoss{T}(; pde_loss_weights = 1,
18
+ asl_loss_weights = 1,
18
19
bc_loss_weights = 1,
19
20
additional_loss_weights = 1)
20
21
```
@@ -24,31 +25,34 @@ change during optimization
24
25
"""
25
26
mutable struct NonAdaptiveLoss{T <: Real } <: AbstractAdaptiveLoss
26
27
pde_loss_weights:: Vector{T}
28
+ asl_loss_weights:: Vector{T}
27
29
bc_loss_weights:: Vector{T}
28
30
additional_loss_weights:: Vector{T}
29
31
SciMLBase. @add_kwonly function NonAdaptiveLoss {T} (; pde_loss_weights = 1 ,
32
+ asl_loss_weights = 1 ,
30
33
bc_loss_weights = 1 ,
31
34
additional_loss_weights = 1 ) where {
32
35
T < :
33
36
Real
34
37
}
35
- new (vectorify (pde_loss_weights, T), vectorify (bc_loss_weights, T),
38
+ new (vectorify (pde_loss_weights, T), vectorify (asl_loss_weights, T), vectorify ( bc_loss_weights, T),
36
39
vectorify (additional_loss_weights, T))
37
40
end
38
41
end
39
42
40
43
# default to Float64
41
- SciMLBase. @add_kwonly function NonAdaptiveLoss (; pde_loss_weights = 1 , bc_loss_weights = 1 ,
44
+ SciMLBase. @add_kwonly function NonAdaptiveLoss (; pde_loss_weights = 1 , asl_loss_weights = 1 , bc_loss_weights = 1 ,
42
45
additional_loss_weights = 1 )
43
46
NonAdaptiveLoss {Float64} (; pde_loss_weights = pde_loss_weights,
47
+ asl_loss_weights = asl_loss_weights,
44
48
bc_loss_weights = bc_loss_weights,
45
49
additional_loss_weights = additional_loss_weights)
46
50
end
47
51
48
52
function generate_adaptive_loss_function (pinnrep:: PINNRepresentation ,
49
53
adaloss:: NonAdaptiveLoss ,
50
- pde_loss_functions, bc_loss_functions)
51
- function null_nonadaptive_loss (θ, pde_losses , bc_losses)
54
+ pde_loss_functions, asl_loss_functions, bc_loss_functions)
55
+ function null_nonadaptive_loss (θ, pde_loss, asl_loss , bc_losses)
52
56
nothing
53
57
end
54
58
end
58
62
GradientScaleAdaptiveLoss(reweight_every;
59
63
weight_change_inertia = 0.9,
60
64
pde_loss_weights = 1,
65
+ asl_loss_weights = 1,
61
66
bc_loss_weights = 1,
62
67
additional_loss_weights = 1)
63
68
```
@@ -90,30 +95,34 @@ mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
90
95
reweight_every:: Int64
91
96
weight_change_inertia:: T
92
97
pde_loss_weights:: Vector{T}
98
+ asl_loss_weights:: Vector{T}
93
99
bc_loss_weights:: Vector{T}
94
100
additional_loss_weights:: Vector{T}
95
101
SciMLBase. @add_kwonly function GradientScaleAdaptiveLoss {T} (reweight_every;
96
102
weight_change_inertia = 0.9 ,
97
103
pde_loss_weights = 1 ,
104
+ asl_loss_weights = 1 ,
98
105
bc_loss_weights = 1 ,
99
106
additional_loss_weights = 1 ) where {
100
107
T < :
101
108
Real
102
109
}
103
110
new (convert (Int64, reweight_every), convert (T, weight_change_inertia),
104
- vectorify (pde_loss_weights, T), vectorify (bc_loss_weights , T),
105
- vectorify (additional_loss_weights, T))
111
+ vectorify (pde_loss_weights, T), vectorify (asl_loss_weights , T),
112
+ vectorify (bc_loss_weights, T), vectorify ( additional_loss_weights, T))
106
113
end
107
114
end
108
115
# default to Float64
109
116
SciMLBase. @add_kwonly function GradientScaleAdaptiveLoss (reweight_every;
110
117
weight_change_inertia = 0.9 ,
111
118
pde_loss_weights = 1 ,
119
+ asl_loss_weights = 1 ,
112
120
bc_loss_weights = 1 ,
113
121
additional_loss_weights = 1 )
114
122
GradientScaleAdaptiveLoss {Float64} (reweight_every;
115
123
weight_change_inertia = weight_change_inertia,
116
124
pde_loss_weights = pde_loss_weights,
125
+ asl_loss_weights = asl_loss_weights,
117
126
bc_loss_weights = bc_loss_weights,
118
127
additional_loss_weights = additional_loss_weights)
119
128
end
@@ -136,7 +145,7 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
136
145
137
146
nonzero_divisor_eps = adaloss_T isa Float64 ? Float64 (1e-11 ) :
138
147
convert (adaloss_T, 1e-7 )
139
- bc_loss_weights_proposed = pde_grads_max ./
148
+ bc_loss_weights_proposed = pde_asl_grads_max ./
140
149
(bc_grads_mean .+ nonzero_divisor_eps)
141
150
adaloss. bc_loss_weights .= weight_change_inertia .*
142
151
adaloss. bc_loss_weights .+
160
169
```julia
161
170
function MiniMaxAdaptiveLoss(reweight_every;
162
171
pde_max_optimiser = Flux.ADAM(1e-4),
172
+ asl_max_optimiser = Flux.ADAM(1e-4),
163
173
bc_max_optimiser = Flux.ADAM(0.5),
164
174
pde_loss_weights = 1,
175
+ asl_loss_weights = 1,
165
176
bc_loss_weights = 1,
166
177
additional_loss_weights = 1)
167
178
```
@@ -191,65 +202,81 @@ https://arxiv.org/abs/2009.04544
191
202
"""
192
203
mutable struct MiniMaxAdaptiveLoss{T <: Real ,
193
204
PDE_OPT <: Flux.Optimise.AbstractOptimiser ,
205
+ ASL_OPT <: Flux.Optimise.AbstractOptimiser ,
194
206
BC_OPT <: Flux.Optimise.AbstractOptimiser } < :
195
207
AbstractAdaptiveLoss
196
208
reweight_every:: Int64
197
209
pde_max_optimiser:: PDE_OPT
210
+ asl_max_optimiser:: ASL_OPT
198
211
bc_max_optimiser:: BC_OPT
199
212
pde_loss_weights:: Vector{T}
213
+ asl_loss_weights:: Vector{T}
200
214
bc_loss_weights:: Vector{T}
201
215
additional_loss_weights:: Vector{T}
202
216
SciMLBase. @add_kwonly function MiniMaxAdaptiveLoss{T,
203
- PDE_OPT, BC_OPT}(reweight_every;
217
+ PDE_OPT, ASL_OPT, BC_OPT}(reweight_every;
204
218
pde_max_optimiser = Flux. ADAM (1e-4 ),
219
+ asl_max_optimiser = Flux. ADAM (1e-4 ),
205
220
bc_max_optimiser = Flux. ADAM (0.5 ),
206
221
pde_loss_weights = 1 ,
222
+ asl_loss_weights = 1 ,
207
223
bc_loss_weights = 1 ,
208
224
additional_loss_weights = 1 ) where {
209
225
T < :
210
226
Real,
211
227
PDE_OPT < :
212
228
Flux. Optimise. AbstractOptimiser,
229
+ ASL_OPT < :
230
+ Flux. Optimise. AbstractOptimiser,
213
231
BC_OPT < :
214
232
Flux. Optimise. AbstractOptimiser
215
233
}
216
- new (convert (Int64, reweight_every), convert (PDE_OPT, pde_max_optimiser),
234
+ new (convert (Int64, reweight_every), convert (PDE_OPT, pde_max_optimiser), convert (ASL_OPT, asl_max_optimiser),
217
235
convert (BC_OPT, bc_max_optimiser),
218
236
vectorify (pde_loss_weights, T), vectorify (bc_loss_weights, T),
219
- vectorify (additional_loss_weights, T))
237
+ vectorify (asl_loss_weights, T), vectorify ( additional_loss_weights, T))
220
238
end
221
239
end
222
240
223
241
# default to Float64, ADAM, ADAM
224
242
SciMLBase. @add_kwonly function MiniMaxAdaptiveLoss (reweight_every;
225
243
pde_max_optimiser = Flux. ADAM (1e-4 ),
244
+ asl_max_optimiser = Flux. ADAM (1e-4 ),
226
245
bc_max_optimiser = Flux. ADAM (0.5 ),
227
246
pde_loss_weights = 1 ,
247
+ asl_loss_weights = 1 ,
228
248
bc_loss_weights = 1 ,
229
249
additional_loss_weights = 1 )
230
250
MiniMaxAdaptiveLoss{Float64, typeof (pde_max_optimiser),
231
251
typeof (bc_max_optimiser)}(reweight_every;
232
252
pde_max_optimiser = pde_max_optimiser,
253
+ asl_max_optimiser = asl_max_optimiser,
233
254
bc_max_optimiser = bc_max_optimiser,
234
255
pde_loss_weights = pde_loss_weights,
256
+ asl_loss_weights = asl_loss_weights,
235
257
bc_loss_weights = bc_loss_weights,
236
258
additional_loss_weights = additional_loss_weights)
237
259
end
238
260
239
261
function generate_adaptive_loss_function (pinnrep:: PINNRepresentation ,
240
262
adaloss:: MiniMaxAdaptiveLoss ,
241
- pde_loss_functions, bc_loss_functions)
263
+ pde_loss_functions, asl_loss_functions, bc_loss_functions)
242
264
pde_max_optimiser = adaloss. pde_max_optimiser
265
+ asl_max_optimiser = adaloss. asl_max_optimiser
243
266
bc_max_optimiser = adaloss. bc_max_optimiser
244
267
iteration = pinnrep. iteration
245
268
246
- function run_minimax_adaptive_loss (θ, pde_losses, bc_losses)
269
+ function run_minimax_adaptive_loss (θ, pde_losses, asl_losses, bc_losses)
247
270
if iteration[1 ] % adaloss. reweight_every == 0
248
271
Flux. Optimise. update! (pde_max_optimiser, adaloss. pde_loss_weights,
249
272
- pde_losses)
273
+ Flux. Optimise. update! (asl_max_optimiser, adaloss. asl_loss_weights,
274
+ - asl_losses)
250
275
Flux. Optimise. update! (bc_max_optimiser, adaloss. bc_loss_weights, - bc_losses)
251
276
logvector (pinnrep. logger, adaloss. pde_loss_weights,
252
277
" adaptive_loss/pde_loss_weights" , iteration[1 ])
278
+ logvector (pinnrep. logger, adaloss. asl_loss_weights,
279
+ " adaptive_loss/asl_loss_weights" , iteration[1 ])
253
280
logvector (pinnrep. logger, adaloss. bc_loss_weights,
254
281
" adaptive_loss/bc_loss_weights" ,
255
282
iteration[1 ])
0 commit comments