import pickle
from collections import deque, defaultdict
import datetime
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import pandas as pd
import os
from tqdm import tqdm


input_folder = r"C:\Users\simon\Desktop\Uni\Master\Master_Thesis\Versuchsdaten\Versuchsauswertung"

def sort_for_time(data, time_groups, id_tracker, group_size=4):
    """
    Sorts data packets by time and groups them based on a specified group size.
    """
    output_groups = []

    for packet in reversed(data):
        if packet['sensortype'] in ['machine', 'force']:
            time_key = packet["time"]
            packet_id = f"{packet['sensortype']}{packet['sensoridx']}"

            if time_key in processed_times:
                continue
            if packet_id not in id_tracker[time_key]:
                time_groups[time_key].append(packet)
                id_tracker[time_key].add(packet_id)

                if len(time_groups[time_key]) == group_size:
                    sorted_group = sorted(time_groups[time_key], key=lambda x: x['time'])
                    output_groups.append(sorted_group)
                    processed_times.append(time_key)
                    if len(processed_times) > num_entries * 5:
                        processed_times.popleft()
                    del time_groups[time_key]
                    del id_tracker[time_key]
    return output_groups, time_groups, id_tracker


def sort_buffer_by_time(buffer):
    """Sorts the buffer based on the 'time' key."""
    return deque(sorted(buffer, key=lambda x: x[0]['time'] if x else datetime.datetime.max))


def process_sorted_buffer(buffer):
    """
    Processes the sorted buffer, ensuring packets are grouped by consistent time intervals.
    """
    global last_time
    output_groups = []

    while buffer:
        packet = buffer[0]
        time_key = packet[0]['time']

        if last_time is None or time_key == last_time:
            current_group = []
            while buffer and buffer[0][0]['time'] == time_key:
                current_group.append(buffer.popleft())
            output_groups.extend(current_group)
            last_time = time_key + datetime.timedelta(seconds=1)
        else:
            output_groups.append(str(last_time))
            last_time += datetime.timedelta(seconds=1)
    return output_groups


def sort_data_type(sorted_data_groups, time_groups, list_time_groups):
    spindle_speed = []
    position = []
    force_axial = []
    force_radial_1 = []
    force_radial_2 = []
    Force = True
    Maschine = True
    for sorted_data in sorted_data_groups:
        if isinstance(sorted_data, str):
            Force = True
            Maschine = True

            if sorted_data in list_time_groups:

                index = list_time_groups.index(sorted_data)
                value = time_groups[index]
                num_for = 0
                for_compl = False
                for packet in value:
                    if packet['sensortype'] == 'force':
                        num_for += 1
                if num_for == 3:
                    for_compl = True

                for packet in value:
                    if packet['sensortype'] == 'machine':
                        Maschine = False
                        for entry in packet['data']:
                            spindle_speed.append(
                                entry['axes'][0]['speedorfeedact'])  # Achse 0 ist die Drehzahl der Spindel
                            position.append(entry['axes'][0]['positionact'])  # müsste der gesuchte Winkel sein
                        print('Maschinendaten eingefügt')
                    if packet['sensortype'] == 'force' and for_compl:
                        Force = False
                        if packet['sensoridx'] == 0:
                            force_axial.extend(packet['data'])
                        elif packet['sensoridx'] == 1:
                            force_radial_1.extend(packet['data'])
                        elif packet['sensoridx'] == 2:
                            force_radial_2.extend(packet['data'])
                        print('Kräfte eingefügt')
                if Force:
                    force_axial.extend(np.full(20000, np.nan).tolist())
                    force_radial_1.extend(np.full(20000, np.nan).tolist())
                    force_radial_2.extend(np.full(20000, np.nan).tolist())
                    print('Kraft Fehlt')
                if Maschine:
                    spindle_speed.extend(np.full(250, np.nan).tolist())
                    position.extend(np.full(250, np.nan).tolist())
                    print('Maschine Fehlt')
            else:
                force_axial.extend(np.full(20000, np.nan).tolist())
                force_radial_1.extend(np.full(20000, np.nan).tolist())
                force_radial_2.extend(np.full(20000, np.nan).tolist())
                spindle_speed.extend(np.full(250, np.nan).tolist())
                position.extend(np.full(250, np.nan).tolist())
        else:
            for packet in sorted_data:
                if packet['sensortype'] == 'machine':
                    for entry in packet['data']:
                        spindle_speed.append(entry['axes'][0]['speedorfeedact'])  # Achse 0 ist die Drehzahl der Spindel
                        position.append(entry['axes'][0]['positionact'])  # müsste der gesuchte Winkel sein
                if packet['sensortype'] == 'force':
                    if packet['sensoridx'] == 2:
                        force_axial.extend(packet['data'])
                    elif packet['sensoridx'] == 1:
                        force_radial_1.extend(packet['data'])
                    elif packet['sensoridx'] == 0:
                        force_radial_2.extend(packet['data'])
    return spindle_speed, position, force_axial, force_radial_1, force_radial_2


