Skip to content

Commit 1315e8e

Browse files
authored
Merge pull request #268 from Serra314/Mtest_correction_and_MLL
Correction to M-test with resampling, and addition of new MLL test
2 parents ecaec64 + 3b24128 commit 1315e8e

File tree

9 files changed

+880
-10
lines changed

9 files changed

+880
-10
lines changed

csep/core/catalog_evaluations.py

+270-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
# Third-Party Imports
1+
# Python imports
22
import time
3+
from typing import Optional, TYPE_CHECKING
34

5+
# Third-Party Imports
46
import numpy
57
import scipy.stats
68

79
# PyCSEP imports
810
from csep.core.exceptions import CSEPEvaluationException
11+
from csep.core.catalogs import CSEPCatalog
912
from csep.models import (
1013
CatalogNumberTestResult,
1114
CatalogSpatialTestResult,
@@ -14,7 +17,10 @@
1417
CalibrationTestResult
1518
)
1619
from csep.utils.calc import _compute_likelihood
17-
from csep.utils.stats import get_quantiles, cumulative_square_diff
20+
from csep.utils.stats import get_quantiles, cumulative_square_diff, MLL_score
21+
22+
if TYPE_CHECKING:
23+
from csep.core.forecasts import CatalogForecast
1824

1925

2026
def number_test(forecast, observed_catalog, verbose=True):
@@ -55,6 +61,7 @@ def number_test(forecast, observed_catalog, verbose=True):
5561
obs_name=observed_catalog.name)
5662
return result
5763

64+
5865
def spatial_test(forecast, observed_catalog, verbose=True):
5966
""" Performs spatial test for catalog-based forecasts.
6067
@@ -140,6 +147,7 @@ def spatial_test(forecast, observed_catalog, verbose=True):
140147

141148
return result
142149

150+
143151
def magnitude_test(forecast, observed_catalog, verbose=True):
144152
""" Performs magnitude test for catalog-based forecasts """
145153
test_distribution = []
@@ -215,6 +223,7 @@ def magnitude_test(forecast, observed_catalog, verbose=True):
215223

216224
return result
217225

226+
218227
def pseudolikelihood_test(forecast, observed_catalog, verbose=True):
219228
""" Performs the spatial pseudolikelihood test for catalog forecasts.
220229
@@ -310,6 +319,7 @@ def pseudolikelihood_test(forecast, observed_catalog, verbose=True):
310319

311320
return result
312321

