Search
Curve fitting за Италија

Започнуваме со читање на податоците и пробување да направиме fitting на параметрите (како во претходната тетратка) со дефинирање на иницијални ниски (min) и високи (max) граници за секој параметар кој сакаме да спаѓа во fitting.

Еден од најважните параметри кои сеуште не го анализиравме е кога започнува пандемијата (outbreak shift). Датумот од кој податоците почнуваат да се собираат е Јануари 21, така што нашиот модел ќе смета дека вирусот започнал на тој датум Јануари 21 2020 година. За многу земји не се знае точно кога започнува бидејќи има асимптоматска манифестација кај некои пациенти, можеби почнал 2 недели пред можеби подоцна. Секоја земја не знае кој е нејзинот patient 0, првата личност која е заразена повеќето се само како претпоставки. Имавме идеја да го вклучиме ова во нашиoт модел (во равенките), меѓутоа одлучивме да биде само бројка (вкупно денови) бидејќи само со такви знаеме да работиме, а програмирање на вакви бројки во равенките изгледа исклучително тешко (integer programming) и е прекомплицирано за краток период, затоа не одлучивме да го користиме како дополнителен параметар во равенките само како бројка.

Сега почнуваме да ги собираме бројките од различни datasets, поставуваме некои иницијални вредности како 0 на почетокот од пандемијата. Истотака дефинираме х-вредности за fitting; лист од облик [0, 1, 2, 3, ..., вкупно денови]. За fitting, ни треба функција што зима една x-вредност како прв аргумент и сите останати параметри кои сакаме да им најдеме најдобро совпаѓање, и тоа ни ја враќа бројката починати предвидени за таа x-вредност генерирана со параметрите, така што автоматската функција ги споредува предвидувањата на моделот со вистинските податоци. Останува да се иницијализира моделот за curve fitting, да се распоредат сите параметри и да се постави методот. Испробувавме со различни методи differential_evolution меѓутоа многу чудни best-fit резултати добивме, затоа одлучивме да одиме со најпростиот за програмирање, методот: least squares.

Ова е резултатот кој го добиваме за Италија:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import mpld3
from scipy.integrate import odeint
import lmfit
from lmfit.lineshapes import gaussian, lorentzian
import warnings
from scipy.integrate import odeint
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}
pd.options.mode.chained_assignment = None  # default='warn'
mpld3.enable_notebook()
warnings.filterwarnings('ignore')
%matplotlib inline 

gamma = 1.0/9.0
sigma = 1.0/3.0

