import csv
import numpy as np
import scipy as sp
from scipy import stats
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from model import Model
from worlds import neumann, moore, voronoi, social


def collect(world, num, runs, ticks, sre, srp, src, setting, mechanism, filename):
    """Collects all model data for number of runs and stores it in csv file. 

    world : world type that gets created each run
    num : number of agents
    runs : number of separate model runs
    ticks : number of ticks for each run
    sre : success rate of critical evaluation
    srp : success rate of plausibility checking 
    src : success rate of competence checking 
    setting : determines starting configuration 
    mechanism : determines agent mechanism 
    filename : filename without file extension 
    """
    for run in range(runs):

        # voronoi worlds are random
        if world == 'neumann':
            w = neumann(num)
        elif world == 'moore':
            w = moore(num)
        elif world == 'voronoi':
            w = voronoi(num) # random
        elif world == 'social':
            w = social(num)
        else:
            raise ValueError()

        model = Model(*w, setting)
        model.run(ticks, sre, srp, src, mechanism, show=False, mode='s')

        with open('../Data/{}_rumors_new.csv'.format(filename), 'a+', newline='') as f:
            w = csv.writer(f)
            w.writerow(model.data['rumors'])
        
        with open('../Data/{}_truths_new.csv'.format(filename), 'a+', newline='') as f:
            w = csv.writer(f)
            w.writerow(model.data['truths'])

        with open('../Data/{}_blanks_new.csv'.format(filename), 'a+', newline='') as f:
            w = csv.writer(f)
            w.writerow(model.data['blanks'])


def comparison():
    """Calculates and plots Kendalls tau values for all datasets"""
    for i in range(1, 9):
        r1 = list(pd.read_csv('../Data/moo{}_rumors.csv'.format(i)).mean())
        r2 = list(pd.read_csv('../Data/neu{}_rumors.csv'.format(i)).mean())
        r3 = list(pd.read_csv('../Data/vor{}_rumors.csv'.format(i)).mean())
        r4 = list(pd.read_csv('../Data/soc{}_rumors.csv'.format(i)).mean())
        print('----------')
        tau, p = stats.kendalltau(r1, r2)
        print('Tau MN:', tau, p)
        tau, p = stats.kendalltau(r1, r3)
        print('Tau MV:', tau, p)
        tau, p = stats.kendalltau(r1, r4)
        print('Tau MS:', tau, p)


def stackplot(num, index, datanames, filename):
    """Plots the data given by filenames in a stackplot.
    
    num : number of agents that were used for simulation data
    index : index determines the range that is used for plots
    datanames : path from Data folder for filenames with file extension
    filename : filename without file extension
    """
    rumors = pd.read_csv('../Data/' + datanames[0], header=None)
    truths = pd.read_csv('../Data/' + datanames[1], header=None)
    blanks = pd.read_csv('../Data/' + datanames[2], header=None)

    # cut off at desired index
    rumors = rumors.iloc[:, 0:index]
    truths = truths.iloc[:, 0:index]
    blanks = blanks.iloc[:, 0:index]

    # define colors and labels stackplot
    cmap = plt.get_cmap('binary')
    colors = [cmap(0.8), cmap(0.2), cmap(0.5)]
    colors = [cmap(0.2), cmap(0), cmap(0.5)]
    # colors = ['crimson', 'c', cmap(0.5)]
    labels = ['falsity', 'truth', 'none']

    # mean values of data
    x_mean = range(truths.shape[1])
    t_mean = truths.mean()
    r_mean = rumors.mean()
    b_mean = blanks.mean()

    # other metrics of data 
    s = int(truths.shape[1] / 20)
    x = range(truths.shape[1])[::s]
    t = truths.mean()[::s]
    t_min = truths.min()[::s]
    t_max = truths.max()[::s]
    t_std = truths.std()[::s]
    r = rumors.mean()[::s]
    r_min = rumors.min()[::s]
    r_max = rumors.max()[::s]
    r_std = rumors.std()[::s] 

    fig = plt.figure(figsize=(10, 4.3), facecolor='white')
    gs = GridSpec(2, 5, figure=fig)
    
    ax1 = fig.add_subplot(gs[:, :3])
    ax2 = fig.add_subplot(gs[0, 3:])
    ax3 = fig.add_subplot(gs[1, 3:])
    
    ax1.stackplot(x_mean, r_mean, t_mean, b_mean, colors=colors, edgecolor='k', linewidth=0.5, labels=labels)
    ax1.set_xlabel('ticks', size=14)
    ax1.set_ylabel('agents', size=14)
    ax1.set_xticks([0, int(index/4), int(index/2), int(index*3/4), int(index - 1)])
    ax1.legend(frameon=False, prop={'size': 12})
    ax1.margins(0)
    
    ax2.errorbar(x, r, [r - r_min, r_max - r], fmt='.k', ecolor='gray', lw=1)
    ax2.errorbar(x, r, r_std, fmt='ok', lw=3)
    ax2.set_ylabel('agents', size=14)
    ax2.set_yticks([0, int(num/2), num])
    ax2.set_xticks([0, int(index/4), int(index/2), int(index*3/4), int(index - 1)])
    ax2.legend(labels=['falsity'], prop={'size': 12}, fancybox=False, edgecolor='k', framealpha=1)
    
    ax3.errorbar(x, t, [t - t_min, t_max - t], fmt='.k', ecolor='gray', lw=1)
    ax3.errorbar(x, t, t_std, fmt='ok', lw=3)
    ax3.set_xlabel('ticks', size=14)
    ax3.set_ylabel('agents', size=14)
    ax3.set_yticks([0, int(num/2), num])
    ax3.set_xticks([0, int(index/4), int(index/2), int(index*3/4), int(index - 1)])
    ax3.legend(labels=['truth'], prop={'size': 12}, fancybox=False, edgecolor='k', framealpha=1)
    
    fig.tight_layout()
    fig.savefig('../Images/{}.tif'.format(filename), dpi=200)
    fig.savefig('../Images/{}.png'.format(filename), dpi=200)
    plt.close()