def scale_force(force):  # 2 ** 15 equals 32768 → represents the maximum value of a signed 16-bit integer
    force_scaled = [(f / (2 ** 15)) * 1500 for f in force]
    # After normalizing, the scaled force is multiplied by 1500
    #       → force sensor was calibrated such that a normalized value of 1.0 corresponds to 1500 N
    # maybe change to another value
    # 1500 durch maximal möglichen Wert meiner KMP austauschen → was die KMP an Kraft erfassen kann!
    return force_scaled

def interpolate_c_p_and_n(spindle_speed, position, force_axial_clean):
    """
    Interpolates spindle speed, position, and current to align with the target length.
    spindle speed, position, and current mit 250Hz aufgezeichnet und werden an Abtastrate KMP angepasst.
    Wichtig: wird angepasst an KMP NACH downsampling!!!!!!
    """
    target_length = len(force_axial_clean)
    if len(spindle_speed) == 0 or len(position) == 0:
        print("Empty input detected. Filling with NaN.")
        return [np.nan] * target_length, [np.nan] * target_length, [np.nan] * target_length

    indices = np.linspace(0, target_length - 1, num=len(spindle_speed))
    new_indices = np.arange(target_length)
    spindle_speed_interp = np.interp(new_indices, indices, spindle_speed)
    position_interp = np.interp(new_indices, indices, position)
    return spindle_speed_interp, position_interp


def downsample_signal(data, original_rate, new_rate):
    """
    Downsamples the input signal from the original rate to the new rate.
    """
    decimation_factor = original_rate // new_rate
    return signal.decimate(data, decimation_factor, ftype='iir', zero_phase=True).tolist()

output_folder = fr"{input_folder}\grob_bereinigt"
os.makedirs(output_folder, exist_ok=True)

# Count the total number of files in the input folder
all_files = [f for f in os.listdir(input_folder) if f.endswith(".pkl")]
total_files = len(all_files)

