#!/bin/python
import csv
import os.path
import re
import statistics
import sys

if (len(sys.argv) != 2):
    #print("Usage: ./eval_perf.py perf_2020_01_01-1337")
    print("Usage: ./eval_perf.py perf_UC_LAN")
    sys.exit(1)

#file_prefix = sys.argv[1][:20]
file_prefix = sys.argv[1][:11]
print("File prefix", file_prefix)

def get_timing(file_ref, prefix, label):
    file_ref.seek(0)
    values = [float(line.split()[1]) for line in file_ref if line.startswith("%s" % label)]
    print("%s] %s: %9.2f" % (prefix, label.rjust(36, " "), statistics.mean(values)))
    return statistics.mean(values)

def get_mem(file_ref, label):
    file_ref.seek(0)
    values = [int(line.split()[2]) for line in file_ref if line.startswith("[SEND] %s" % label)]
    if len(values):
        print("%s %9.2f MiB (%10d bytes)" % ((label + ":").ljust(22, " "), values[0]/1024/1024, values[0]))
        return values[0]
    return 0

csv_fieldnames = ['g', 'setupN', 'setupF', 'onlineX', 'total']
with open("mem_UC.csv", 'w') as file_mem_csv, open("time_UC.csv", 'w') as file_time_csv:
    csv_mem = csv.DictWriter(file_mem_csv, fieldnames=csv_fieldnames)
    csv_mem.writeheader()
    csv_time = csv.DictWriter(file_time_csv, fieldnames=csv_fieldnames)
    csv_time.writeheader()

    for g in [100, 1000, 10000, 100000, 1000000]:
        print("\n\nPerformance stats for g = %d gates" % g)
        if (not os.path.isfile("%s_SERVER_%s" % (file_prefix, g))):
            print("%s_SERVER_%s does not exist" % (file_prefix, g))
            continue

        if (not os.path.isfile("%s_CLIENT_%s" % (file_prefix, g))):
            print("%s_CLIENT_%s does not exist" % (file_prefix, g))
            continue
             

        with open("%s_SERVER_%s" % (file_prefix, g)) as file_server, open("%s_CLIENT_%s" % (file_prefix, g)) as file_client:
            detected_runs_s = len([line for line in file_server if line.startswith("Timings:")])
            detected_runs_c = len([line for line in file_client if line.startswith("Timings:")])
            print("Detected number of runs: %d (SERVER), %d (CLIENT)" % (detected_runs_s, detected_runs_c))
            if (detected_runs_s == 0 or detected_runs_c == 0):
                continue

            file_client.seek(0)
            time_total_c = statistics.mean([float(line.split()[2]) for line in file_client if line.startswith("Total =")])
            file_server.seek(0)
            time_total_s = statistics.mean([float(line.split()[2]) for line in file_server if line.startswith("Total =")])
            print("\n[S/C] Total time: %.2f / %.2f " % (time_total_s, time_total_c))
           
            # in the setupN phase, OTExtension and garbling takes place
            file_client.seek(0)
            time_setupN_c = statistics.mean([float(line.split()[2]) for line in file_client if line.startswith("OTExtension =")])
            file_server.seek(0)
            time_setupN_s = statistics.mean([float(line.split()[2]) for line in file_server if line.startswith("OTExtension =")])
            
            file_client.seek(0)
            time_setupN_c += statistics.mean([float(line.split()[2]) for line in file_client if line.startswith("Garbling =")])
            file_server.seek(0)
            time_setupN_s += statistics.mean([float(line.split()[2]) for line in file_server if line.startswith("Garbling =")])

            print("[S/C] time_setupN (garbling) %.2f / %.2f" % (time_setupN_s, time_setupN_c))


            time_layer0c_s = get_timing(file_client, "S", "TIME_CIRCUIT_LAYER_0_COMPUTATION")
            time_layer0i_s = get_timing(file_client, "S", "TIME_CIRCUIT_LAYER_0_INTERACTION")
            time_layer0c_c = get_timing(file_client, "C", "TIME_CIRCUIT_LAYER_0_COMPUTATION")
            time_layer0i_c = get_timing(file_client, "C", "TIME_CIRCUIT_LAYER_0_INTERACTION")
            
            time_layer1c_s = get_timing(file_client, "S", "TIME_CIRCUIT_LAYER_1_COMPUTATION")
            time_layer1i_s = get_timing(file_client, "S", "TIME_CIRCUIT_LAYER_1_INTERACTION")
            time_layer1c_c = get_timing(file_client, "C", "TIME_CIRCUIT_LAYER_1_COMPUTATION")
            time_layer1i_c = get_timing(file_client, "C", "TIME_CIRCUIT_LAYER_1_INTERACTION")
            
            time_layer2c_s = get_timing(file_client, "S", "TIME_CIRCUIT_LAYER_2_COMPUTATION")
            time_layer2i_s = get_timing(file_client, "S", "TIME_CIRCUIT_LAYER_2_INTERACTION")
            time_layer2c_c = get_timing(file_client, "C", "TIME_CIRCUIT_LAYER_2_COMPUTATION")
            time_layer2i_c = get_timing(file_client, "C", "TIME_CIRCUIT_LAYER_2_INTERACTION")

            # programming bits are the input gates of the circuit (layer 0)
            time_setupF = time_layer0c_c + time_layer0i_c
            print("time_setupF", time_setupF)

            # local evaluation of the GC takes place in the client on layer 3
            time_onlineX = time_layer2c_c + time_layer2i_c
            print("time_onlineX", time_onlineX)

            csv_time.writerow({'g': g, 'setupN': time_setupN_c, 'setupF': time_setupF, 'onlineX': time_onlineX, 'total': time_total_c})
            print("[S/C] Remaining: %.2f / %.2f" % (time_total_s - time_setupN_s - time_setupF - time_onlineX, time_total_c - time_setupN_c - time_setupF - time_onlineX))

            ### Communication
            file_client.seek(0)
            mem_total_c = 0
            mem_total_c += [int(line.split()[4]) for line in file_client if line.startswith("Total Sent / Rcv")][0]
            file_client.seek(0)
            mem_total_c += [int(line.split()[7]) for line in file_client if line.startswith("Total Sent / Rcv")][0]
            print("\nTotal communication:       %.2f MiB" % (mem_total_c /1024/1024))
            
            mem_setupN  = get_mem(file_server, "m_vGarbledCircuit")
            mem_setupN += get_mem(file_server, "m_vUniversalGateTable")
            mem_setupN += get_mem(file_server, "m_vOutputShareSndBuf")
            mem_setupF  = get_mem(file_server, "ServerInputKeys")
            mem_setupF += get_mem(file_server, "ClientInputKeys")
            mem_onlineX = get_mem(file_server, "OT")

            mem_sum = mem_setupN + mem_setupF + mem_onlineX

            csv_mem.writerow({'g': g, 'setupN': mem_setupN, 'setupF': mem_setupF, 'onlineX': mem_onlineX, 'total': mem_total_c})

            print("Missing communication:      %d bytes" % (mem_total_c - mem_sum))