322+
313323
def calibration_test(evaluation_results, delta_1=False):
314324
""" Perform the calibration test by computing a Kilmogorov-Smirnov test of the observed quantiles against a uniform
315325
distribution.
@@ -345,3 +355,261 @@ def calibration_test(evaluation_results, delta_1=False):
345355
return result
346356

347357

358+
def resampled_magnitude_test(forecast: "CatalogForecast",
359+
observed_catalog: CSEPCatalog,
360+
verbose: bool = False,
361+
seed: Optional[int] = None) -> CatalogMagnitudeTestResult:
362+
"""
363+
Performs the resampled magnitude test for catalog-based forecasts (Serafini et al., 2024),
364+
which corrects the bias from the original M-test implementation to the total N of events.
365+
Calculates the (pseudo log-likelihood) test statistic distribution from the forecast's
366+
synthetic catalogs Λ_j as:
367+
368+
D_j = Σ_k [log(Λ_u(k) * N / N_u + 1) - log(Λ̃_j(k) + 1)] ^ 2
369+
370+
where k are the magnitude bins, Λ_u the union of all synthetic catalogs, N_u the total
371+
number of events in Λ_u, and Λ̃_j the resampled catalog containing exactly N events.
372+
373+
The pseudo log-likelihood statistic from the observations is calculated as:
374+
375+
D_o = Σ_k [log(Λ_U(k) * N / N_u + 1) - log(Ω(k) + 1)]^2
376+
377+
where Ω is the observed catalog.
378+
379+
Args:
380+
forecast (CatalogForecast): The forecast to be evaluated
381+
observed_catalog (CSEPCatalog): The observation/testing catalog.
382+
verbose (bool): Flag to display debug messages
383+
seed (int): Random number generator seed
384+
385+
Returns:
386+
A CatalogMagnitudeTestResult object containing the statistic distribution and the
387+
observed statistic.
388+
"""
389+
390+
# set seed
391+
if seed:
392+
numpy.random.seed(seed)
393+
""" """
394+
test_distribution = []
395+
396+
if forecast.region.magnitudes is None:
397+
raise CSEPEvaluationException(
398+
"Forecast must have region.magnitudes member to perform magnitude test.")
399+
400+
# short-circuit if zero events
401+
if observed_catalog.event_count == 0:
402+
print("Cannot perform magnitude test when observed event count is zero.")
403+
# prepare result
404+
result = CatalogMagnitudeTestResult(test_distribution=test_distribution,
405+
name='M-Test',
406+
observed_statistic=None,
407+
quantile=(None, None),
408+
status='not-valid',
409+
min_mw=forecast.min_magnitude,
410+
obs_catalog_repr=str(observed_catalog),
411+
obs_name=observed_catalog.name,
412+
sim_name=forecast.name)
413+
414+
return result
415+
416+
# compute expected rates for forecast if needed
417+
if forecast.expected_rates is None:
418+
forecast.get_expected_rates(verbose=verbose)
419+
420+
# THIS IS NEW - returns the average events in the magnitude bins
421+
union_histogram = numpy.zeros(len(forecast.magnitudes))
422+
for j, cat in enumerate(forecast):
423+
union_histogram += cat.magnitude_counts()
424+
425+
mag_half_bin = numpy.diff(observed_catalog.region.magnitudes)[0] / 2.
426+
# end new
427+
n_union_events = numpy.sum(union_histogram)
428+
obs_histogram = observed_catalog.magnitude_counts()
429+
n_obs = numpy.sum(obs_histogram)
430+
union_scale = n_obs / n_union_events
431+
scaled_union_histogram = union_histogram * union_scale
432+
433+
# this is new - prob to be used for resampling
434+
probs = union_histogram / n_union_events
435+
# end new
436+
437+
# compute the test statistic for each catalog
438+
t0 = time.time()
439+
for i, catalog in enumerate(forecast):
440+
# THIS IS NEW - sampled from the union forecast histogram
441+
mag_values = numpy.random.choice(forecast.magnitudes + mag_half_bin, p=probs,
442+
size=int(n_obs))
443+
extended_mag_max = max(forecast.magnitudes) + 10
444+
mag_counts, tmp = numpy.histogram(mag_values, bins=numpy.append(forecast.magnitudes,
445+
extended_mag_max))
446+
# end new
447+
n_events = numpy.sum(mag_counts)
448+
if n_events == 0:
449+
# print("Skipping to next because catalog contained zero events.")
450+
continue
451+
scale = n_obs / n_events
452+
catalog_histogram = mag_counts * scale
453+
# compute magnitude test statistic for the catalog
454+
test_distribution.append(
455+
cumulative_square_diff(numpy.log10(catalog_histogram + 1),
456+
numpy.log10(scaled_union_histogram + 1))
457+
)
458+
# output status
459+
if verbose:
460+
tens_exp = numpy.floor(numpy.log10(i + 1))
461+
if (i + 1) % 10 ** tens_exp == 0:
462+
t1 = time.time()
463+
print(f'Processed {i + 1} catalogs in {t1 - t0} seconds', flush=True)
464+
465+
# compute observed statistic
466+
obs_d_statistic = cumulative_square_diff(numpy.log10(obs_histogram + 1),
467+
numpy.log10(scaled_union_histogram + 1))
468+
469+
# score evaluation
470+
delta_1, delta_2 = get_quantiles(test_distribution, obs_d_statistic)
471+
472+
# prepare result
473+
result = CatalogMagnitudeTestResult(test_distribution=test_distribution,
474+
name='M-Test',
475+
observed_statistic=obs_d_statistic,
476+
quantile=(delta_1, delta_2),
477+
status='normal',
478+
min_mw=forecast.min_magnitude,
479+
obs_catalog_repr=str(observed_catalog),
480+
obs_name=observed_catalog.name,
481+
sim_name=forecast.name)
482+
483+
return result
484+
485+
486+
def MLL_magnitude_test(forecast: "CatalogForecast",
487+
observed_catalog: CSEPCatalog,
488+
full_calculation: bool = False,
489+
verbose: bool = False,
490+
seed: Optional[int] = None) -> CatalogMagnitudeTestResult:
491+
"""
492+
Implements the modified Multinomial log-likelihood ratio (MLL) magnitude test (Serafini et
493+
al., 2024). Calculates the test statistic distribution as:
494+
495+
D̃_j = -2 * log( L(Λ_u + N_u / N_j + Λ̃_j + 1) /
496+
[L(Λ_u + N_u / N_j) * L(Λ̃_j + 1)]
497+
)
498+
499+
where L is the multinomial likelihood function, Λ_u the union of all the forecasts'
500+
synthetic catalogs, N_u the total number of events in Λ_u, Λ̃_j the resampled catalog
501+
containing exactly N observed events. The observed statistic is defined as:
502+
503+
D_o = -2 * log( L(Λ_u + N_u / N + Ω + 1) /
504+
[L(Λ_u + N_u / N) * L(Ω + 1)]
505+
)
506+
507+
where Ω is the observed catalog.
508+
509+
Args:
510+
forecast (CatalogForecast): The forecast to be evaluated
511+
observed_catalog (CSEPCatalog): The observation/testing catalog.
512+
full_calculation (bool): Whether to sample from the entire stochastic catalogs or from
513+
its already processed magnitude histogram.
514+
verbose (bool): Flag to display debug messages
515+
seed (int): Random number generator seed
516+
517+
Returns:
518+
A CatalogMagnitudeTestResult object containing the statistic distribution and the
519+
observed statistic.
520+
"""
521+
522+
# set seed
523+
if seed:
524+
numpy.random.seed(seed)
525+
526+
test_distribution = []
527+
528+
if forecast.region.magnitudes is None:
529+
raise CSEPEvaluationException(
530+
"Forecast must have region.magnitudes member to perform magnitude test.")
531+
532+
# short-circuit if zero events
533+
if observed_catalog.event_count == 0:
534+
print("Cannot perform magnitude test when observed event count is zero.")
535+
# prepare result
536+
result = CatalogMagnitudeTestResult(test_distribution=test_distribution,
537+
name='M-Test',
538+
observed_statistic=None,
539+
quantile=(None, None),
540+
status='not-valid',
541+
min_mw=forecast.min_magnitude,
542+
obs_catalog_repr=str(observed_catalog),
543+
obs_name=observed_catalog.name,
544+
sim_name=forecast.name)
545+
546+
return result
547+
548+
# compute expected rates for forecast if needed
549+
if forecast.expected_rates is None:
550+
forecast.get_expected_rates(verbose=verbose)
551+
552+
# calculate histograms of union forecast and total number of events
553+
Lambda_u_histogram = numpy.zeros(len(forecast.magnitudes))
554+
555+
if full_calculation:
556+
Lambda_u = []
557+
else:
558+
mag_half_bin = numpy.diff(observed_catalog.region.magnitudes)[0] / 2.
559+
560+
for j, cat in enumerate(forecast):
561+
if full_calculation:
562+
Lambda_u = numpy.append(Lambda_u, cat.get_magnitudes())
563+
Lambda_u_histogram += cat.magnitude_counts()
564+
565+
# # calculate histograms of observations and observed number of events
566+
Omega_histogram = observed_catalog.magnitude_counts()
567+
n_obs = numpy.sum(Omega_histogram)
568+
569+
# compute observed statistic
570+
obs_d_statistic = MLL_score(union_catalog_counts=Lambda_u_histogram,
571+
catalog_counts=Omega_histogram)
572+
573+
probs = Lambda_u_histogram / numpy.sum(Lambda_u_histogram)
574+
575+
# compute the test statistic for each catalog
576+
t0 = time.time()
577+
for i, catalog in enumerate(forecast):
578+
# this is new - sampled from the union forecast histogram
579+
if full_calculation:
580+
mag_values = numpy.random.choice(Lambda_u, size=int(n_obs))
581+
else:
582+
mag_values = numpy.random.choice(forecast.magnitudes + mag_half_bin, p=probs,
583+
size=int(n_obs))
584+
extended_mag_max = max(forecast.magnitudes) + 10
585+
Lambda_j_histogram, tmp = numpy.histogram(mag_values,
586+
bins=numpy.append(forecast.magnitudes,
587+
extended_mag_max))
588+
589+
# compute magnitude test statistic for the catalog
590+
test_distribution.append(
591+
MLL_score(union_catalog_counts=Lambda_u_histogram,
592+
catalog_counts=Lambda_j_histogram)
593+
)
594+
# output status
595+
if verbose:
596+
tens_exp = numpy.floor(numpy.log10(i + 1))
597+
if (i + 1) % 10 ** tens_exp == 0:
598+
t1 = time.time()
599+
print(f'Processed {i + 1} catalogs in {t1 - t0} seconds', flush=True)
600+
601+
# score evaluation
602+
delta_1, delta_2 = get_quantiles(test_distribution, obs_d_statistic)
603+
604+
# prepare result
605+
result = CatalogMagnitudeTestResult(test_distribution=test_distribution,
606+
name='M-Test',
607+
observed_statistic=obs_d_statistic,
608+
quantile=(delta_1, delta_2),
609+
status='normal',
610+
min_mw=forecast.min_magnitude,
611+
obs_catalog_repr=str(observed_catalog),
612+
obs_name=observed_catalog.name,
613+
sim_name=forecast.name)
614+
615+
return result

csep/utils/stats.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy
2-
import scipy.stats
32
import scipy.special
4-
# PyCSEP imports
5-
from csep.core import regions
3+
import scipy.stats
4+
65

76
def sup_dist(cdf1, cdf2):
87
"""
@@ -269,3 +268,61 @@ def get_Kagan_I1_score(forecasts, catalog):
269268
I_1[j] = numpy.dot(counts[non_zero_idx], numpy.log2(rate_den[non_zero_idx] / uniform_forecast)) / n_event
270269

