"""Module for orthogonal distance regression."""

import numpy as np

from uncertainties import unumpy
from collections import defaultdict
from itertools import product
from scipy.odr import RealData, Model, ODR
from scipy.stats import distributions

def residual(beta, x, var='B', bounds=False):
    """
    Compute objective function for orthogonal distance regression.
    
    Scalar residuals returned of list of length of x.
    """
    
    # Unpack multidimensional inputs
    GIc, GIIc, n, m = beta
    Gi, Gii = x
    
    # Assign penalty if exponents not within bounds
    if bounds:
        if not (1 <= n <= 10 and 1 <= m <= 10):
            return 1e3

    # Compute residual
    with np.errstate(invalid='ignore'):
        if var == 'A':
            res = ((Gi/GIc)**n + (Gii/GIIc)**m)**(2/(n+m)) - 1
        elif var == 'B':
            res = ((Gi/GIc)**n + (Gii/GIIc)**m) - 1
        elif var == 'C':
            res = ((Gi/GIc)**(1/n) + (Gii/GIIc)**(1/m)) - 1
        elif var == 'BK':
            res = (GIc + (GIIc - GIc)*(Gii/(Gi + Gii))**m)/(Gi + Gii) - 1
        else:
            raise NotImplementedError(f'Criterion type {var} not implemented.')

    # Return
    return res

def param_jacobian(beta, x, var='B', *args):
    """
    Compute derivates with respect to parameters (GIc, GIIc, n, m).
    
    Jacobian = [df/dGIc, df/dGIIc, df/dn, df/dm].T

    Parameters
    ----------
    beta : list[float]
        Model parameters (GIc, GIIc, n, m).
    x : list[float]
        Variables (Gi, Gii).
    var : str, optional
        Residual variant. Default is 'B'.

    Returns
    -------
    np.ndarray
        Jacobian matrix.

    Raises
    ------
    NotImplementedError
        If residual variant is not implemented.
    """
    
    # Unpack multidimensional inputs
    GIc, GIIc, n, m = beta
    Gi, Gii = x
    
    # Calculate derivatives (each np.array of length of x)
    with np.errstate(invalid='ignore'):
        if var == 'A':
            dGIc = -(2*Gi*(Gi/GIc)**(-1+n)*((Gi/GIc)**n+(Gii/GIIc)**m)**(-1+2/(m+n))*n)/(GIc**2*(m+n))  # noqa: E501
            dGIIc = -((2*Gii*((Gi/GIc)**n+(Gii/GIIc)**m)**(-1+2/(m+n))*(Gii/GIIc)**(-1+m)*m)/(GIIc**2*(m+n)))  # noqa: E501
            dn = (((Gi/GIc)**n+(Gii/GIIc)**m)**(-1+2/(m+n))*( 2*(Gi/GIc)**n*(m+n)*np.log(Gi/GIc)-2*((Gi/GIc)**n+(Gii/GIIc)**m)*np.log((Gi/GIc)**n+(Gii/GIIc)**m)))/(m+n)**2  # noqa: E501
            dm = (((Gi/GIc)**n+(Gii/GIIc)**m)**(-1+2/(m+n))*(-2*((Gi/GIc)**n+(Gii/GIIc)**m)*np.log((Gi/GIc)**n+(Gii/GIIc)**m)+2*(Gii/GIIc)**m*(m+n)*np.log(Gii/GIIc)))/(m+n)**2  # noqa: E501
        elif var == 'B':
            dGIc = -(((Gi/GIc)**n*n)/GIc)
            dGIIc = -(((Gii/GIIc)**m*m)/GIIc)
            dn = (Gi/GIc)**n*np.log(Gi/GIc)
            dm = (Gii/GIIc)**m*np.log(Gii/GIIc)
        elif var == 'C':
            dGIc = -((Gi/GIc)**(1/n))/(GIc*n)
            dGIIc = -((Gii/GIIc)**(1/m))/(GIIc*m)
            dn = -((Gi/GIc)**(1/n) * np.log(Gi/GIc))/(n**2)
            dm = -((Gii/GIIc)**(1/m) * np.log(Gii/GIIc))/(m**2)
        elif var == 'BK':
            dGIc = (1 - (Gii/(Gi + Gii))**m)/(Gi + Gii)
            dGIIc = ((Gii/(Gi + Gii))**m)/(Gi + Gii)
            dn = np.zeros_like(Gi)
            dm = ((Gii/(Gi + Gii))**m * (GIIc - GIc)*np.log(Gii/(Gi + Gii)))/(Gi + Gii)
        else:
            raise NotImplementedError(f'Criterion type {var} not implemented.')

    # Stack derivaties (number of rowns corresponds to of length of beta)
    return np.row_stack([dGIc, dGIIc, dn, dm])

