import corner
import numpy as np
import h5py
import pandas as pd
from matplotlib.ticker import NullFormatter
import matplotlib as mpl
import matplotlib.pyplot as plt


lab = "Delta_0_1_sigma_-0p5_0p5_samples_7e5_partial_NLO_cs2_2o3_"

# ----------
# Formatting
# ----------
mpl.rcParams.update({
    'mathtext.fontset'            : 'cm',
    'axes.unicode_minus'          : False,
    'axes.formatter.use_mathtext' :  True,
})
plt.rcParams.update({
    'font.size': 14,
    'figure.figsize'    : [8.0, 5.0],
    'font.family'       : "serif",
    'font.serif'        : "cmr10",
    'xtick.major.size'  : 6,
    'xtick.minor.size'  : 3,
    'ytick.major.size'  : 6,
    'ytick.minor.size'  : 3,
    'axes.labelpad'     : 0,
})
LabelSize=14
MarkerSize=75
MarkerSize=95
TickSize=20
LegendSize=14
nullfmt=NullFormatter() # no labels

SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18
BIGGEST_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# ----------
# Import data
# ----------
filename = "pQCD_Ensemble_"+lab+"out.h5"
#print(filename)
with h5py.File(filename, 'r') as hdf:
    dset = hdf['data']
    datas = dset[:]
    col_names = dset.attrs['column_names']

pQCD_data_N2LO_0mg = pd.DataFrame(datas, columns=col_names)


# ----------
# Plot
# ----------
SAVE = False

samples = pQCD_data_N2LO_0mg[['Delta(muB0)[GeV]','sigma']].values
labels = [
        r"$\Delta_{\rm CFL}^{\!\!*}$ [GeV]",
        r"$\sigma$",
]
ndim, nsamples = len(samples[0]), len(samples)
figure = corner.corner(
    samples,
    bins=30,
    levels=[0.95],
    color='lightgray',
    hist_kwargs=dict(density=True),
    labelpad=3
)
corner.corner(
    samples,
    bins = 30,
    weights = pQCD_data_N2LO_0mg['weight_2M'],
    levels=[0.68,0.95],
    color='C1',
    fig=figure,
    hist_kwargs=dict(density=True),
)
corner.corner(
    samples,
    bins = 30,
    weights = pQCD_data_N2LO_0mg['weight_MMAX'],
    levels=[0.68, 0.95],
    color='C2',
    fig=figure,
    hist_kwargs=dict(density=True),
)
# Extract the axes
axes = np.array(figure.axes).reshape((ndim, ndim))
for a in axes[np.triu_indices(ndim)]:
    a.remove()

# Sigma histogram
axes[1,1].axvline(-0.23, color='black', lw=2)
axes[1,0].axhline(-0.23, color='black', lw=2)

# Plot quantile
delta_95 = corner.quantile(samples[:,0], 0.95, weights=pQCD_data_N2LO_0mg['weight_2M'])[0]
axes[0,0].axvline(delta_95, c='C1', ls='--')
# "conservative" text
axes[0,0].text(x=delta_95+0.05, y=2, s="``conservative'' 95%", c='C1', rotation=90, fontsize=13.5)

# pQCD labels
axes[1,0].text(y=-0.23+0.05, x=0.4, s=r'weak-coupling', c='black', fontsize=12.5)

# Final formatting things
axes[1,0].tick_params(direction="in")
plt.xticks(rotation=0, fontsize=12)
plt.yticks(rotation=0, fontsize=12)
axes[1,0].set_xlabel(labels[0], labelpad=4, fontsize=13)
axes[1,0].set_ylabel(labels[1], labelpad=4, fontsize=13)

plt.tight_layout()
plt.savefig('corner_'+lab[:-1]+'_zenodo.pdf', bbox_inches='tight')
plt.close()