271270
return I_1
271+
272+
273+
def log_d_multinomial(x: numpy.ndarray, size: int, prob: numpy.ndarray):
274+
"""
275+
276+
Args:
277+
x:
278+
size:
279+
prob:
280+
281+
Returns:
282+
283+
"""
284+
return scipy.special.loggamma(size + 1) + numpy.sum(
285+
x * numpy.log(prob) - scipy.special.loggamma(x + 1))
286+
287+
288+
def MLL_score(union_catalog_counts: numpy.ndarray, catalog_counts: numpy.ndarray):
289+
"""
290+
Calculates the modified Multinomial log-likelihood (MLL) score, defined by Serafini et al.,
291+
(2024). It is built from a collection catalogs Λ_u and a single catalog Ω
292+
293+
MLL_score = 2 * log( L(Λ_u + N_u / N_o + Ω + 1) /
294+
[L(Λ_u + N_u / N_o) * L(Ω + 1)]
295+
)
296+
where N_u and N_j are the total number of events in Λ_u and Ω, respectively.
297+
298+
Args:
299+
union_catalog_counts (numpy.ndarray):
300+
catalog_counts (numpy.ndarray):
301+
302+
Returns:
303+
The MLL score for the collection of catalogs and
304+
"""
305+
306+
N_u = numpy.sum(union_catalog_counts)
307+
N_j = numpy.sum(catalog_counts)
308+
events_ratio = N_u / N_j
309+
310+
union_catalog_counts_mod = union_catalog_counts + events_ratio
311+
catalog_counts_mod = catalog_counts + 1
312+
merged_catalog_j = union_catalog_counts_mod + catalog_counts_mod
313+
314+
pr_merged_cat = merged_catalog_j / numpy.sum(merged_catalog_j)
315+
pr_union_cat = union_catalog_counts_mod / numpy.sum(union_catalog_counts_mod)
316+
pr_cat_j = catalog_counts_mod / numpy.sum(catalog_counts_mod)
317+
318+
log_lik_merged = log_d_multinomial(x=merged_catalog_j,
319+
size=numpy.sum(merged_catalog_j),
320+
prob=pr_merged_cat)
321+
log_lik_union = log_d_multinomial(x=union_catalog_counts_mod,
322+
size=numpy.sum(union_catalog_counts_mod),
323+
prob=pr_union_cat)
324+
log_like_cat_j = log_d_multinomial(x=catalog_counts_mod,
325+
size=numpy.sum(catalog_counts_mod),
326+
prob=pr_cat_j)
327+
328+
return 2 * (log_lik_merged - log_lik_union - log_like_cat_j)

0 commit comments

Comments
 (0)