# -*- coding: utf-8 -*-
"""
sefef.visualization
-------------------
This is a helper module for visualization.
:copyright: (c) 2024 by Ana Sofia Carmo
:license: BSD 3-clause License, see LICENSE for more details.
"""
# built-in
import os
from datetime import datetime
# third-party
from plotly.subplots import make_subplots
import matplotlib as mpl
import numpy as np
import plotly.graph_objects as go
import pandas as pd
import plotly.express.colors as pc
import plotly.express as px
COLOR_PALETTE = ['#4179A0', '#A0415D', '#44546A', '#44AA97', '#FFC000', '#0F3970', '#873C26']
[docs]
def hex_to_rgba(h, alpha):
'''Converts color value in hex format to rgba format with alpha transparency'''
return tuple([int(h.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)] + [alpha])
[docs]
def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horizon, title='Seizure probability', folder_path=None, filename=None, show=True, return_plot=False, n_points=100):
''' Provide visualization of forecasts.
Parameters
----------
forecasts : array-like, shape (#forecasts, ), dtype "float64"
Contains the predicted probabilites of seizure occurrence for the period with duration "forecast_horizon" and starting at the timestamps in "result2".
ts : array-like, shape (#forecasts, ), dtype "int64"
Contains the Unix timestamps, in seconds, for the start of the period for which the forecasts (in "result1") are valid.
sz_onsets : array-like, shape (#sz onsets, )
Contains the unix timestamps (in seconds) of the onsts of seizures.
high_likelihood_thr : float64
Value between 0 and 1 corresponding to the threshold of high-likelihood.
'''
fig = go.Figure()
# get only sz_onsets for which there are forecasts
ts_forecast_end = ts + forecast_horizon
sz_onsets = sz_onsets[np.logical_and((sz_onsets[:, np.newaxis] >= ts[np.newaxis, :]),
(sz_onsets[:, np.newaxis] < ts_forecast_end[np.newaxis, :])).any(axis=1)]
sz_onsets_forecasts_ind = np.argwhere(np.logical_and(
(sz_onsets[:, np.newaxis] >= ts[np.newaxis, :]), (sz_onsets[:, np.newaxis] < ts_forecast_end[np.newaxis, :])))[:, 1]
# Draw background
color_gradient = px.colors.sequential.RdBu[1:-1][::-1]
color_gradient.remove('rgb(247,247,247)')
low_prob_values = np.linspace(0, high_likelihood_thr, int(n_points/2))
high_prob_values = np.linspace(high_likelihood_thr, 1, int(n_points/2))
low_prob_colors = [color_gradient[int(
val / high_likelihood_thr * (len(color_gradient)//2 - 1))] for val in low_prob_values]
high_prob_colors = [color_gradient[-len(color_gradient)//2:][int((val - high_likelihood_thr) / (1 - high_likelihood_thr) * (len(color_gradient)//2 - 1))]
for val in high_prob_values]
fig = _get_gradient(fig, low_prob_values, low_prob_colors, x0=pd.to_datetime(ts[0], unit='s'), x1=pd.to_datetime(
ts_forecast_end[-1], unit='s'))
fig = _get_gradient(fig, high_prob_values, high_prob_colors, x0=pd.to_datetime(ts[0], unit='s'), x1=pd.to_datetime(
ts_forecast_end[-1], unit='s'))
df_start = pd.DataFrame({'ts': ts, 'forecasts': forecasts})
df_end = df_start.copy()
df_end['ts'] += forecast_horizon - 1
df = pd.concat([df_start, df_end]).sort_values('ts').reset_index(drop=True)
fig.add_trace(go.Scatter(
x=pd.to_datetime(df['ts'], unit='s'), y=df['forecasts'],
mode='lines',
line_color='white',
line_width=3,
name='Forecast'
))
fig.add_trace(go.Scatter(
x=pd.to_datetime(sz_onsets, unit='s'), y=forecasts[sz_onsets_forecasts_ind],
mode='markers',
marker=dict(
color=COLOR_PALETTE[1],
size=12,
symbol='star',
),
name='Seizure'
))
fig.add_hline(y=high_likelihood_thr, line_width=1,
line_color='#FF0000', yref='y',)
non_nan_forecasts = forecasts[~np.isnan(forecasts)]
fig.update_yaxes(
title='Probability',
showgrid=False,
tickfont=dict(size=12),
range=[np.max([0, np.min(non_nan_forecasts) - np.std(non_nan_forecasts)]),
np.min([1, np.max(non_nan_forecasts) + np.std(non_nan_forecasts)])],
)
fig.update_xaxes(
title='Time',
showgrid=False,
)
fig.update_layout(
title=title,
showlegend=True,
plot_bgcolor='white',
legend_bgcolor='whitesmoke')
if folder_path is not None:
if not os.path.exists(folder_path):
os.mkdir(folder_path)
fig.write_image(os.path.join(folder_path, filename))
if show:
fig.show()
if return_plot:
return fig
def _get_gradient(fig, values, colors, x0, x1):
for i in range(len(colors)-1):
fig.add_shape(
type="rect",
x0=x0, x1=x1,
y0=values[i], y1=values[i+1],
fillcolor=colors[i],
yref='y',
line_width=0, # No border
layer="below",)
return fig
[docs]
def aggregate_plots(figs, folder_path=None, filename=None, show=True,):
''' Receives go.Figure objects created using "plot_forecasts" and aggregates them into a single Figure.
Parameters
----------
figs : go.Figure
Figures to aggregate into a single plot.
'''
fig = make_subplots(rows=1, cols=len(
figs), shared_yaxes=True, horizontal_spacing=0.05) # 0.005
forecasts = []
for ifig, figure in enumerate(figs):
print(f'Aggregating forecast plots ({ifig+1}/{len(figs)})', end='\r')
x0, x1 = datetime.max, datetime.min
for trace in figure.data:
trace['showlegend'] = (ifig == 0)
fig.add_trace(trace, row=1, col=(ifig+1))
x0, x1 = min(x0, min(trace['x'])), max(x1, max(trace['x']))
for shape in figure.layout.shapes or []:
if shape.type == 'rect':
new_shape = shape.to_plotly_json()
new_shape["xref"], new_shape["yref"] = f"x{ifig+1}", f"y{ifig+1}"
fig.add_shape(new_shape)
elif shape.type == 'line':
fig.add_hline(
y=shape['y0'], line=shape['line'], yref=f"y{ifig+1}")
for annotation in figure.layout.annotations or []:
annotation["xref"], annotation["yref"] = f"x{ifig+1}", f"y{ifig+1}"
fig.add_annotation(annotation)
fig.update_xaxes(
range=[x0, x1], row=1, col=(ifig+1))
fig.update_yaxes(
showticklabels=(ifig == 0),
row=1, col=(ifig+1))
forecasts += [figure['data'][0]['y']]
# Update layout
forecasts = np.concat(forecasts)
fig.update_layout(
title_text=figure['layout']['title']['text'],
showlegend=True, plot_bgcolor='white', legend_bgcolor='whitesmoke',
yaxis_title='Probability',
annotations=[
dict(
text="Time",
x=0.5, y=-0.11,
xref="paper",
yref="paper",
showarrow=False,
font=dict(size=14)
)],)
fig.update_xaxes(showgrid=False,)
non_nan_forecasts = forecasts[~np.isnan(forecasts)]
fig.update_yaxes(
showgrid=False,
range=[np.max([0, np.min(non_nan_forecasts) - np.std(non_nan_forecasts)]),
np.min([1, np.max(non_nan_forecasts) + np.std(non_nan_forecasts)])])
if folder_path is not None:
if not os.path.exists(folder_path):
os.mkdir(folder_path)
fig.write_image(os.path.join(folder_path, filename))
if show:
fig.show()