def value_jacobian(beta, x, var='B', *args):
    """
    Compute derivates with respect to function arguments (Gi, Gii).
    
    Jacobian = [df/dGi, df/dGii].T

    Parameters
    ----------
    beta : list[float]
        Model parameters (GIc, GIIC, n, m).
    x : list[float]
        Variables (Gi, Gii).
    var : str, optional
        Residual variant. Default is 'B'.

    Returns
    -------
    np.ndarray
        Jacobian matrix.

    Raises
    ------
    NotImplementedError
        If residual variant is not implemented.
    """
    
    # Unpack multidimensional inputs
    GIc, GIIc, n, m = beta
    Gi, Gii = x
    
    # Calculate derivatives (each np.array of length of x)
    if var == 'A':
        dGi = (2*(Gi/GIc)**n*((Gi/GIc)**n+(Gii/GIIc)**m)**(-1+2/(m+n))*n)/(Gi*(m+n))  # noqa: E501
        dGii = (2*((Gi/GIc)**n + (Gii/GIIc)**m)**(-1+2/(m+n))*(Gii/GIIc)**m*m)/(Gii*(m+n))  # noqa: E501
    elif var == 'B':
        dGi = ((Gi/GIc)**n*n)/Gi
        dGii = ((Gii/GIIc)**m*m)/Gii
    elif var == 'C':
        dGi = ((Gi/GIc)**(1/n))/(n*Gi)
        dGii = ((Gii/GIIc)**(1/m))/(m*Gii)
    elif var == 'BK':
        dGi = (-GIc + (1 + m)*(GIc - GIIc)*(Gii/(Gi + Gii))**m)/(Gi + Gii)**2
        dGii = (-GIc*Gii + (Gii - m*Gi)*(GIc - GIIc)*(Gii/(Gi + Gii))**m)/(Gii*(Gi + Gii)**2)
    else:
        raise NotImplementedError(f'Criterion type {var} not implemented.')
    
    # Stack derivaties
    return np.row_stack([dGi, dGii])

def assemble_data(df, dim=1):
    """
    Compile ODR pack data object from data frame.
    
    See https://docs.scipy.org/doc/scipy/reference/
        generated/scipy.odr.RealData.html

    Parameters
    ----------
    df : pd.DataFrame
        Data frame with fracture toughness data.
    dim : int
        Dimensionality of the response function. Target is assumed zero
        and dim=1 indicates a scalar residual function. Default is 1.
        
    Returns
    -------
    data : scipy.odr.RealData
        ODR pack data object.
    ndof : int
        Number of degrees of freedom.
    """
    # Stack 2D data from experiments w/ uncertainties as input array
    exp = np.row_stack(df[['GIc', 'GIIc']].apply(unumpy.nominal_values).values.T)
    std = np.row_stack(df[['GIc', 'GIIc']].apply(unumpy.std_devs).values.T)
    
    # Compute the number of degrees of freedom as the number of
    # observations minus number of of fitted parameters
    ndof = exp.shape[1] - 4
    
    # Pack data in scipy ODR format and return together with DOFs
    return RealData(exp, y=dim, sx=std), ndof

def get_initial_guesses(gc0=0.7, exp=2, indi=False, var='B', verbose=False):
    """
    Assemble matrix of initial guesses.

    Parameters
    ----------
    gc0 : float, optional
        Initial guess for the fracture toughness. Default is 0.6.
    exp : list or int, optional
        List of permitted exponents for the power law or int as
        highest permitted exponent for the power law. Default is 2.
    indi : bool, optional
        If True, exponents of the power law fitted independetly.
        Default is False.
    var : str, optional
        Residual variant. Default is 'B'.

    Returns
    -------
    np.ndarray
        Matrix of initial guesses.
    """
        
    # List of permitted exponents
    if isinstance(exp, tuple) and len(exp) == 2: 
        n0 = [exp[0]]
        m0 = [exp[1]]
    elif isinstance(exp, (list, np.ndarray)):
        n0 = m0 = exp
    else:
        if var == 'C':
            n0 = m0 = np.linspace(1/exp, 1, exp, endpoint=True)   # {n, m} in (0, 1]
        else:
            n0 = m0 = 1 + np.arange(exp)        # {n, m} in [1, inf)

    if verbose:
        print('Running the following initial guesses for the exponents (n, m):')
        print(n0)
        print()

    # Assemble parameter space
    if indi:
        # indi exponents
        return list(product([gc0], [gc0], n0, m0))
    else:
        # Common exponent
        return np.column_stack([np.full([len(n0), 2], gc0), n0, n0])

