import glob
import datetime
import argparse
import random
import math
import numpy as np
import pandas as pd
import matplotlib.colors as mcolors
from dateutil.relativedelta import relativedelta

from utils import *


def match_one_txtfile(h5_file, txt_folder="mapMatch_result/rlt_/", sample_size=400, group_size=10, selected_tripIDs=None, outputfolder="visualization/map_match/"):
    nb = h5_file.split("/")[-1].split(".h5")[0]
    txt_file=txt_folder+str(nb)+".txt"
    with open(txt_file, 'r') as file:
        lines = file.readlines()
        if not os.path.exists(outputfolder+str(nb)+"/"):
            os.makedirs(outputfolder+str(nb)+"/")
        if len(glob.glob(outputfolder+str(nb)+"/*.html"))!=math.ceil(sample_size/group_size):
            print(txt_file+"is visited.")
            return

        all_df = read_h5(h5_file, time_converse=False)
        all_df["speed_"]=all_df['speed_'].round(3)
        all_df["timestamp"] = pd.to_timedelta(all_df.secs, unit='s')+datetime.datetime(2020, 10, 1)
        all_df.drop(["secs"], axis=1, inplace=True)
        
        multipleTraces, multipleRoutes, multipleProjection={}, {}, {}
        starttrip, rd = int(nb)*ptsDf_group_size, 0
        line_nb = len(lines)
        random_integers = random.sample(range(1, line_nb), sample_size)

        # for line in tqdm(file.readlines()[1:], desc="Plot "+txt_file, unit="item"):
        for i in tqdm(random_integers, desc="Plot "+txt_file, unit="item"):
            l = lines[i].strip()
            tripID = int(l.split(",")[0])
            if selected_tripIDs is not None and tripID not in selected_tripIDs:
                continue
            multipleTraces[tripID] = [[], [], [], []] #multipleTraces[tripID][0], obs_lat, obs_color, obs_info
            multipleRoutes[tripID] = [[], [], None, []] #lons, lats, None, info
            
            df = all_df[all_df.tripID==tripID]
            _, _, projected_lon, projected_lat = get_accurate_start_end_point(df, streetmap, edgesDf)
            multipleProjection[tripID] = [projected_lon, projected_lat]

            raw_df = all_raw_df[all_raw_df.tripID==tripID]          

            multipleTraces[tripID][1]=multipleTraces[tripID][1]+raw_df.lat.values.tolist()+[None]
            multipleTraces[tripID][0]=multipleTraces[tripID][0]+raw_df.lon.values.tolist()+[None]
            for index in range(len(raw_df)):
                if index in df.index.unique():
                    row = df.loc[index]
                    multipleTraces[tripID][3]=multipleTraces[tripID][3]+['tripID: '+str(row["tripID"])+
                        "<br>Timestamps: "+str(row["timestamp"])+
                        "<br>Coarse edge: "+str(row["edge"])+
                        "<br>Uturn: "+str(row["uturn"])+
                        "<br>Speed: "+str(row["speed_"])+
                        "<br>Direction: "+str(row["dir"])+
                        '<br>Fraction: '+str(row["frcalong"])]
                else:
                    row = raw_df.iloc[index]
                    multipleTraces[tripID][3]=multipleTraces[tripID][3]+['tripID: '+str(row["tripID"])+
                        "<br>Timestamps: "+str(row["timestamp"])]
            multipleTraces[tripID][3]=multipleTraces[tripID][3]+[' ']
            multipleTraces[tripID][2]= multipleTraces[tripID][2]+colorFader(raw_df.timestamp, c2='#FDC9C9', c1='#920808')  + ["#000000"]

            edges = [int(i) for i in l.split(",")[1].split(" ")]
            unique_full_edges = np.concatenate([coarse2full_edge[i] for i in edges])
            selected_edge=streetmap.loc[unique_full_edges]
            geo_list = []
            for index, row in selected_edge.iterrows():
                for j in row.geometry.coords:
                    geo_list += [list(j)]
                multipleRoutes[tripID][3] = multipleRoutes[tripID][3] + ['Road type:'+row["type"]+
                                '<br>tripID: '+str(tripID)+
                                '<br>Coarse edge: '+str(row["c_edge"]) for i in range(len(row.geometry.coords))]+[' ']
                geo_list += [[None, None]]
            geo_list = np.asarray(geo_list)
            multipleRoutes[tripID][0], multipleRoutes[tripID][1] = geo_list[:,0], geo_list[:,1]
            if (rd+1)%group_size==0:
                plot_tool.plot_trace(outputpath=outputfolder+str(nb)+"/"+str(rd//group_size)+".html", background=background, multipleRoutes=multipleRoutes, multipleTraces=multipleTraces, multipleProjection=multipleProjection)
                starttrip = tripID+1
                multipleTraces, multipleRoutes, multipleProjection={}, {}, {}
            rd += 1
        if rd%group_size!=0:
            plot_tool.plot_trace(outputpath=outputfolder+str(nb)+"/"+str(rd//group_size+1)+".html", background=background, multipleRoutes=multipleRoutes, multipleTraces=multipleTraces, multipleProjection=multipleProjection)


        #     if (rd+1)%group_size==0:
        #         plot_tool.plot_trace(outputpath=outputfolder+str(nb)+"/"+str(rd//group_size)+"_"+str(starttrip)+".html", background=background, multipleRoutes=multipleRoutes, multipleTraces=multipleTraces, multipleProjection=multipleProjection)
        #         starttrip = tripID+1
        #         multipleTraces, multipleRoutes, multipleProjection={}, {}, {}
        #     rd += 1
        # if rd%group_size!=0:
        #     plot_tool.plot_trace(outputpath=outputfolder+str(nb)+"/"+str(rd//group_size+1)+"_"+str(starttrip)+".html", background=background, multipleRoutes=multipleRoutes, multipleTraces=multipleTraces, multipleProjection=multipleProjection)

def selected_roads_plot(input_folder="mapMatch_result/rlt_/",
                        outputfolder="visualization/map_match/",
                        all_tripID=None):
    # parse the map matching result to get unique coarse roads and full roads
    edges= []
    print("Read text files.")
    edge_day_dict = {i:[] for i in all_tripID}
    for txt_file in tqdm(glob.glob(input_folder+"*.txt")):
        if txt_file.split("/")[-1]=="visited.txt":
            continue
        h5_file = input_folder+"viterbi/"+txt_file.split("/")[-1].split(".txt")[0]+".h5"        
        all_df = read_h5(h5_file, time_converse=False)
        all_df["timestamp"] = pd.to_timedelta(all_df.secs, unit='s')+datetime.datetime(2020, 10, 1)
        all_df["day"]= (all_df['timestamp'] - pd.Timestamp('2020-10-01')).dt.days
        ## don't consider the same tripID in different days
        all_df = all_df.drop_duplicates(subset=["tripID"], keep="first")
        tripID_time_dict = dict(zip(all_df.tripID, all_df.day))
        
        with open(txt_file, 'r') as file:
            for line in file.readlines()[1:]:
                l = line.strip()
                day_ = tripID_time_dict[int(l.split(",")[0])]
                edges_to_add = [int(i) for i in l.split(",")[1].split(" ")]
                for edge in edges_to_add:
                    edge_day_dict[edge].append(day_)
                edges += edges_to_add
    edge_nunique_day = {i:len(np.unique(edge_day_dict[i])) for i in edge_day_dict.keys()}

    unique_coarse_edges, counts=np.unique(edges, return_counts=True)
    counts = [counts[i]/edge_nunique_day[unique_coarse_edges[i]] for i in range(len(unique_coarse_edges))]

    coarse_count = dict(zip(unique_coarse_edges, counts))
    normalized_counts = (counts - np.min(counts)) / (np.max(counts) - np.min(counts))
    cmap = plt.get_cmap('cool')
    rgb_colors = [cmap(norm) for norm in normalized_counts]
    hex_colors = [mcolors.to_hex(rgb) for rgb in rgb_colors]
    coarse_color = dict(zip(unique_coarse_edges, hex_colors))

    full_edges = np.concatenate([coarse2full_edge[i] for i in unique_coarse_edges])
    unique_full_edges = np.unique(full_edges)

    geo_list, info, color = [], [], []
    selected_edge=streetmap.loc[unique_full_edges]

    print("Plotting heatmap for selected roads.")
    for index, row in tqdm(selected_edge.iterrows()):
        lst = [np.asarray(j) for j in row.geometry.coords]
        lst = add_intermediate_coords(lst)
        geo_list += lst+[[None, None]]
        nodes = [row["source"]]+[None for i in range(len(lst)-2)]+[row["target"]]
        info = info + ['Road type:'+row["type"]+
                        '<br>Node: '+str(nodes[i])+
                        '<br>Average Daily Count: '+str(coarse_count[full2coarse_edge[index]])+
                        '<br>Source: '+str(row['source'])+
                        '<br>Target: '+str(row['target']) for i in range(len(lst))] + [' ']
        color += [coarse_color[full2coarse_edge[index]] for i in range(len(lst))] + ["#000000"]
    geo_list = np.asarray(geo_list)
    lons, lats = geo_list[:,0], geo_list[:,1]
    roads = [lons, lats, color, info]

    plot_tool.plot_trace(outputpath=outputfolder+"heatmap_selected_roads.html", marker_size=5, line_marker_size=5, background=background, routes=roads)

    print("Plotting selected roads.")
    geo_list, info, color = [], [], []
    for index, row in tqdm(selected_edge.iterrows()):
        lst = [np.asarray(j) for j in row.geometry.coords]
        geo_list += lst+[[None, None]]
        nodes = [row["source"], row["target"]]
        info = info + ['Road type:'+row["type"]+
                        '<br>Node: '+str(nodes[i])+
                        '<br>Count: '+str(coarse_count[full2coarse_edge[index]])+
                        '<br>Source: '+str(row['source'])+
                        '<br>Target: '+str(row['target']) for i in range(len(lst))] + [' ']
    geo_list = np.asarray(geo_list)
    lons, lats = geo_list[:,0], geo_list[:,1]
    roads = [lons, lats, None, info]
    plot_tool.plot_trace(outputpath=outputfolder+"selected_roads.html", marker_size=5, line_marker_size=5, background=background, routes=roads)

    print("Plotting projected location.")
    lon, lat = [],[]
    for h5_file in glob.glob(input_folder+"viterbi/*.h5"):
        df = read_h5(h5_file, time_converse=False)
        _, _, projected_lon, projected_lat = get_accurate_start_end_point(df, streetmap, edgesDf)
        lon.append(projected_lon)
        lat.append(projected_lat)
    lon, lat=np.concatenate(lon), np.concatenate(lat)
    plot_tool.heatmap_plot(data=pd.DataFrame({"lon":lon, "lat":lat}), outputpath=outputfolder+"projection_heatmap.html")
    return 

def get_background(street_map, coarse_street_map,):
    streetmap = gpd.read_file(street_map)
    streetmap.set_index("edge", inplace=True)
    edgesDf = pd.read_csv(coarse_street_map)
    edgesDf.set_index("edge", inplace=True)
    coarse2full_edge = {i:[] for i in edgesDf.index}
    full2coarse_edge = dict(streetmap.c_edge)
    for full_edge in full2coarse_edge:
        coarse_edge = full2coarse_edge[full_edge]
        coarse2full_edge[coarse_edge].append(full_edge)

    geo_list, info = [], []
    for index, row in streetmap.iterrows():
        for j in row.geometry.coords:
            geo_list += [list(j)]
        nodes = [row["source"], row["target"]]
        geo_list += [[None, None]]
        info = info + ['Road type:'+row["type"]+
                        '<br>Node: '+str(nodes[i])+
                        '<br>Full edge: '+str(index)+
                        '<br>Coarse edge: '+str(row["c_edge"])+
                        '<br>Oneway:'+str(row['oneway'])+
                        '<br>Source: '+str(row['source'])+
                        '<br>Target: '+str(row['target'])   for i in range(len(row.geometry.coords))] + [' ']
    geo_list = np.asarray(geo_list)
    lons, lats = geo_list[:,0], geo_list[:,1]
    background = [lons, lats, None, info]
    return background, coarse2full_edge, full2coarse_edge, streetmap, edgesDf

def penality_roads_plot(input_str, outputpath):
    unique_coarse_edges = [int(i) for i in input_str.split(",")]
    full_edges = np.concatenate([coarse2full_edge[i] for i in unique_coarse_edges])
    unique_full_edges = np.unique(full_edges)

    geo_list, info = [], []
    selected_edge=streetmap.loc[unique_full_edges]

    for index, row in selected_edge.iterrows():
        lst = [np.asarray(j) for j in row.geometry.coords]
        geo_list += lst+[[None, None]]
        info = info + ['Road type:'+row["type"]+
                       '<br>Coarse edge: '+str(row["c_edge"])+
                       '<br>Full edge: '+str(index)+
                       '<br>Source: '+str(row["source"])+
                        '<br>Target: '+str(row["target"]) for i in range(len(lst))] + [' ']
    geo_list = np.asarray(geo_list)
    lons, lats = geo_list[:,0], geo_list[:,1]
    roads = [lons, lats, None, info]
    plot_tool.plot_trace(outputpath=outputpath, background=background, routes=roads)

def count_trajectory(data:pd.DataFrame):
    trip_dic = {}
    for df_id, df in data.groupby("tripID"):
        trip_dic[df_id] = len(df)
    return np.asarray([min(300,trip_dic[i]) for i in trip_dic.keys()])

def count_time_interval(data:pd.DataFrame):
    df = data.sort_values(by=['tripID', 'timestamp'])
    df['time_difference'] = df.groupby('tripID')['timestamp'].diff()
    arr =  df["time_difference"].to_numpy()/60
    arr[arr>10]=10
    return arr

def count_duration(data:pd.DataFrame):
    trip_dic = {}
    data = data.sort_values(by=['tripID', 'timestamp'])
    for df_id, df in data.groupby("tripID"):
        trip_dic[df_id] = (df["timestamp"].max()-df["timestamp"].min())/60
    return np.asarray([min(300,trip_dic[i]) for i in trip_dic.keys()])



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--input_file", type=str)
    parser.add_argument("--function_name", type=str)
    parser.add_argument("--ptsDf_group_size", type=int, default=2000)
    parser.add_argument("--output_path", type=str, default="visualization/map_match/punish_roads.html")
    parser.add_argument("--outputfolder", type=str, default="visualization/map_match/")
    parser.add_argument("--raw_file", type=str, default="data/stepII.h5")
    parser.add_argument("--street_map", type=str, default="mapMatch_result/full_roads.shp")
    parser.add_argument("--coarse_street_map", type=str, default="mapMatch_result/coarse_roads.csv")
    
    args = parser.parse_args()
    outputfolder=args.outputfolder

    if not os.path.exists(outputfolder):
        os.mkdir(outputfolder)

    plot_tool = Plot_html()
    plot_plt = Plot_plt()
    # get full roads and coarse roads
    
    if args.function_name=="data_plot":
        data = tracetable(args.input_file)
        traj_duration = count_duration(data=data)
        plot_plt.hist_density_plot(data=traj_duration, x_label="Duration of routes (min)", y_label="distribution", title=None, bin=100, outputpath=args.outputfolder+"duration.png")
        traj_count = count_trajectory(data=data)
        plot_plt.hist_density_plot(data=traj_count, x_label="Nb of data points per route", y_label="distribution", title=None, bin=100, outputpath=args.outputfolder+"nb_datapoint.png")
        interval_count = count_time_interval(data=data)
        plot_plt.hist_density_plot(data=interval_count, x_label="Data point interval (min)", y_label="distribution", title=None, bin=100, outputpath=args.outputfolder+"time_interval.png")

    if args.function_name=="shapefile_plot":
        shapefile = gpd.read_file(args.input_file)
        geo_list, info = [], []
        print(shapefile)
        box = plot_tool.shp_plot_box(shapefile=shapefile, colorby="Layer")
        # print(box)
        plot_tool.plot_map_objs(outputpath=args.outputfolder+args.input_file.split("/")[-1][:-4]+".html", line_box=box)
        exit()

    background, coarse2full_edge, full2coarse_edge, streetmap, edgesDf = get_background(args.street_map, args.coarse_street_map)

    if args.function_name=="penality_roads_plot":
        penality_roads_plot(input_str=args.input_file, outputpath=args.output_path)

    if args.function_name=="selected_roads_plot":
        all_raw_df=read_h5(args.raw_file)
        all_tripID = all_raw_df.tripID.unique()
        if not os.path.exists(outputfolder):
            os.mkdir(outputfolder)
        selected_roads_plot(input_folder=args.input_file, outputfolder=args.outputfolder, all_tripID=all_tripID)