def lineplot(num, index, percentages, datanames, filename):
    """Plots the data given by filenames in a stackplot.

    num : number of agents that were used for simulation data
    index : index of single timestep for plot
    percentages : needed for x axis with percent of plausibles
    datanames : rumors, truths, blanks filenames with file extions
    filename : filename without file extension
    """
    line = [] # mean line
    sthi = [] # mean plus standard deviation
    stlo = [] # mean minus standard deviation
    rang = [] # values for x axis    
    
    for i in range(len(datanames)):
        df = pd.read_csv('../Data/' + datanames[i], header=None)
        line.append(df.mean()[index] / num)
        sthi.append((df.mean()[index] + df.std()[index]) / num)
        stlo.append((df.mean()[index] - df.std()[index]) / num)
        rang.append(round(percentages[i] / 100, ndigits=2))
        
    # polynomial fits
    line_poly = np.poly1d(np.polyfit(rang, line, 7))
    sthi_poly = np.poly1d(np.polyfit(rang, sthi, 7))
    stlo_poly = np.poly1d(np.polyfit(rang, stlo, 7))

    fig, ax = plt.subplots(1, 1, figsize=(8, 5), facecolor='white')
    
    ax.plot(rang, line_poly(rang), c='k', label='polyfit: mean')
    ax.plot(rang, sthi_poly(rang), c='k', lw=1.5, alpha=0.3, label='polyfit: +/- std')
    ax.plot(rang, stlo_poly(rang), c='k', lw=1.5, alpha=0.3)

    ax.scatter(rang, line, s=3, color='k', label='data', alpha=0.4)
    ax.scatter(rang, stlo, s=3, color='k', alpha=0.4)
    ax.scatter(rang, sthi, s=3, color='k', alpha=0.4)
    ax.fill_between(rang, sthi_poly(rang), stlo_poly(rang), color='k', alpha=0.1)
    
    ax.set_xlabel('% plausibles', size=14)
    ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    ax.set_xticklabels([0, 20, 40, 60, 80, 100]) # map to percentages
    ax.set_xlim([-0.01, 1.01])
    ax.set_ylabel('% falsities at tick {}'.format(index), size=14)
    ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5])
    ax.set_yticklabels([0, 10, 20, 30, 40, 50]) # map to percentages
    ax.legend(prop={'size': 12}, fancybox=False, edgecolor='k', framealpha=1)

    twin = ax.twiny()
    twin.xaxis.set_ticks_position('bottom')
    twin.xaxis.set_label_position('bottom')
    twin.set_xlabel('% gullilbes', size=14)
    twin.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    twin.set_xticklabels([100, 80, 60, 40, 20, 0]) # map to percentages
    twin.set_xlim([-0.01, 1.01]) # reverse order
    twin.spines['bottom'].set_position(('outward', 50))

    plt.tight_layout()
    plt.savefig('../Images/{}.tif'.format(filename), dpi=200)
    plt.savefig('../Images/{}.png'.format(filename), dpi=200)


def worldplot(num, world):
    """Plots specified interaction network in a worldplot.

    num : number of agents that were used for simulation data
    world : string indicating which world to plot
    """
    fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5), facecolor='white')

    if world == 'neumann':
        w = neumann(num)
        ax.set_xlabel('von Neumann', size=12)
    elif world == 'moore':
        w = moore(num)
        ax.set_xlabel('Moore', size=12)
    elif world == 'voronoi':
        w = voronoi(num)
        ax.set_xlabel('Voronoi', size=12)
        vor = sp.spatial.Voronoi(list(w[1].values()))
        sp.spatial.voronoi_plot_2d(vor, show_points=False, show_vertices=False, line_colors='k', line_alpha=0.3, ax=ax)
    elif world == 'social':
        w = social(num)
        ax.set_xlabel('Watts-Strogatz', size=12)
    else:
        pass
    
    color = plt.get_cmap('binary')(0.8)
    nx.draw_networkx_nodes(w[0], w[1], node_size=30, node_color=color, ax=ax)
    nx.draw_networkx_edges(w[0], w[1], width=1.1, alpha=0.8, ax=ax)
    ax.set_aspect('equal', adjustable='box')
    plt.margins(0.08)
    plt.tight_layout()
    plt.savefig('../images/{}.tif'.format(world), dpi=200)
    plt.savefig('../images/{}.png'.format(world), dpi=200)
    plt.close()