def run_regression(
        data, model, beta0,
        sstol=1e-12, partol=1e-12,
        maxit=1000, ndigit=12,
        ifixb=[1, 1, 0, 0],
        fit_type=1, deriv=3,
        init=0, iteration=0, final=0):
    """
    Setup ODR object and run regression.
    
    See https://docs.scipy.org/doc/scipy/reference/generated/
        scipy.odr.ODR.html
        scipy.odr.ODR.set_job.html
        scipy.odr.ODR.set_iprint.html

    Parameters
    ----------
    data : ODRData
        Scipy ODRpack data object.
    model : ODRmodel
        Scipy ODRpack model object.
    beta0 : list[float]
        List of initial parameter guesses.
    sstol : float, optional
        Tolerance for residual convergence (<1). Default is 1e-12.
    partol : float, optional
        Tolerance for parameter convergence (<1). Default is 1e-12.
    maxit : int, optional
        Maximum number of iterations. Default is 1000.
    ndigit : int, optional
        Number of reliable digits. Default is 12.
    ifixb : list[int], optional
        0 parameter fixed, 1 parameter free. Default is [1, 1, 0, 0].
    fit_type : int, optional
        0 explicit ODR, 1 implicit ODR. Default is 1.
    deriv : int, optional
        0 finite differences, 3 jacobians. Default is 3.
    init : int, optional
        No, short, or long initialization report. Default is 0.
    iteration : int, optional
        No, short, or long iteration report. Default is 0.
    final : int, optional
        No, short, or long final report. Default is 0.
        
    Returns
    -------
    scipy.odr.Output
        Optimization results object.
    """

    # Setup ODR object
    odr = ODR(
        data,                   # Input data
        model,                  # Model
        beta0=beta0,            # Initial parameter guess
        sstol=sstol,            # Tolerance for residual convergence (<1)
        partol=partol,          # Tolerance for parameter convergence (<1)
        maxit=maxit,            # Maximum number of iterations
        ndigit=ndigit,          # Number of reliable digits
        ifixb=ifixb,            # 0 parameter fixed, 1 parameter free
    )

    # Set job options
    odr.set_job(
        fit_type=fit_type,      # 0 explicit ODR, 1 implicit ODR
        deriv=deriv             # 0 finite differences, 3 jacobians
    )

    # Define outputs
    odr.set_iprint(
        init=init,              # No, short, or long initialization report
        iter=iteration,         # No, short, or long iteration report
        final=final,            # No, short, or long final report
    )

    # Run optimization
    return odr.run()

def calc_fit_statistics(final, ndof):
    """
    Complement fit results dictionary with goodness of fit info.

    Check the scipy user forum (https://scipy-user.scipy.narkive.com/
    ZOHix6nj/scipy-odr-goodness-of-fit-and-parameter-estimation-for-
    explicit-orthogonal-distance-regression) for an explanation of ODR's
    goodness of fit estimation and Wikipedia https://en.wikipedia.org/
    wiki/Reduced_chi-squared_statistic) for an explanation of the
    reduced chi_nu^2 goodness-of-fit indicator.
    
    As a rule of thumb, when the variance of the measurement error is
    known a priori, a $\chi _{\nu }^{2}\gg 1$ indicates a poor model
    fit. A $\chi _{\nu }^{2}>1$ indicates that the fit has not fully
    captured the data (or that the error variance has been underestimated).
    In principle, a value of $\chi _{\nu }^{2}$ around 1 indicates that
    the extent of the match between observations and estimates is in
    accord with the error variance. A $\chi _{\nu }^{2}<1$ indicates that
    the model is "over-fitting" the data: either the model is improperly
    fitting noise, or the error variance has been overestimated.

    Parameters
    ----------
    final : scipy.odr.Output
        Optimization results object.
    ndof : int
        Number of degrees of freedom.
    fit : dict
        Dictoinary to store fit data. Default is defaultdict(dict).

    Returns
    -------
    fit : dict
        Updated dictionary.
    """
    # Initialize dictionary
    fit = defaultdict()
    # Best fit parameters
    fit['params'] = final.beta
    # Standard deviations
    fit['stddev'] = final.sd_beta
    # Goodness of fit per DOF (reduced chi^2 per DOF)
    fit['reduced_chi_squared'] = final.res_var
    # Goodness of fit (chi^2)
    fit['chi_squared'] = ndof*fit['reduced_chi_squared']
    # P-value (result is statistically significant if below 0.05)
    fit['p_value'] = distributions.chi2.sf(fit['chi_squared'], ndof)
    # Goodness of fit (R^2) (not valid for nonlinear regression)
    fit['R_squared'] = 1 - fit['chi_squared']/(ndof + fit['chi_squared'])
    # Write optimization result to dictionary
    fit['final'] = final
    
    # Return updated dictionary
    return fit