# Create a progress bar for the total progress
with tqdm(total=total_files, desc="Total Progress", unit="file") as total_progress:
    # Main processing loop
    for filename in all_files:
        input_filepath = os.path.join(input_folder, filename)

        # Load the data
        with open(input_filepath, 'rb') as f:
            data = pickle.load(f)
        print(f"\nData loaded from {filename}")

        processed_times = deque()
        num_entries = 5000  # Adjust as necessary --------------------------------------------------------------
        time_groups = defaultdict(list)
        id_tracker = defaultdict(set)
        buffer = deque()
        last_time = None

        # Sorting by time
        sorted_data_groups, time_groups, id_tracker = sort_for_time(data, time_groups, id_tracker)
        del data

        filtered_time_groups = [value for key, value in time_groups.items() if value]
        list_time_groups = [str(key) for key, value in time_groups.items() if value]

        buffer.extend(sorted_data_groups)
        buffer = sort_buffer_by_time(buffer)
        sorted_data_groups_final = process_sorted_buffer(buffer)
        del sorted_data_groups

        # Extract data types
        spindle_speed, position, force_axial, force_radial_1, force_radial_2 = sort_data_type(
            sorted_data_groups_final, filtered_time_groups, list_time_groups)
        del sorted_data_groups_final

        force_axial_scaled = scale_force(force_axial)
        force_radial_1_scaled = scale_force(force_radial_1)
        force_radial_2_scaled = scale_force(force_radial_2)
        del force_axial, force_radial_1, force_radial_2

        # Scale and interpolate data
        """
        Hier original_rate ggf anpassen. Ladungsverstärker kann bis zu 50.000 Abtastungen pro Sekunde
        Schauen, was KMP kann und dann entsprechend samplen. 
        KMP nimmt 20kHz auf, also passt es!!!
        """
        force_axial_filtered = downsample_signal(force_axial_scaled, 20000, 10000)
        force_radial_1_filtered = downsample_signal(force_radial_1_scaled, 20000, 10000)
        force_radial_2_filtered = downsample_signal(force_radial_2_scaled, 20000, 10000)
        del force_axial_scaled, force_radial_1_scaled, force_radial_2_scaled

        force_axial_clean = force_axial_filtered  # wenn nicht bei Null beginnt auskommentieren und das darüber nehmen
        spindle_speed_interp, position_interp = interpolate_c_p_and_n(
            spindle_speed, position, force_axial_clean)

        plt.rcParams['font.family'] = 'Charter'

        variables = [
            (list(spindle_speed_interp), 'Spindle Speed (RPM)'),
            (list(position_interp), 'Position'),
            (force_axial_filtered, 'Axial Force (N)'),
            (force_radial_1_filtered, 'Radial Force 1 (N)'),
            (force_radial_2_filtered, 'Radial Force 2 (N)')
        ]

        fig, axs = plt.subplots(len(variables), 1, figsize=(8, 6), sharex=True)
        u = 0

        # Plotten der Daten für jede Achse
        for var, name in variables:
            timet = [j / 10000 for j in range(len(var))]
            axs[u].plot(timet, var)
            axs[u].set_title(name, fontsize=10, loc='left', y=0.95)
            axs[u].grid(True)
            u += 1

        # Gemeinsame x-Achsenbeschriftung
        axs[-1].set_xlabel('Time (s)')
        axs[-1].set_xticks(np.arange(0, max(timet), 10))  # letzte Zahl entspricht Schritten auf x-Achse

        fig.suptitle(f'Experimental Data Series: {os.path.splitext(filename)[0]}', fontsize=12)
        plt.tight_layout(h_pad=0, rect=[0, 0, 1, 1])
        plt.show()

        plt.pause(15)

        # Benutzereingabe für den Zeitbereich
        start_time = float(input("Gib die Startzeit in Sekunden ein: "))
        end_time = float(input("Gib die Endzeit in Sekunden ein: "))

        plt.close()

        # Initialisiere eine leere Liste, um alle gekürzten Daten für den CSV-Export zu speichern
        csv_data = {}

        # Neuen Plot mit zugeschnittenen Daten
        fig, axs = plt.subplots(len(variables), 1, figsize=(8, 6), sharex=True)
        u = 0

        for var, name in variables:
            timet = [j / 10000 for j in range(len(var))]

            start_idx = next(i for i, t in enumerate(timet) if t >= start_time)
            end_idx = next(i for i, t in enumerate(timet) if t >= end_time)

            var_cut = var[start_idx:end_idx]
            timet_cut = timet[start_idx:end_idx]

            axs[u].plot(timet_cut, var_cut)
            axs[u].set_title(name, fontsize=10, loc='left', y=0.95)
            axs[u].grid(True)

            # Speichern der gekürzten Daten in der Liste (für den CSV-Export)
            csv_data[name] = var_cut
            u += 1

        # Gemeinsame x-Achsenbeschriftung
        axs[-1].set_xlabel('Time (s)')
        fig.suptitle(f'Experimental Data Series: {os.path.splitext(filename)[0]} - Cut Data', fontsize=12)
        plt.tight_layout(h_pad=0, rect=[0, 0, 1, 1])
        plt.show()
        plt.pause(15)
        # Berechnung der relativen Zeit (Zeit relativ zum Startzeitpunkt)
        relative_time = [t - start_time for t in timet_cut]

        csv_data['Time (s)'] = relative_time

        # Erstellen eines pandas DataFrame
        output_filename = f'{os.path.splitext(filename)[0]}_gekürzt.csv'
        output_filepath = os.path.join(output_folder, output_filename)
        df = pd.DataFrame(csv_data)
        df.to_csv(output_filepath, index=False)
        print(f"Processed data saved to {output_filename}")

        # Update the total progress bar
        total_progress.update(1)

print("All files have been processed successfully.")
