#!/bin/python

# This script evaluated the performance files generated by the ./run_PFE_perf.sh
# and writes the results to csv files

import csv
import os.path
import re
import statistics
import sys

if (len(sys.argv) != 2):
    #print("Usage: ./eval_perf.py perf_2020_01_01-1337")
    print("Usage: ./eval_perf.py perf_MPC_LAN")
    sys.exit(1)

file_prefix = sys.argv[1][:12]
print("File prefix", file_prefix)

def get_timing(file_ref, prefix, label):
    file_ref.seek(0)
    values = [float(line.split()[2])/1000 for line in file_ref if line.startswith("[TIME] %s" % label)]
    print("%s] %s: %9.2f ms" % (prefix, label.rjust(36, " "), statistics.mean(values)))
    return statistics.mean(values)

def get_mem(file_ref, label):
    file_ref.seek(0)
    values = [int(line.split()[2]) for line in file_ref if line.startswith("[SEND] %s" % label)]
    if len(values):
        print("%s %9.2f MiB (%10d bytes)" % ((label + ":").ljust(20, " "), values[0]/1024/1024, values[0]))
        return values[0]
    return 0

print("NO PFE")

csv_fieldnames = ['g', 'total']

print("Writing mem_MPC.csv and time_MPC.csv...")
with open("mem_MPC.csv", 'w') as file_mem_csv, open("time_MPC.csv", 'w') as file_time_csv:
    csv_mem = csv.DictWriter(file_mem_csv, fieldnames=csv_fieldnames)
    csv_mem.writeheader()
    csv_time = csv.DictWriter(file_time_csv, fieldnames=csv_fieldnames)
    csv_time.writeheader()

    for g in [100, 1000, 10000, 100000, 1000000]:
        print("\n\nPerformance stats for g = %d gates" % g)
        if (not os.path.isfile("%s_SERVER_%s" % (file_prefix, g))):
            print("%s_SERVER_%s does not exist" % (file_prefix, g))
            continue

        if (not os.path.isfile("%s_CLIENT_%s" % (file_prefix, g))):
            print("%s_CLIENT_%s does not exist" % (file_prefix, g))
            continue
             

        with open("%s_SERVER_%s" % (file_prefix, g)) as file_server, open("%s_CLIENT_%s" % (file_prefix, g)) as file_client:
            detected_runs_s = len([line for line in file_server if line.startswith("Timings:")])
            detected_runs_c = len([line for line in file_client if line.startswith("Timings:")])
            print("Detected number of runs: %d (SERVER), %d (CLIENT)" % (detected_runs_s, detected_runs_c))
            if (detected_runs_s == 0 or detected_runs_c == 0):
                continue

            file_server.seek(0)
            time_total_s = statistics.mean([float(line.split()[2]) for line in file_server if line.startswith("Total =")])
            file_client.seek(0)
            time_total_c = statistics.mean([float(line.split()[2]) for line in file_client if line.startswith("Total =")])
            print("\n[S/C] Total time: %.2f ms / %.2f ms" % (time_total_s, time_total_c))

            csv_time.writerow({'g': g, 'total': time_total_c})

            ### Communication
            file_client.seek(0)
            mem_total_c = 0
            mem_total_c += [int(line.split()[4]) for line in file_client if line.startswith("Total Sent / Rcv")][0]
            file_client.seek(0)
            mem_total_c += [int(line.split()[7]) for line in file_client if line.startswith("Total Sent / Rcv")][0]
            print("\nTotal communication:     %.2f MiB" % (mem_total_c /1024/1024))

            csv_mem.writerow({'g': g, 'total': mem_total_c})