def odr(
        df, dim=1, gc0=.6, exp=2, var='B',
        indi=False, ifixb=[1, 1, 0, 0],
        print_results=True, verbose=False):
    """
    Perform orthogonal distance regression (ODR) on the data frame.
    
    Scipy.odr is a wrapper around a much older FORTRAN77 package known as
    ODRPACK. The documentation for ODRPACK can actually be found on the
    scipy website: https://docs.scipy.org/doc/external/odrpack_guide.pdf
    
    See also https://docs.scipy.org/doc/scipy/reference/
             generated/scipy.odr.Model.html

    Parameters
    ----------
    df : pd.DataFrame
        Data frame with energy release rates.
    dim : int
        Dimensionality of the response function. Target is assumed zero
        and dim=1 indicates a scalar residual function. Default is 1.
    gc0 : float
        Initial guesses for the fracture toughnesses. Default is 0.6.
    exp : list or int, optional
        List of permitted exponents for the power law or int as
        highest permitted exponent for the power law. Default is 2.
    var : str, optional
        Residual variant {'A', 'B', 'C', 'BK'}. Default is 'B'.
    indi : bool
        If True, exponents of the power law fitted independetly.
        Default is False.
    ifixb : list[int], optional
        0 parameter fixed, 1 parameter free. Default is [1, 1, 0, 0].
    print_results : bool
        If True, print fit results to console. Default is True.
    """    
    # Assemble ODR pack data object
    data, ndof = assemble_data(df, dim)

    # Compile scipy ODR models
    model = Model(fcn=residual, fjacb=param_jacobian,
                  fjacd=value_jacobian, implicit=True,
                  extra_args=(var,))
    
    # Generate list of initial guesses
    guess = get_initial_guesses(gc0=gc0, exp=exp, indi=indi, var=var, verbose=verbose)

    # Run regression for all guesses and store result only if converged          
    runs = [r for r in (run_regression(data, model, g, ifixb=ifixb) for g in guess) if r.info <= 3]

    # Determine run with smallest sum of squared errors
    final = runs[np.argmin([run.sum_square for run in runs])]
    
    # Compile fit results dictionary with goodness-of-fit info
    fit = calc_fit_statistics(final, ndof)
    fit['var'] = var
    
    # Print fit results to console
    if print_results:
        results(fit)
    
    # Return fit results dictionary
    return fit

def results(fit):
    """
    Print fit results to console.
    
    As a rule of thumb, when the variance of the measurement error is
    known a priori, a $\chi _{\nu }^{2}\gg 1$ indicates a poor model
    fit. A $\chi _{\nu }^{2}>1$ indicates that the fit has not fully
    captured the data (or that the error variance has been underestimated).
    In principle, a value of $\chi _{\nu }^{2}$ around 1 indicates that
    the extent of the match between observations and estimates is in
    accord with the error variance. A $\chi _{\nu }^{2}<1$ indicates that
    the model is "over-fitting" the data: either the model is improperly
    fitting noise, or the error variance has been overestimated.

    Parameters
    ----------
    fit : dict
        Dictionary with optimization results.
    """
    
    # Unpack variables
    GIc, GIIc, n, m = fit['final'].beta
    chi2 = fit['reduced_chi_squared']
    pval = fit['p_value']
    R2 = fit['R_squared']

    # Define the header and horizontal rules
    header = 'Variable      Value    Description'.upper()
    rule = '---'.join(['-' * s for s in [8, 5, 50]])

    # Print the header
    print(header)
    print(rule)

    # Print fit paramters
    print(f"GIc        {GIc:8.3f}    Mode I fracture toughness")
    print(f"GIIc       {GIIc:8.3f}    Mode II fracture toughness")
    print(f"n          {n:8.3f}    Interaction-law exponent")
    print(f"m          {m:8.3f}    Interaction-law exponent")
    print(rule)
    
    # Print goodness-of-fit indicators
    print(f"chi2       {chi2:8.3f}    Reduced chi^2 per DOF (goodness of fit)")
    print(f"p-value    {pval:8.1e}    p-value (statistically significant if below 0.05)")
    print(f"R2         {R2:8.3f}    R-squared (not valid for nonlinear regression)")
    print()
