5
5
package fit_test
6
6
7
7
import (
8
+ "fmt"
8
9
"image/color"
9
10
"log"
10
11
"math"
@@ -13,7 +14,10 @@ import (
13
14
"go-hep.org/x/hep/hbook"
14
15
"go-hep.org/x/hep/hplot"
15
16
"gonum.org/v1/gonum/floats"
17
+ "gonum.org/v1/gonum/mat"
16
18
"gonum.org/v1/gonum/optimize"
19
+ "gonum.org/v1/gonum/stat"
20
+ "gonum.org/v1/gonum/stat/distuv"
17
21
"gonum.org/v1/plot/plotter"
18
22
"gonum.org/v1/plot/vg"
19
23
)
@@ -289,3 +293,130 @@ func ExampleCurve1D_powerlaw() {
289
293
}
290
294
}
291
295
}
296
+
297
+ func ExampleCurve1D_hessian () {
298
+ var (
299
+ cst = 3.0
300
+ mean = 30.0
301
+ sigma = 20.0
302
+ want = []float64 {cst , mean , sigma }
303
+ )
304
+
305
+ xdata , ydata , err := readXY ("testdata/gauss-data.txt" )
306
+ if err != nil {
307
+ log .Fatal (err )
308
+ }
309
+
310
+ // use a small sample
311
+ xdata = xdata [:min (25 , len (xdata ))]
312
+ ydata = ydata [:min (25 , len (ydata ))]
313
+
314
+ gauss := func (x , cst , mu , sigma float64 ) float64 {
315
+ v := (x - mu )
316
+ return cst * math .Exp (- v * v / sigma )
317
+ }
318
+
319
+ f1d := fit.Func1D {
320
+ F : func (x float64 , ps []float64 ) float64 {
321
+ return gauss (x , ps [0 ], ps [1 ], ps [2 ])
322
+ },
323
+ X : xdata ,
324
+ Y : ydata ,
325
+ Ps : []float64 {10 , 10 , 10 },
326
+ }
327
+ res , err := fit .Curve1D (f1d , nil , & optimize.NelderMead {})
328
+ if err != nil {
329
+ log .Fatal (err )
330
+ }
331
+
332
+ if err := res .Status .Err (); err != nil {
333
+ log .Fatal (err )
334
+ }
335
+ if got := res .X ; ! floats .EqualApprox (got , want , 1e-3 ) {
336
+ log .Fatalf ("got= %v\n want=%v\n " , got , want )
337
+ }
338
+
339
+ inv := mat .NewSymDense (len (res .Location .X ), nil )
340
+ f1d .Hessian (inv , res .Location .X )
341
+ // fmt.Printf("hessian: %1.2e\n", mat.Formatted(inv, mat.Prefix(" ")))
342
+
343
+ popt := res .Location .X
344
+ pcov := mat .NewDense (len (popt ), len (popt ), nil )
345
+ {
346
+ var chol mat.Cholesky
347
+ if ok := chol .Factorize (inv ); ! ok {
348
+ log .Fatalf ("cov-matrix not positive semi-definite" )
349
+ }
350
+
351
+ err := chol .InverseTo (inv )
352
+ if err != nil {
353
+ log .Fatalf ("could not inverse matrix: %+v" , err )
354
+ }
355
+ pcov .Copy (inv )
356
+ }
357
+
358
+ // compute goodness-of-fit.
359
+ gof := newGoF (f1d .X , f1d .Y , popt , func (x float64 ) float64 {
360
+ return f1d .F (x , popt )
361
+ })
362
+
363
+ pcov .Scale (gof .SSE / float64 (len (f1d .X )- len (popt )), pcov )
364
+
365
+ // fmt.Printf("pcov: %1.2e\n", mat.Formatted(pcov, mat.Prefix(" ")))
366
+
367
+ var (
368
+ n = float64 (len (f1d .X )) // number of data points
369
+ ndf = n - float64 (len (popt )) // number of degrees of freedom
370
+ t = distuv.StudentsT {
371
+ Mu : 0 ,
372
+ Sigma : 1 ,
373
+ Nu : ndf ,
374
+ }.Quantile (0.5 * (1 + 0.95 ))
375
+ )
376
+
377
+ for i , p := range popt {
378
+ sigma := math .Sqrt (pcov .At (i , i ))
379
+ fmt .Printf ("c%d: %1.5e [%1.5e, %1.5e] -- truth: %g\n " , i , p , p - sigma * t , p + sigma * t , want [i ])
380
+ }
381
+ // Output:
382
+ //c0: 2.99999e+00 [2.99999e+00, 3.00000e+00] -- truth: 3
383
+ //c1: 3.00000e+01 [3.00000e+01, 3.00000e+01] -- truth: 30
384
+ //c2: 2.00000e+01 [2.00000e+01, 2.00000e+01] -- truth: 20
385
+ }
386
+
387
+ type GoF struct {
388
+ SSE float64 // Sum of squares due to error
389
+ Rsquare float64 // R-Square is the square of the correlation between the response values and the predicted response values
390
+ NdF int // Number of degrees of freedom
391
+ AdjRsquare float64 // Degrees of freedom adjusted R-Square
392
+ RMSE float64 // Root mean squared error
393
+ }
394
+
395
+ func newGoF (xs , ys , ps []float64 , f func (float64 ) float64 ) GoF {
396
+ switch {
397
+ case len (xs ) != len (ys ):
398
+ panic ("invalid lengths" )
399
+ }
400
+
401
+ var gof GoF
402
+
403
+ var (
404
+ ye = make ([]float64 , len (ys ))
405
+ nn = float64 (len (xs ) - 1 )
406
+ vv = float64 (len (xs ) - len (ps ))
407
+ )
408
+
409
+ for i , x := range xs {
410
+ ye [i ] = f (x )
411
+ dy := ys [i ] - ye [i ]
412
+ gof .SSE += dy * dy
413
+ gof .RMSE += dy * dy
414
+ }
415
+
416
+ gof .Rsquare = stat .RSquaredFrom (ye , ys , nil )
417
+ gof .AdjRsquare = 1 - ((1 - gof .Rsquare ) * nn / vv )
418
+ gof .RMSE = math .Sqrt (gof .RMSE / float64 (len (ys )- len (ps )))
419
+ gof .NdF = len (ys ) - len (ps )
420
+
421
+ return gof
422
+ }
0 commit comments