# %% [markdown]
# ```text
# Copyright (c) 2024, Bastian Latsch, Technische Universität Darmstadt.
# Some Rights Reserved.
# Except where otherwise noted, this work is licensed under Creative Commons Attribution 4.0 International.
# To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/
#
# You are free to:
#    * Share — copy and redistribute the material in any medium or format for any purpose.
#    * Adapt — remix, transform, and build upon the material for any purpose, even commercially.
#
# Under the following terms::
#    * Attribution — You must give appropriate credit, provide a link to the license, and indicate if changes were made.
#
# For any reuse or distribution, you must make clear to others the license terms of this work.
# The best way to give credit is with a reference to the corresponding publication.
# Any of the above conditions can be waived if you get permission from the copyright holder.
# No warranties are given.
# ```

# %%
# Load README
from IPython.display import Markdown, display
display(Markdown('README.md'))

# %% [markdown]
# ## Read the data from HDF5 formatted file

# %%
import h5py
import numpy as np

# File containing data
h5file = 'data.hdf5'

# Display file description
with h5py.File(h5file, 'r') as file:
    print(file.attrs['description'][0])

# Use the provided data structure
fe_sensors = ['heel', 'lateral', 'medial', 'toe']  # Insole
grf_components = ['vertical_right']  # Treadmill
speeds = tuple(range(50, 151, 25))

# Initialize 2D array to hold data: row = speed, col = sensor
data = np.empty((len(speeds), len(fe_sensors) + len(grf_components)), dtype=object)

# Get ferroelectret insole data for all sensors and all speeds
# Get ground reaction force (GRF) treadmill data for all speeds
for i, sensor in enumerate(fe_sensors + grf_components):
    for j, speed in enumerate(speeds):

        # Create path for this speed and this sensor
        if sensor in fe_sensors:
            this_path = f'/insole/{sensor}/v{speed:03}' # e.g. '/insole/heel/v050'
        else:
            this_path = f'/treadmill/{sensor}/v{speed:03}' # e.g. '/treadmill/vertical_right/v050'

        # Get data from this path
        with h5py.File(h5file, 'r') as file:
            data[j][i] = np.array(file[this_path]).transpose()

# %% [markdown]
# ## Sample plot with the means

# %%
import matplotlib.pyplot as plt

# Initialize mean and standard deviation over gait cycles
data_mean = np.empty_like(data)
data_std = np.empty_like(data)

# Cycle colors
colors = [
    (0.00, 0.51, 0.80),  # Blue
    (0.45, 0.06, 0.56),  # Purple
    (0.90, 0.00, 0.08),  # Red
    (0.98, 0.73, 0.08),  # Yellow
    (0.41, 0.56, 0.23)   # Green
]

# Figure with tiles for each sensor with the means
fig, axs = plt.subplots(len(fe_sensors) + len(grf_components), 1, figsize=(10, 8), tight_layout=True)

for i, sensor in enumerate(fe_sensors + grf_components):

    # Next tile
    ax = axs[i] if len(fe_sensors) + len(grf_components) > 1 else axs
    ax.set_prop_cycle('color', colors)

    for j, speed in enumerate(speeds):

        # Calculate mean over all strides
        if data[j][i] is not None:
            data_mean[j][i] = np.mean(data[j][i], axis=0)
            data_std[j][i] = np.std(data[j][i], axis=0)

        # Plot
        if data_mean[j][i] is not None:
            x = np.linspace(0, 100, len(data_mean[j][i]))
            ax.plot(x, data_mean[j][i], label=f'v{speed:03}')

    # Tile specific formatting
    ax.set_ylabel(f'{sensor} (V)' if sensor in fe_sensors else 'vertical GRF (N)')
    ax.set_xlim([0, 100])

# Add legend and x-label to the last subplot
axs[-1].legend(title='Speed')
axs[-1].set_xlabel("Time (% stride)")

plt.show()


