Source code for sefef.scoring

# -*- coding: utf-8 -*-
"""
sefef.scoring
-------------

This module contains functions to compute both deterministic and probvabilistic metrics according to the horizon of the forecast.

:copyright: (c) 2024 by Ana Sofia Carmo
:license: BSD 3-clause License, see LICENSE for more details.
"""

# third-party
import pandas as pd
import numpy as np
import sklearn
import sklearn.metrics
import plotly.graph_objects as go


# local
from sefef.visualization import COLOR_PALETTE


[docs] class Scorer: ''' Class description Attributes ---------- metrics2compute : list<str> List of metrics to compute. The metrics can be either deterministic or probabilistic and metric names should be the ones from the following list: - Deterministic: "Sen" (i.e. sensitivity), "FPR" (i.e. false positive rate), "TiW" (i.e. time in warning), "AUC_TiW" (i.e. area under the curve of Sen vs TiW). - Probabilistic: "resolution", "reliability", "BS" (i.e. Brier score), "skill" or "BSS" (i.e. Brier skill score). sz_onsets : array-like, shape (#seizures, ), dtype "int64" Contains the Unix timestamps, in seconds, for the start of each seizure onset. forecast_horizon : int Forecast horizon in seconds, i.e. time in the future for which the forecasts are valid. performance : dict Dictionary where the keys are the metrics' names (as in "metrics2compute") and the value is the corresponding performance. It is initialized as an empty dictionary and populated in "compute_metrics". reference_method : str, defaults to "prior_prob" Method to compute the reference forecasts. hist_prior_prob : float64, defaults to None Prior probability, aka historical likelihood (relative frequency) of seizures in train data. Used only as the "hist_prior_prob" reference forecast compute the skill measure. Methods ------- compute_metrics(forecasts, timestamps): Computes metrics in "metrics2compute" for the probabilities in "forecasts" and populates the "performance" attribute. This method uses techniques described in [Mason2004]_ and [Stephenson2008]_. reliability_diagram() : Description Raises ------- ValueError : Raised when a metric name in "metrics2compute" is not a valid metric or when "reference_method" is not a valid method. AttributeError : Raised when 'compute_metrics' is called before 'compute_metrics'. References ---------- .. [Mason2004] S. J. Mason, “On Using ‘Climatology’ as a Reference Strategy in the Brier and Ranked Probability Skill Scores,” Jul. 2004, Accessed: Nov. 06, 2024. [Online]. Available: https://journals.ametsoc.org/view/journals/mwre/132/7/1520-0493_2004_132_1891_oucaar_2.0.co_2.xml .. [Stephenson2008] Stephenson, D. B. , C. A. S. Coelho, and I. T. Jolliffe. "Two Extra Components in the Brier Score Decomposition", Weather and Forecasting 23, 4 (2008): 752-757, doi: https://doi.org/10.1175/2007WAF2006116.1 ''' def __init__(self, metrics2compute, sz_onsets, forecast_horizon, reference_method='prior_prob', hist_prior_prob=None): self.metrics2compute = metrics2compute self.sz_onsets = np.array(sz_onsets) self.forecast_horizon = forecast_horizon self.reference_method = reference_method self.hist_prior_prob = hist_prior_prob self.performance = {} self.metrics2function = {'Sen': self._compute_Sen, 'FPR': self._compute_FPR, 'TiW': self._compute_TiW, 'AUC_TiW': self._compute_AUC, 'resolution': self._compute_resolution, 'reliability': self._compute_reliability, 'calibration': self._compute_reliability, 'BS': self._compute_BS, 'skill': self._compute_skill, 'BSS': self._compute_skill}
[docs] def compute_metrics(self, forecasts, timestamps, threshold=0.5, binning_method='quantile', num_bins=10, draw_diagram=True): ''' Computes metrics in "metrics2compute" for the probabilities in "forecasts" and populates the "performance" attribute. Parameters ---------- forecasts : array-like, shape (#forecasts, ), dtype "float64" Contains the predicted probabilites of seizure occurrence for the period with duration equal to the forecast horizon and starting at the timestamps in "timestamps". timestamps : array-like, shape (#forecasts, ), dtype "int64" Contains the Unix timestamps, in seconds, for the start of the period for which the forecasts (in "forecasts") are valid. threshold : float64, defaults to 0.5 Probability value to apply as the high-likelihood threshold. binning_method : str, defaults to "equal_frequency" Method used to determine the number of bins used to compute probabilistic metrics. Available methods are: - "uniform": number of bins corresponds to np.ceil(#forecasts^(1/3)), set at approximately equal distances. - "quantile": number of bins corresponds to np.ceil(#forecasts^(1/3)), which are populated with an approximately equal number of forecasts. num_bins : int64, defaults to 10 Number of bins used to compute probabilistic metrics. If None, it is calculated as np.ceil(#forecasts^(1/3)), otherwise "num_bins" is used as the number of bins. draw_diagram : bool, defaults to True Whether to draw the reliability diagram after computing all required metrics. Returns ------- performance : dict Dictionary where the keys are the metrics' names (as in "metrics2compute") and the value is the corresponding performance. ''' timestamps = np.array(timestamps) forecasts = np.array(forecasts) timestamps = timestamps[~np.isnan(forecasts)] forecasts = forecasts[~np.isnan(forecasts)] # TODO: VERIFY THIS for metric_name in self.metrics2compute: if metric_name in ['Sen', 'FPR', 'TiW']: tp, fp, fn = self._get_counts(forecasts, timestamps, threshold) self.performance[metric_name] = self.metrics2function[metric_name](tp, fp, fn, forecasts) elif metric_name == 'AUC_TiW': self.performance[metric_name] = self.metrics2function[metric_name](forecasts, timestamps, threshold) elif metric_name in ['resolution', 'reliability', 'calibration', 'BS', 'skill', 'BSS']: bin_edges = self._get_bins_indx(forecasts, binning_method, num_bins) self.performance[metric_name] = self.metrics2function[metric_name](forecasts, timestamps, bin_edges) else: raise ValueError(f'{metric_name} is not a valid metric.') if draw_diagram: self.reliability_diagram(forecasts, timestamps, binning_method=binning_method, num_bins=num_bins) return self.performance
# Deterministic metrics def _get_counts(self, forecasts, timestamps_start_forecast, threshold): '''Internal method that computes counts of true positives (tp), false positives (fp), and false negatives (fn), according to the occurrence (or not) of a seizure event within the forecast horizon.''' timestamps_end_forecast = timestamps_start_forecast + self.forecast_horizon - 1 tp_counts = np.any( (self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :]) & (self.sz_onsets[:, np.newaxis] <= timestamps_end_forecast[np.newaxis, :]) & (forecasts >= threshold), axis=1) no_sz_forecasts = forecasts[~np.any( (self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :]) & (self.sz_onsets[:, np.newaxis] <= timestamps_end_forecast[np.newaxis, :]), axis=0)] tp = np.sum(tp_counts) fn = len(self.sz_onsets) - tp fp = np.sum(no_sz_forecasts >= threshold) return tp, fp, fn def _compute_Sen(self, tp, fp, fn, forecasts): '''Internal method that computes sensitivity, providing a measure of the model's ability to correctly identify pre-ictal periods.''' return tp / (tp + fn) def _compute_FPR(self, tp, fp, fn, forecasts): '''Internal method that computes the false positive rate, i.e. the proportion of time that the user incorrectly spends in alert.''' return fp / len(forecasts) def _compute_TiW(self, tp, fp, fn, forecasts): '''Internal method that computes the time in warning, i.e. the proportion of time that the user spends in alert (i.e. in a high likelihood state, independently of the ”goodness” of the forecast).''' return (tp + fp) / len(forecasts) def _compute_AUC(self, forecasts, timestamps, threshold): '''Internal method that computes the area under the Sen vs TiW curve, abstracting the need for threshold optimization. Computed as the numerical integration of Sen vs TiW using the trapezoidal rule.''' # use unique forecasted values as thresholds thresholds = np.unique(forecasts) tp, fp, fn = np.vectorize(self._get_counts, excluded=['forecasts', 'timestamps_start_forecast'])( forecasts=forecasts, timestamps_start_forecast=timestamps, threshold=thresholds) sen = np.vectorize(self._compute_Sen, excluded=['forecasts'])(tp=tp, fp=fp, fn=fn, forecasts=forecasts) tiw = np.vectorize(self._compute_TiW, excluded=['forecasts'])(tp=tp, fp=fp, fn=fn, forecasts=forecasts) # add point (0, 0) to curve since the auc() function computes the area strictly based on the given points return sklearn.metrics.auc(np.append(tiw, 0.), np.append(sen, 0.)) # Probabilistic metrics def _get_bins_indx(self, forecasts, binning_method, num_bins): '''Internal method that computes the edges of probability bins so that each bin contains the same number of observations. If not provided, the number of bins is determined by n^(1/3), as proposed in np.histogram_bin_edges.''' if num_bins is None: num_bins = np.ceil(len(forecasts)**(1/3)).astype('int64') if binning_method == 'uniform': bin_edges = np.linspace(0, 1, num_bins + 1) elif binning_method == 'quantile': percentile = np.linspace(0, 100, num_bins + 1) bin_edges = np.percentile(np.sort(forecasts), percentile)[1:] # remove edge corresponding to 0th percentile else: raise ValueError(f'{binning_method} is not a valid binning method') return bin_edges def _compute_resolution(self, forecasts, timestamps, bin_edges): '''Internal method that computes the resolution, i.e. the ability of the model to differentiate between individual observed probabilities and the average observed probability. "y_avg": observed relative frequency of true events for all forecasts; "y_k_avg": observed relative frequency of true events for the kth probability bin.''' binned_data = np.digitize(forecasts, bin_edges, right=True) y_avg = len(self.sz_onsets) / len(forecasts) resolution = [] for k in np.unique(binned_data): binned_indx = np.where(binned_data == k) events_in_bin, _, _ = self._get_counts(forecasts[binned_indx], timestamps[binned_indx], threshold=0.) y_k_avg = events_in_bin / len(forecasts[binned_indx]) resolution += [len(forecasts[binned_indx]) * ((y_k_avg - y_avg) ** 2)] return np.sum(resolution) * (1/len(forecasts)) def _compute_reliability(self, forecasts, timestamps, bin_edges): '''Internal method that computes reliability, i.e. the agreement between forecasted and observed probabilities through the Brier score. "y_k_avg": observed relative frequency of true events for the kth probability bin.''' binned_data = np.digitize(forecasts, bin_edges, right=True) reliability = [] for k in np.unique(binned_data): binned_indx = np.where(binned_data == k) events_in_bin, _, _ = self._get_counts(forecasts[binned_indx], timestamps[binned_indx], threshold=0.) y_k_avg = events_in_bin / len(forecasts[binned_indx]) reliability += [len(forecasts[binned_indx]) * ((np.mean(forecasts[binned_indx]) - y_k_avg) ** 2)] return np.sum(reliability) * (1/len(forecasts)) def _compute_uncertainty(self, forecasts, timestamps, bin_edges): '''Internal method that computes uncertainty. "y_avg": observed relative frequency of true events for all forecasts''' y_avg = len(self.sz_onsets) / len(forecasts) return y_avg * (1-y_avg) def _compute_WBV(self, forecasts, timestamps, bin_edges): '''Internal method that computes within-bin variance.''' binned_data = np.digitize(forecasts, bin_edges, right=True) wbv = [] for k in np.unique(binned_data): binned_indx = np.where(binned_data == k) wbv += [np.sum((forecasts[binned_indx] - np.mean(forecasts[binned_indx]))**2)] return np.sum(wbv) * (1/len(forecasts)) def _compute_WBC(self, forecasts, timestamps, bin_edges): '''Internal method that computes within-bin covariance.''' binned_data = np.digitize(forecasts, bin_edges, right=True) wbc = [] for k in np.unique(binned_data): binned_indx = np.where(binned_data == k) timestamps_start_forecast = timestamps[binned_indx] timestamps_end_forecast = timestamps_start_forecast + self.forecast_horizon - 1 y_ki = np.any((self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :]) & ( self.sz_onsets[:, np.newaxis] <= timestamps_end_forecast[np.newaxis, :]), axis=0) wbc += [np.sum((y_ki - np.mean(y_ki)) * (forecasts[binned_indx] - np.mean(forecasts[binned_indx])))] return np.sum(wbc) * (2/len(forecasts)) def _compute_BS(self, forecasts, timestamps, bin_edges): '''Internal method that computes the Brier score, through the decomposition proposed in [Stephenson2008]_.''' if 'reliability' in self.performance.keys(): reliability = self.performance['reliability'] else: reliability = self._compute_reliability(forecasts, timestamps, bin_edges) if 'resolution' in self.performance.keys(): resolution = self.performance['resolution'] else: resolution = self._compute_resolution(forecasts, timestamps, bin_edges) uncertainty = self._compute_uncertainty(forecasts, timestamps, bin_edges) wbv = self._compute_WBV(forecasts, timestamps, bin_edges) wbc = self._compute_WBC(forecasts, timestamps, bin_edges) return (reliability - resolution + uncertainty + wbv - wbc) def _compute_skill(self, forecasts, timestamps, bin_edges): '''Internal method that computes the Brier skill score against a reference forecast. Simplification of BS of reference forecast as described in [Mason2004]_.''' if 'BS' in self.performance.keys(): bs = self.performance['BS'] else: bs = self._compute_BS(forecasts, timestamps, bin_edges) ref_forecasts = self._get_reference_forecasts(timestamps) return 1 - bs / self._compute_uncertainty(ref_forecasts, None, None) def _get_reference_forecasts(self, timestamps): '''Internal method that returns a reference forecast according to the specified method. "y_avg": observed relative frequency of true events for all forecasts.''' if self.reference_method == 'prior_prob': y_avg = len(self.sz_onsets) / len(timestamps) return y_avg * np.ones_like(timestamps) else: raise ValueError(f'{self.reference_method} is not a valid method to compute the reference forecasts.')
[docs] def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): '''Method that plots the reliability diagram (forecasted_proba vs observed_proba), along with the no-resolution and perfect-reliability lines.''' fig = go.Figure() y_avg = len(self.sz_onsets) / len(forecasts) bin_edges = self._get_bins_indx(forecasts, binning_method, num_bins) binned_data = np.digitize(forecasts, bin_edges, right=True) bin_edges = np.insert(bin_edges, 0, 0.) diagram_data = pd.DataFrame(columns=['observed_proba', 'forecasted_proba'], index=(bin_edges[:-1] + bin_edges[1:]) / 2) for k in np.unique(binned_data): binned_indx = np.where(binned_data == k) events_in_bin, _, _ = self._get_counts( forecasts[binned_indx], timestamps[binned_indx], threshold=0.) y_k_avg = events_in_bin / len(forecasts[binned_indx]) diagram_data.iloc[k, :] = [y_k_avg, np.mean(forecasts[binned_indx])] fig.add_trace(go.Scatter( x=diagram_data.loc[:, 'forecasted_proba'], y=diagram_data.loc[:, 'observed_proba'], mode='lines', line=dict(width=3, color=COLOR_PALETTE[1]), name='Reliability curve' )) fig.add_trace(go.Scatter( x=diagram_data.loc[:, 'forecasted_proba'], y=diagram_data.loc[:, 'observed_proba'], mode='markers', marker=dict(size=10, color=COLOR_PALETTE[1]), name='Bin average' )) fig.add_trace(go.Scatter( x=[0, 1], y=[0, 1], line=dict(width=3, color=COLOR_PALETTE[0], dash='dash'), # showlegend=False, mode='lines', name='Perfect reliability' )) fig.add_trace(go.Scatter( x=[0, 1], y=[y_avg, y_avg], line=dict(width=3, color='lightgrey', dash='dash'), mode='lines', name='No resolution' )) # Config plot layout fig.update_yaxes( title='observed probability', tickfont=dict(size=12), showline=True, linewidth=2, linecolor=COLOR_PALETTE[2], showgrid=False, range=[diagram_data.min().min(), diagram_data.max().max()] ) fig.update_xaxes( title='forecasted probability', tickfont=dict(size=12), showline=True, linewidth=2, linecolor=COLOR_PALETTE[2], showgrid=False, range=[diagram_data.min().min(), diagram_data.max().max()], ) fig.update_layout( title=f'Reliability diagram (binning method: {binning_method})', showlegend=True, plot_bgcolor='white', ) fig.show()