#!/bin/python

# This script evaluated the performance files generated by the ./run_PFE_perf.sh
# and writes the results to csv files

import csv
import os.path
import re
import statistics
import sys

if (len(sys.argv) != 2):
    #print("Usage: ./eval_perf.py perf_2020_01_01-1337")
    print("Usage: ./eval_perf.py perf_ECC_LAN")
    sys.exit(1)

#file_prefix = sys.argv[1][:20]
file_prefix = sys.argv[1][:12]
print("File prefix", file_prefix)

def get_timing(file_ref, prefix, label):
    file_ref.seek(0)
    values = [float(line.split()[2])/1000 for line in file_ref if line.startswith("[TIME] %s" % label)]
    print("%s] %s: %9.2f ms" % (prefix, label.rjust(36, " "), statistics.mean(values)))
    return statistics.mean(values)

def get_mem(file_ref, label):
    file_ref.seek(0)
    values = [int(line.split()[2]) for line in file_ref if line.startswith("[SEND] %s" % label)]
    if len(values):
        print("%s %9.2f MiB (%10d bytes)" % ((label + ":").ljust(20, " "), values[0]/1024/1024, values[0]))
        return values[0]
    return 0

def get_cryptosystem(file_prefix):
    if (not os.path.isfile("%s_CLIENT_1000" % (file_prefix))):
        print("%s_CLIENT_1000 does not exist" % (file_prefix))
        sys.exit(1)
    
    with open("%s_CLIENT_1000" % file_prefix) as file_client:
       header = [line.strip() for line in file_client if line.startswith("KM11_CRYPTOSYSTEM:")]
       if len(header) == 0:
           print("CRYPTOSYSTEM NOT DETECTED")
           sys.exit(1)
       return header[0].replace('KM11_CRYPTOSYSTEM: KM11_CRYPTOSYSTEM_', '')

cryptosystem = get_cryptosystem(file_prefix)
print("Cryptosystem: %s" % cryptosystem)

csv_fieldnames = ['g', 'setupN', 'setupF', 'onlineX', 'total']

print("Writing mem_%s.csv..." % cryptosystem)
with open("mem_%s.csv" % cryptosystem, 'w') as file_mem_csv, open("time_%s.csv" % cryptosystem, 'w') as file_time_csv:
    csv_mem = csv.DictWriter(file_mem_csv, fieldnames=csv_fieldnames)
    csv_mem.writeheader()
    csv_time = csv.DictWriter(file_time_csv, fieldnames=csv_fieldnames)
    csv_time.writeheader()

    for g in [100, 1000, 10000, 100000, 1000000]:
        print("\n\nPerformance stats for g = %d gates" % g)
        if (not os.path.isfile("%s_SERVER_%s" % (file_prefix, g))):
            print("%s_SERVER_%s does not exist" % (file_prefix, g))
            continue

        if (not os.path.isfile("%s_CLIENT_%s" % (file_prefix, g))):
            print("%s_CLIENT_%s does not exist" % (file_prefix, g))
            continue
             

        with open("%s_SERVER_%s" % (file_prefix, g)) as file_server, open("%s_CLIENT_%s" % (file_prefix, g)) as file_client:
            detected_runs_s = len([line for line in file_server if line.startswith("Timings:")])
            detected_runs_c = len([line for line in file_client if line.startswith("Timings:")])
            print("Detected number of runs: %d (SERVER), %d (CLIENT)" % (detected_runs_s, detected_runs_c))
            if (detected_runs_s == 0 or detected_runs_c == 0):
                continue

            file_server.seek(0)
            time_total_s = statistics.mean([float(line.split()[2]) for line in file_server if line.startswith("Total =")])
            file_client.seek(0)
            time_total_c = statistics.mean([float(line.split()[2]) for line in file_client if line.startswith("Total =")])
            print("\n[S/C] Total time: %.2f ms / %.2f ms" % (time_total_s, time_total_c))

            # Online phase
            file_server.seek(0)
            time_onlineX_s = statistics.mean([float(line.split()[2]) for line in file_server if line.startswith("Online =")])
            file_client.seek(0)
            time_onlineX_c = statistics.mean([float(line.split()[2]) for line in file_client if line.startswith("Online =")])
            print("[S/C] time_onlineX: %.2f ms / %.2f ms" % (time_onlineX_s, time_onlineX_c))

            # setup_N phase
            time_sum_c = 0
            time_sum_s = 0
            time_sum_c += get_timing(file_client, "C", "TIME_CREATE_BLINDING_VALUES")
            time_sum_c += get_timing(file_client, "C", "TIME_PRECOMPUTE_B")
            time_sum_s += get_timing(file_server, "S", "TIME_CREATE_ENCWK")
            time_wait_encwk = get_timing(file_client, "C", "TIME_WAIT_FOR_ENCWK")
            time_sum_c += time_wait_encwk
            
            time_setupN = time_sum_c
            
            time_wait_encgg = get_timing(file_server, "S", "TIME_SEND_ENCWK_WAIT_FOR_ENCGG")
            time_sum_s += time_wait_encgg
            
            time_create_encgg = get_timing(file_client, "C", "TIME_CREATE_ENCGG")
            time_sum_c += time_create_encgg
            get_timing(file_client, "C", "TIME_SEND_ENCGG_WAIT_FOR_OUTPUT_KEYS")
            time_send_encgg_wait_gc = get_timing(file_client, "C", "TIME_SEND_ENCGG_WAIT_FOR_GC")
            time_sum_c += time_send_encgg_wait_gc
            time_create_gc = get_timing(file_server, "S", "TIME_CREATE_GC")
            time_sum_s += time_create_gc
            
            time_setupF = time_create_encgg + time_send_encgg_wait_gc

            print("[S/C] Remaining:   %.2f ms / %.2f ms" % (time_total_s - time_sum_s, time_total_c - time_sum_c))
            print("Not in any phase: %.2f ms" % (time_total_c - time_setupN - time_setupF - time_onlineX_c))

            csv_time.writerow({'g': g, 'setupN': time_setupN, 'setupF': time_setupF, 'onlineX': time_onlineX_c, 'total': time_total_c})

            ### Communication
            file_client.seek(0)
            mem_total_c = 0
            mem_total_c += [int(line.split()[4]) for line in file_client if line.startswith("Total Sent / Rcv")][0]
            file_client.seek(0)
            mem_total_c += [int(line.split()[7]) for line in file_client if line.startswith("Total Sent / Rcv")][0]
            print("\nTotal communication:     %.2f MiB" % (mem_total_c /1024/1024))
            
            mem_setupN  = get_mem(file_server, "BFVpublickey")
            mem_setupN += get_mem(file_server, "m_bEncWireKeys")
            mem_setupN += get_mem(file_server, "MEM_OUTPUT_KEYS")
            mem_setupF  = get_mem(file_client, "m_bEncGG")
            mem_setupF += get_mem(file_server, "m_vGarbledCircuit")
            mem_onlineX = get_mem(file_server, "MEM_INPUT_KEYS")
            mem_sum = mem_setupN + mem_setupF + mem_onlineX
            
            print("Missing communication:    %d bytes" % (mem_total_c - mem_sum))

            csv_mem.writerow({'g': g, 'setupN': mem_setupN, 'setupF': mem_setupF, 'onlineX': mem_onlineX, 'total': mem_total_c})