def deriv(y, t, beta, gamma, sigma, N, p_I_to_C, p_C_to_D, Beds):
    S, E, I, C, R, D = y

    dSdt = -beta(t) * I * S / N
    dEdt = beta(t) * I * S / N - sigma * E
    dIdt = sigma * E - 1/12.0 * p_I_to_C * I - gamma * (1 - p_I_to_C) * I
    dCdt = 1/12.0 * p_I_to_C * I - 1/7.5 * p_C_to_D * min(Beds(t), C) - max(0, C-Beds(t)) - (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dRdt = gamma * (1 - p_I_to_C) * I + (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dDdt = 1/7.5 * p_C_to_D * min(Beds(t), C) + max(0, C-Beds(t))
    return dSdt, dEdt, dIdt, dCdt, dRdt, dDdt

def logistic_R_0(t, R_0_start, k, x0, R_0_end):
    return (R_0_start-R_0_end) / (1 + np.exp(-k*(-t+x0))) + R_0_end

def Model(days, agegroups, beds_per_100k, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s):

    def beta(t):
        return logistic_R_0(t, R_0_start, k, x0, R_0_end) * gamma

    N = sum(agegroups)
    
    def Beds(t):
        beds_0 = beds_per_100k / 100_000 * N
        return beds_0 + s*beds_0*t  # 0.003

    y0 = N-1.0, 1.0, 0.0, 0.0, 0.0, 0.0
    t = np.linspace(0, days-1, days)
    ret = odeint(deriv, y0, t, args=(beta, gamma, sigma, N, prob_I_to_C, prob_C_to_D, Beds))
    S, E, I, C, R, D = ret.T
    R_0_over_time = [beta(i)/gamma for i in range(len(t))]

    return t, S, E, I, C, R, D, R_0_over_time, Beds, prob_I_to_C, prob_C_to_D

def fitter(x, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s):
    ret = Model(days, agegroups, beds_per_100k, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s)
    return ret[6][x]

# read data
beds = pd.read_csv("https://raw.githubusercontent.com/hf2000510/infectious_disease_modelling/master/data/beds.csv", header=0)
agegroups = pd.read_csv("https://raw.githubusercontent.com/hf2000510/infectious_disease_modelling/master/data/agegroups.csv")
probabilities = pd.read_csv("https://raw.githubusercontent.com/hf2000510/infectious_disease_modelling/master/data/probabilities.csv")
covid_data = pd.read_csv("https://tinyurl.com/t59cgxn", parse_dates=["Date"], skiprows=[1])

beds_lookup = dict(zip(beds["Country"], beds["ICU_Beds"]))
agegroup_lookup = dict(zip(agegroups['Location'], agegroups[['0_9', '10_19', '20_29', '30_39', '40_49', '50_59', '60_69', '70_79', '80_89', '90_100']].values))

# prob_I_to_C_1 = list(probabilities.prob_I_to_ICU_1.values)
# prob_I_to_C_2 = list(probabilities.prob_I_to_ICU_2.values)
# prob_C_to_Death_1 = list(probabilities.prob_ICU_to_Death_1.values)
# prob_C_to_Death_2 = list(probabilities.prob_ICU_to_Death_2.values)

# parameters
data = covid_data[covid_data["Country/Region"] == "Italy"]["Value"].values[::-1]
agegroups = agegroup_lookup["Italy"]
beds_per_100k = beds_lookup["Italy"]
outbreak_shift = 30
params_init_min_max = {"R_0_start": (3.0, 2.0, 5.0), "k": (2.5, 0.01, 5.0), "x0": (90, 0, 120), "R_0_end": (0.9, 0.3, 3.5),
                       "prob_I_to_C": (0.05, 0.01, 0.1), "prob_C_to_D": (0.5, 0.05, 0.8),
                       "s": (0.003, 0.001, 0.01)}  # form: {parameter: (initial guess, minimum value, max value)}

days = outbreak_shift + len(data)
if outbreak_shift >= 0:
    y_data = np.concatenate((np.zeros(outbreak_shift), data))
else:
    y_data = y_data[-outbreak_shift:]

x_data = np.linspace(0, days - 1, days, dtype=int)  # x_data is just [0, 1, ..., max_days] array

mod = lmfit.Model(fitter)

for kwarg, (init, mini, maxi) in params_init_min_max.items():
    mod.set_param_hint(str(kwarg), value=init, min=mini, max=maxi, vary=True)

params = mod.make_params()

result = mod.fit(y_data, params, method="least_squares", x=x_data)

result.plot_fit(datafmt="-");

result.best_values
{'R_0_start': 3.5224449165096123,
 'k': 3.6657130934336335,
 'x0': 84.86985361722793,
 'R_0_end': 0.5587523495025978,
 'prob_I_to_C': 0.0671130060064688,
 'prob_C_to_D': 0.582176957630886,
 's': 0.0034397806282748444}

Не е голема грешката, што е доста добро. Останува да ги разгледаме параметрите кои ги предвиди нашиот fitting модел и дали изгледаат реалистично. $x_0=84$, ако податоците почнуваат од 21 Јан и outbreak shift е наместен на 30-ти ден, следува дека 84 ден е 15 Март. $x_0$ е датата (односно бројот на изминати денови) каде имаме најголемо паѓање (decline) на вредноста на $R_0$, т.е. нашиот модел смета дека карантинот е воведен тој период што е горе-доле тука со вистинската дата (9 Март е датумот каде поголемиот дел од провинциите се под карантин, според [10]).

Сега останува де видиме како изгледа проширениот SIR модел за Италија користејќи ги овие параметри кои ги добиваме од нашиот обид за fitting:

import numpy as np
import pandas as pd
import lmfit
import warnings
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import mpld3
from lmfit.lineshapes import gaussian, lorentzian
from scipy.integrate import odeint
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}
mpld3.enable_notebook()
%matplotlib inline 
warnings.filterwarnings('ignore')

def Model(days, agegroups, beds_per_100k, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s):

    def beta(t):
        return logistic_R_0(t, R_0_start, k, x0, R_0_end) * gamma

    N = sum(agegroups)
    
    def Beds(t):
        beds_0 = beds_per_100k / 100_000 * N
        return beds_0 + s*beds_0*t  # 0.003

    y0 = N-1.0, 1.0, 0.0, 0.0, 0.0, 0.0
    t = np.linspace(0, days-1, days)
    ret = odeint(deriv, y0, t, args=(beta, gamma, sigma, N, prob_I_to_C, prob_C_to_D, Beds))
    S, E, I, C, R, D = ret.T
    R_0_over_time = [beta(i)/gamma for i in range(len(t))]

    return t, S, E, I, C, R, D, R_0_over_time, Beds, prob_I_to_C, prob_C_to_D

def deriv(y, t, beta, gamma, sigma, N, p_I_to_C, p_C_to_D, Beds):
    S, E, I, C, R, D = y

    dSdt = -beta(t) * I * S / N
    dEdt = beta(t) * I * S / N - sigma * E
    dIdt = sigma * E - 1/12.0 * p_I_to_C * I - gamma * (1 - p_I_to_C) * I
    dCdt = 1/12.0 * p_I_to_C * I - 1/7.5 * p_C_to_D * min(Beds(t), C) - max(0, C-Beds(t)) - (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dRdt = gamma * (1 - p_I_to_C) * I + (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dDdt = 1/7.5 * p_C_to_D * min(Beds(t), C) + max(0, C-Beds(t))
    return dSdt, dEdt, dIdt, dCdt, dRdt, dDdt

def logistic_R_0(t, R_0_start, k, x0, R_0_end):
    return (R_0_start-R_0_end) / (1 + np.exp(-k*(-t+x0))) + R_0_end

def beta(t):
        return logistic_R_0(t, R_0_start, k, x0, R_0_end) * gamma

    
def Beds(t):
    beds_0 = beds_per_100k / 100_000 * N
    return beds_0 + s*beds_0*t  # 0.003

def logistic_R_0(t, R_0_start, k, x0, R_0_end):
    return (R_0_start-R_0_end) / (1 + np.exp(-k*(-t+x0))) + R_0_end

full_days = 500
first_date = np.datetime64(covid_data.Date.min()) - np.timedelta64(outbreak_shift,'D')
x_ticks = pd.date_range(start=first_date, periods=full_days, freq="D")
t, S, E, I, C, R, D, R_0_over_time, Beds, prob_I_to_C, prob_C_to_D = Model(full_days, 
                                                                           agegroup_lookup["Italy"], 
                                                                           beds_lookup["Italy"], 
                                                                           result.best_values['R_0_start'], 
                                                                           result.best_values['k'], 
                                                                           result.best_values["x0"], 
                                                                           result.best_values["R_0_end"],
                                                                           result.best_values["prob_I_to_C"],
                                                                           result.best_values["prob_C_to_D"],
                                                                           result.best_values["s"])

data_com = []

data_1 = go.Scatter(x = x_ticks, 
                  y = S, 
                  mode = 'lines',
                  visible = True,
                  line=dict(color="blue",
                            width=2),
                  name = "Подлежни; <i>S(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_2 = go.Scatter(x = x_ticks, 
                  y = E, 
                  mode = 'lines',
                  visible = True,
                  line=dict(color="#CCCC00",
                            width=2),
                  name = "Изложени; <i>E(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_3 = go.Scatter(x = x_ticks, 
                  y = I,
                  mode = 'lines',
                  visible = True, 
                  line=dict(color="red",
                            width=2),
                  name = "Заразени; <i>I(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_4 = go.Scatter(x = x_ticks,
                  y = R, 
                  mode = 'lines',
                  visible = True,
                  line=dict(color="green",
                            width=2),
                  name = "Оздравени; <i>R(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_5 = go.Scatter(x = x_ticks,
                  y =  D, 
                  visible = True, 
                  mode = 'lines',
                  line=dict(color="#696969",
                            width=2),
                  name = "Починати; <i>D(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')

data_6 = [go.Scatter(x = x_ticks,
                  y =  C, 
                  visible = True, 
                  mode = 'lines',
                  line=dict(color="tomato",
                            width=2),
                  name = "Критични; <i>C(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')]

data_com.append(data_1)
data_com.append(data_2)
data_com.append(data_3)
data_com.append(data_4)
data_com.append(data_5)

data_R0 = [dict(visible = True,
                    x = x_ticks,
                    y = R_0_over_time,
                    name = 'R<sub>0</sub>',
                    hoverinfo = "all",
                    hovertemplate = "%{y: .1f}",
                    line=dict(width=1.5,
                              dash = 'dot',
                              color = 'rgb(0, 204, 150)'),
                    xaxis='x2',
                    yaxis='y2',
                    showlegend = False
               )]

total_CFR = [0] + [100 * D[i] / sum(sigma*E[:i]) if sum(sigma*E[:i])>0 else 0 for i in range(1, len(t))]
daily_CFR = [0] + [100 * ((D[i]-D[i-1]) / ((R[i]-R[i-1]) + (D[i]-D[i-1]))) if max((R[i]-R[i-1]), (D[i]-D[i-1]))>10 else 0 for i in range(1, len(t))]

data_alpha1 = [dict(visible = False,
                    x = x_ticks,
                    y = total_CFR,
                    name = 'Вкупно &#120572;',
                    hoverinfo = "all",
                    hovertemplate = "%{y: .2f}",
                    line=dict(width=1.5,
                              dash = 'dot',
                              color = 'rgb(238,144,238)'),
                    xaxis='x2',
                    yaxis='y2',
                    showlegend = False,
               )]

data_alpha2 = [dict(visible = False,
                    x = x_ticks,
                    y = daily_CFR,
                    name = 'Дневно &#120572; ',
                    hoverinfo = "all",
                    hovertemplate = "%{y: .2f}",
                    line=dict(width=1.5,
                              dash = 'dot',
                              color = 'palegreen'),
                    xaxis='x2',
                    yaxis='y2',
                    showlegend = False,
               )]

newDs = [0] + [D[i]-D[i-1] for i in range(1, len(t))]

data_death1 = [dict(visible = False,
                    x = x_ticks,
                    y = newDs,
                    name = 'Вкупно починати',
                    hoverinfo = "all",
                    hovertemplate = "%{y}",
                    line=dict(width=1.5,
                              dash = 'dot',
                              color = 'gray'),
                    xaxis='x2',
                    yaxis='y2',
                    showlegend = False,
               )]

data_death2 = [dict(visible = False,
                    x = x_ticks,
                    y = [max(0, C[i]-Beds(i)) for i in range(len(t))],
                    name = 'Починати од недостиг ресурси',
                    hoverinfo = "all",
                    hovertemplate = "%{y}",
                    line=dict(width=1.5,
                              dash = 'dot',
                              color = 'tomato'),
                    xaxis='x2',
                    yaxis='y2',
                    showlegend = False,
               )]

    
data = data_com + data_R0 + data_alpha1 + data_alpha2  + data_death1 + data_death2 + data_6

# Setup the layout of the figure
layout = go.Layout(
    updatemenus=[
        dict(
            active = 0, 
            x=0.29,
            y=1.0,
            yanchor="top",
            buttons=list([
                dict(label="(1) R<sub>0</sub> преку време",
                             method="update",
                             args=[{"visible": [True, True, True, True, True, True, False, False, False, False, True]},
                                   {'annotations': [
                                               dict(text="Прикажи: ", 
                                                     showarrow=False,
                                                     x=0.02, 
                                                     y=1.06,
                                                     xref = 'paper',
                                                     yref="paper"),
                                               dict(x=0.15,
                                                    y=0.76,
                                                    showarrow=False,
                                                    text='R<sub>0</sub> преку време',
                                                    font=dict(size=10),
                                                    xref='paper',
                                                    yref='paper'),
                                                dict(x=0.025,
                                                    y=0.48,
                                                    showarrow=False,
                                                    textangle = -90,
                                                    text='R<sub>0</sub>',
                                                    font=dict(size=10),
                                                    xref='paper',
                                                    yref='paper')] ,
                                    'xaxis2': dict(linecolor='#000',
                                                   domain=[0.08, 0.38],
                                                   anchor='y2',
                                                   mirror = True,
                                                   side='bottom',
                                                   ticks='',
                                                   showline=True),
                                   'yaxis2': dict(autorange=False,
                                                  range=[0, 4],
                                                  tickvals = [0.5, 1.5, 2.5, 3.5],
                                                  linecolor='#000',
                                                  domain=[0.30, 0.70],
                                                  anchor='x2',
                                                  mirror = True,
                                                  ticks='',
                                                  showline=True)}
                                  ]) ,
                
                dict(label="(2) Смртност (%)",
                             method="update",
                             args=[{"visible": [True, True, True, True, True, False, True, True, False, False, True]},
                                   {'annotations': [
                                               dict(text="Прикажи: ", 
                                                     showarrow=False,
                                                     x=0.02, 
                                                     y=1.06,
                                                     xref = 'paper',
                                                     yref="paper"),
                                               dict(x=0.15,
                                                    y=0.76,
                                                    showarrow=False,
                                                    text='&#120572;% преку време',
                                                    font=dict(size=10),
                                                    xref='paper',
                                                    yref='paper'),
                                                dict(x=0.01,
                                                    y=0.48,
                                                    showarrow=False,
                                                    textangle = -90,
                                                    text='&#120572;%',
                                                    font=dict(size=10),
                                                    xref='paper',
                                                    yref='paper')],
                                    'xaxis2': dict(linecolor='#000',
                                                   domain=[0.08, 0.38],
                                                   anchor='y2',
                                                   mirror = True,
                                                   side='bottom',
                                                   ticks='',
                                                   showline=True,
                                                   tickfont = dict(size=9)),
                                   'yaxis2': dict(autorange = False,
                                                  range = [0, 4],
                                                  tickvals = [0, 0.6, 1.2,  1.8,  2.4,  3, 3.6],
                                                  linecolor='#000',
                                                  domain=[0.30, 0.70],
                                                  anchor='x2',
                                                  mirror = True,
                                                  ticks='',
                                                  showline=True)}
                                  ]) , 
                dict(label="(3) Смртни случаеви",
                             method="update",
                             args=[{"visible": [True, True, True, True, True, False, False, False, True, True, True]},
                                   {'annotations': [
                                               dict(text="Прикажи: ", 
                                                     showarrow=False,
                                                     x=0.02, 
                                                     y=1.06,
                                                     xref = 'paper',
                                                     yref="paper"),
                                               dict(x=0.125,
                                                    y=0.76,
                                                    showarrow=False,
                                                    text='Починати дневно',
                                                    font=dict(size=10),
                                                    xref='paper',
                                                    yref='paper'),
                                               dict(x=0.01,
                                                    y=0.48,
                                                    showarrow=False,
                                                    textangle = -90,
                                                    text='#',
                                                    font=dict(size=10),
                                                    xref='paper',
                                                    yref='paper')],
                                    'xaxis2': dict(linecolor='#000',
                                                   domain=[0.08, 0.38],
                                                   anchor='y2',
                                                   mirror = True,
                                                   side='bottom',
                                                   ticks='',
                                                   showline=True,
                                                  tickfont = dict(size=9)),
                                   'yaxis2': dict(autorange = False,
                                                  range = [0, 900],
                                                  tickvals = [0, 200, 400, 600, 800],
                                                  linecolor='#000',
                                                  domain=[0.30, 0.70],
                                                  anchor='x2',
                                                  mirror = True,
                                                  ticks='',
                                                  showline=True)}
                                  ]) 
                ]),
            direction="down"
            )],
                  title = "Проширен SIR модел; Италија",
                  title_x = 0.45, 
                  xaxis_title='Дата (Период)',
                  annotations=[
                               dict(text="Прикажи: ", 
                                     showarrow=False,
                                     x=0.02, 
                                     y=1.06,
                                     xref = 'paper',
                                     yref="paper"),
                               dict(x=0.15,
                                    y=0.76,
                                    showarrow=False,
                                    text='R<sub>0</sub> преку време',
                                    font=dict(size=10),
                                    xref='paper',
                                    yref='paper'),
                                dict(x=0.025,
                                    y=0.48,
                                    showarrow=False,
                                    textangle = -90,
                                    text='R<sub>0</sub>',
                                    font=dict(size=10),
                                    xref='paper',
                                    yref='paper')],
    
                  xaxis=dict(type = 'date',
                             autorange = False,
                             mirror=False,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             rangeslider_visible=True,
                             rangeslider_range = ['2019-12-23', '2021-05-05'],
                             range = ['2019-12-23', '2021-05-05'],
                             tickfont = dict(size=11)),
                  yaxis_title='Број на луѓе',
                  yaxis=dict(range=[-10_00_000,70_000_000], 
                             mirror=False,
                             ticks='outside', 
                             showline=True,
                             showspikes = True,
                             linecolor='#000',
                             tickvals = [0,10_000_000,20_000_000,30_000_000,40_000_000,50_000_000 ,60_000_000], 
                             tickfont = dict(size=11)),
                  xaxis2=dict(linecolor='#000',
                                domain=[0.09, 0.38],
                                anchor='y2',
                                mirror = True,
                                side='bottom',
                                ticks='',
                                showline=True),
                  yaxis2=dict(autorange=False,
                              range=[0, 4],
                              tickvals = [0.5, 1.5, 2.5, 3.5],
                              linecolor='#000',
                            domain=[0.30, 0.70],
                            anchor='x2',
                            mirror = True,
                            ticks='',
                            showline=True),
                  plot_bgcolor='#fff', 
                  hovermode = 'x unified',
                  width = 700, 
                  height = 450,
                  font = dict(size = 10),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

# Plot function saves as html or with ipplot
fig_extend = go.Figure(data=data, layout=layout)
plot(fig_extend, filename = 'fig_extend.html', config = config)
display(HTML('fig_extend.html'))

Од 3-иот мал график (наслов: (3) Починати дневно) може да согледаме дека врвот на смртни случаеви како недостиг на ресурси е некаде на крајот од Март (25-26 Март), што ја зголемува смртната стапка $\alpha$ на некои 1.5%.

Еве еден интересен график, каде може да се менува периодот кои се набљудува (со rangeselector), бидејќи бројот на Подлежни (Susceptible) е многу голем одлучивме да го тргнеме за овој график да може поубаво да се гледа без потреба да се зумира:

import numpy as np
import pandas as pd
import lmfit
import warnings
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import mpld3
from lmfit.lineshapes import gaussian, lorentzian
from scipy.integrate import odeint
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}
mpld3.enable_notebook()
%matplotlib inline 
warnings.filterwarnings('ignore')

def Model(days, agegroups, beds_per_100k, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s):

    def beta(t):
        return logistic_R_0(t, R_0_start, k, x0, R_0_end) * gamma

    N = sum(agegroups)
    
    def Beds(t):
        beds_0 = beds_per_100k / 100_000 * N
        return beds_0 + s*beds_0*t  # 0.003

    y0 = N-1.0, 1.0, 0.0, 0.0, 0.0, 0.0
    t = np.linspace(0, days-1, days)
    ret = odeint(deriv, y0, t, args=(beta, gamma, sigma, N, prob_I_to_C, prob_C_to_D, Beds))
    S, E, I, C, R, D = ret.T
    R_0_over_time = [beta(i)/gamma for i in range(len(t))]

    return t, S, E, I, C, R, D, R_0_over_time, Beds, prob_I_to_C, prob_C_to_D

def deriv(y, t, beta, gamma, sigma, N, p_I_to_C, p_C_to_D, Beds):
    S, E, I, C, R, D = y

    dSdt = -beta(t) * I * S / N
    dEdt = beta(t) * I * S / N - sigma * E
    dIdt = sigma * E - 1/12.0 * p_I_to_C * I - gamma * (1 - p_I_to_C) * I
    dCdt = 1/12.0 * p_I_to_C * I - 1/7.5 * p_C_to_D * min(Beds(t), C) - max(0, C-Beds(t)) - (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dRdt = gamma * (1 - p_I_to_C) * I + (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dDdt = 1/7.5 * p_C_to_D * min(Beds(t), C) + max(0, C-Beds(t))
    return dSdt, dEdt, dIdt, dCdt, dRdt, dDdt

def logistic_R_0(t, R_0_start, k, x0, R_0_end):
    return (R_0_start-R_0_end) / (1 + np.exp(-k*(-t+x0))) + R_0_end

def beta(t):
        return logistic_R_0(t, R_0_start, k, x0, R_0_end) * gamma

    
def Beds(t):
    beds_0 = beds_per_100k / 100_000 * N
    return beds_0 + s*beds_0*t  # 0.003

def logistic_R_0(t, R_0_start, k, x0, R_0_end):
    return (R_0_start-R_0_end) / (1 + np.exp(-k*(-t+x0))) + R_0_end

full_days = 500
first_date = np.datetime64(covid_data.Date.min()) - np.timedelta64(outbreak_shift,'D')
x_ticks = pd.date_range(start=first_date, periods=full_days, freq="D")
t, S, E, I, C, R, D, R_0_over_time, Beds, prob_I_to_C, prob_C_to_D = Model(full_days, 
                                                                           agegroup_lookup["Italy"], 
                                                                           beds_lookup["Italy"], 
                                                                           result.best_values['R_0_start'], 
                                                                           result.best_values['k'], 
                                                                           result.best_values["x0"], 
                                                                           result.best_values["R_0_end"],
                                                                           result.best_values["prob_I_to_C"],
                                                                           result.best_values["prob_C_to_D"],
                                                                           result.best_values["s"])

data_com = []

data_1 = go.Scatter(x = x_ticks, 
                  y = S, 
                  mode = 'lines',
                  visible = True,
                  line=dict(color="blue",
                            width=2),
                  showlegend = False,
                  name = "Подлежни; <i>S(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_2 = go.Scatter(x = x_ticks, 
                  y = E, 
                  mode = 'lines',
                  visible = True,
                  line=dict(color="#CCCC00",
                            width=2),
                  name = "Изложени; <i>E(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_3 = go.Scatter(x = x_ticks, 
                  y = I,
                  mode = 'lines',
                  visible = True, 
                  line=dict(color="red",
                            width=2),
                  name = "Заразени; <i>I(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_4 = go.Scatter(x = x_ticks,
                  y = R, 
                  mode = 'lines',
                  visible = True,
                  line=dict(color="green",
                            width=2),
                  name = "Оздравени; <i>R(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')


data_5 = go.Scatter(x = x_ticks,
                  y =  D, 
                  visible = True, 
                  mode = 'lines',
                  line=dict(color="#696969",
                            width=2),
                  name = "Починати; <i>D(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')

data_6 = [go.Scatter(x = x_ticks,
                  y =  C, 
                  visible = True, 
                  mode = 'lines',
                  line=dict(color="tomato",
                            width=2),
                  name = "Критични; <i>C(t)</i>",
                  hovertemplate = '<i> %{y:.0f} </i> луѓе')]

data_com.append(data_1)
data_com.append(data_2)
data_com.append(data_3)
data_com.append(data_4)
data_com.append(data_5)

  
data = data_com +  data_6

# Setup the layout of the figure
layout = go.Layout(
                  title = "Проширен SIR модел; Италија",
                  title_x = 0.48, 
                  xaxis_title='Дата (Период)',
                  xaxis=dict(type = 'date',
                             autorange = False,
                             mirror=False,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             rangeselector=dict(
                                                buttons=list([
                                                    dict(count=1,
                                                         label="1m",
                                                         step="month",
                                                         stepmode="backward"),
                                                    dict(count=9,
                                                         label="9m",
                                                         step="month",
                                                         stepmode="backward"),
                                                    dict(count=17,
                                                         label="17m",
                                                         step="month",
                                                         stepmode="backward"),
                                                    dict(count=1,
                                                         label="YTD",
                                                         step="year",
                                                         stepmode="todate"),
                                                    dict(count=2,
                                                         label="2y",
                                                         step="year",
                                                         stepmode="backward"),
                                                ])),
                             rangeslider_visible=True,
                             rangeslider_range = ['2019-12-23', '2021-05-05'],
                             range = ['2019-12-23', '2021-05-05'],
                             tickfont = dict(size=11)),
                  yaxis_title='Број на луѓе',
                  yaxis=dict(range=[-100_000,2_100_000], 
                             mirror=False,
                             ticks='outside', 
                             showline=True,
                             showspikes = True,
                             linecolor='#000',
                             tickvals = [0,500_000,1_000_000,1_500_000,2_000_000], 
                             tickfont = dict(size=11)),
                  xaxis2=dict(linecolor='#000',
                                domain=[0.09, 0.38],
                                anchor='y2',
                                mirror = True,
                                side='bottom',
                                ticks='',
                                showline=True),
                  yaxis2=dict(autorange=False,
                              range=[0.5, 4.5],
                              tickvals = [0.5, 1.5, 2, 2.5, 3, 3.5],
                              linecolor='#000',
                            domain=[0.50, 0.90],
                            anchor='x2',
                            mirror = True,
                            ticks='',
                            showline=True),
                  plot_bgcolor='#fff', 
                  hovermode = 'x unified',
                  width = 700, 
                  height = 450,
                  font = dict(size = 10),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=45))

# Plot function saves as html or with ipplot
fig_extend2 = go.Figure(data=data, layout=layout)
plot(fig_extend2, filename = 'fig_extend2.html', config = config)
display(HTML('fig_extend2.html'))

Ова е "груба" препоставка која успеавме да ја добиеме до Април 2021, секако доколку воопшто моделот кореспондира со вистински податоци бидејќи има уште многу фактори кои може да се моделираат. Италија го помина најлошто и бројот на смртни случаеви брзо ќе почне да се намалува низ текот од следните месеци. Користениот модел секако тука е всушност можеби малку поедноставен, бидејќи $R_0$ го поставуваме да седи околу 0.6; доколку оваа бројка почне да расте (повеќе бранови доколку има или се намалаат мерките) бројките повторно ќе почнат да се зголемуваат-менуваат-намалуваат.