import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from collections import Counter
from agent import Agent


class Model:
    """Agent based model of epistemic vigilance

    G : networkx graph that determines number of agents
    p : position dictionary used to index all agents
    s : setting determines the starting configuration 
    """
    def __init__(self, G, pos, setting):
        self.G = G   # mutable
        self.pos = pos # mutable
        self.agents = []
        self.data = {'truths': [], 
                     'rumors': [],
                     'blanks': []}

        if setting[0] == 'degrees':
            self.init_degrees(setting[1])
        if setting[0] == 'centrals':
            self.init_centrals(setting[1])
        if setting[0] == 'plauwalls':
            self.init_plauwalls(setting[1])
        if setting[0] == 'compwalls':
            self.init_compwalls(setting[1])
        if setting[0] == 'schelling':
            self.init_schelling(setting[1])


    def init_degrees(self, mode):
        """Creates the following social setting: A critic is placed 
        on the most central node according to the closeness centrality.
        This setting is used in combination with the base mechanism
        critical evaluation.
        
        mode : 'min' or 'max' controls node degree of critic
        """
        for index in self.G.nodes:
            a = Agent(self, index, 0, 0, 0)
            self.agents.append(a)

        # get centralities of all nodes
        # get the most central node
        # get edges of most central node
        cs = nx.closeness_centrality(self.G)
        cn = max(cs, key=cs.get)
        ce = list(self.G.edges(cn))

        # mode determines how many edges are removed
        # min removes all except one edge
        # max leaves all edges intact
        if mode == 'min':
            ids = np.random.choice(len(ce), len(ce) - 1, replace=False)
            trash = [ce[i] for i in range(len(ce)) if i in ids]
            self.G.remove_edges_from(trash)

        # create either critics or gullibles
        for a in self.agents:
            if a.id == cn:
                a.b = 1 # critic
                a.i = 1
            else:
                a.b = 0 # gullible
                a.i = 0

        # place starting message: falsehood
        # remove critic then choose random gullible
        f = [a for a in self.agents if a.id != cn]
        np.random.shuffle(f)
        f[0].m = -1


    def init_centrals(self, mode):
        """Creates the following social setting: A critic is placed
        on one of the nodes with the most frequent node degree. 
        This setting is used in combination with the base mechanism
        critical evaluation.
        
        mode : 'min' or 'max' controls centrality selection of node
        """
        for index in self.G.nodes:
            a = Agent(self, index, 0, 0, 0)
            self.agents.append(a)

        # get tuple containing node and degree for all nodes
        # get frequency distribution over most prevalent degrees
        # get most frequent degree
        # get all nodes that have the most frequent degree
        nd = dict(self.G.degree)
        fd = Counter(nd.values())
        md = max(fd, key=fd.get)
        nm = [i[0] for i in nd.items() if i[1] == md]

        # get centralities of all nodes
        # filter centralities based on nodes with most frequen degree
        cs = nx.closeness_centrality(self.G)
        fc = {n: c for n, c in cs.items() if n in nm}

        # mode determines what centrality of node is selected
        # min selects the minimum centrality in node list
        # max selects the maximum centrality in node list
        if mode == 'min':
            cn = min(fc, key=fc.get)
        if mode == 'max':
            cn = max(fc, key=fc.get)

        # create either critics or gullibles
        for a in self.agents:
            if a.id == cn:
                a.b = 1 # critic
                a.i = 1
            else:
                a.b = 0 # gullible
                a.i = 0

        # place starting message: falsehood
        # remove critic then choose random gullible
        f = [a for a in self.agents if a.id != cn]
        np.random.shuffle(f)
        f[0].m = -1


    def init_plauwalls(self, mode):
        """Creates the following social setting: A critic is placed
        on the most central node and all its neighbors are set as 
        truth biased plausibles. This setting is used in combination
        with the mechanism addition plausibility checking.
        
        mode : '' place empty string here
        """
        for index in self.G.nodes:
            a = Agent(self, index, 0, 0, 0)
            self.agents.append(a)

        # get centralities of all nodes
        # get the most central node
        # get neighbors of most central node
        cs = nx.closeness_centrality(self.G)
        cn = max(cs, key=cs.get)
        cb = list(self.G.neighbors(cn))

        # create either critics or gullibles or plausibles
        for a in self.agents:
            if a.id == cn:
                a.b = 1 # critic
                a.i = 1
            elif a.id in cb:
                a.b = 1 # plausible
                a.i = 0
            else:
                a.b = 0 # gullible
                a.i = 0

        # place starting message: falsehood
        # remove critic and plausibles then choose a random gullible 
        f = [a for a in self.agents if a.id != cn and a.id not in cb]
        np.random.shuffle(f)
        f[0].m = -1


    def init_compwalls(self, mode):
        """Creates the following social setting: A critic is placed
        on the most central node and all its neighbors are set as 
        rumor biased plausibles. The critic has an initial truth that
        it wants to share with its neighbors. This setting is used 
        in combinationwith the mechanism addition plausibility checking.
        
        mode : '' place empty string here
        """

        for index in self.G.nodes:
            a = Agent(self, index, 0, 0, 0)
            self.agents.append(a)

        # get centralities of all nodes
        # get the most central node
        # get neighbors of most central node
        cs = nx.closeness_centrality(self.G)
        cn = max(cs, key=cs.get)
        cb = list(self.G.neighbors(cn))

        # create either critics or gullibles or plausibles
        for a in self.agents:
            if a.id == cn:
                a.b = 1 # critic
                a.i = 1
                a.m = 1
            elif a.id in cb:
                a.b = -1 # plausible
                a.i = 0
            else:
                a.b = 0 # gullible
                a.i = 1
        
        # place starting message: falsehood
        # remove critic and plausibles then choose a random gullible 
        f = [a for a in self.agents if a.id != cn and a.id not in cb]
        np.random.shuffle(f)
        f[0].m = -1


    def init_schelling(self, mode):
        """Creates the following social setting: Places a mix of 
        truth biased critics, rumor biased critics, rumor biased plausibles, 
        truth biased plausibles, and gullibles in world. Number of starting 
        rumors and truths can also be set. Agents are structured based on 
        Schelling  segregate algorithm determined by number of schelling steps 
        and tolerance level of agents. Initialization requires at least a few 
        gullible agents so that the algorithm has room to switch unhappy agents.
        
        mode : tuple containing (number of truth biased critics, rumor biased 
        critics, truth biased plausibles, rumors biased plausibles, gullibles, 
        number of truths, number of rumors, schelling steps, and tolerance)
        """
        for index in self.G.nodes:
            a = Agent(self, index, 0, 0, 0)
            self.agents.append(a)

        # get slice stops for mix of agents
        # get slice stops for true and false messages
        # get number of maximum schelling steps
        # get schelling tolerance level of agents
        mix = [sum(mode[0:i]) for i in range(0, 6)]
        mes = [sum(mode[5:i]) for i in range(5, 8)]
        steps = mode[7]
        tolerance = mode[8]

        # shuffle agents
        # create list of agent ids in random sequence
        # create ordered list of lists with ids for messages
        np.random.shuffle(self.agents)
        ids = [a.id for a in self.agents]
        mids = [] # mix ids
        for i in range(len(mix) - 1):
            group = ids[mix[i]:mix[i + 1]]
            mids.append(group)
        
        # shuffle agents
        # create list of agent ids in random sequence
        # create ordered list of lists with ids for messages
        np.random.shuffle(self.agents)
        ids = [a.id for a in self.agents]
        meds = [] # mes ids
        for i in range(len(mes) - 1):
            group = ids[mes[i]:mes[i + 1]]
            meds.append(group)

        # create either critics or gullibles or plausibles
        for a in self.agents:
            if a.id in mids[0]:
                a.b = 1 # truth biased critic
                a.i = 1
            elif a.id in mids[1]:
                a.b = -1 # rumor biased critic
                a.i = 1
            elif a.id in mids[2]:
                a.b = 1 # truth biased plausible
                a.i = 0
            elif a.id in mids[3]:
                a.b = -1 # rumor biased plausible
                a.i = 0
            else:
                a.b = 0 # gullible are happy
                a.i = 0
                a.h = True

        # create starting messages either true or false
        for a in self.agents:
            if a.id in meds[0]:
                a.m = 1 # truth
            elif a.id in meds[1]:
                a.m = -1 # rumor
            else:
                a.m = 0 # none
        
        # schelling segregation algorithm
        # segregation is performed a maximum of schelling steps
        for i in range(steps):

            # break if there are no more unhappy agents
            unhappy = len([a for a in self.agents if a.h == False])
            if unhappy == 0:
                break

            # randomly activate agents
            np.random.shuffle(self.agents)
            for a in self.agents:
                nb = a.neighbors()
                pos = len([a.b for a in nb if a.b == 1])
                neg = len([a.b for a in nb if a.b == -1])

                # calculate neighbor fraction depending on beliefs
                # only agents with beliefs not zero are relevant
                if a.b == 1:
                    fraction = pos / len(nb)
                elif a.b == -1:
                    fraction = neg / len(nb)
                else:
                    fraction = 1

                # try relocate agent if unhappy
                if fraction > tolerance:
                    a.h = True
                else:
                    # find potential switch candidates
                    # pick random candidate and switch locations
                    switchables = [a for a in self.agents if a.b == 0]
                    np.random.shuffle(switchables)

                    # memorize location of switch agent
                    # and replace their indices or locations
                    sw = switchables[0]
                    memory = sw.id
                    sw.id = a.id
                    a.id = memory


    def run(self, ticks, sre, srp, src, mechanism, show, mode):
        """Runs the model with specified mechanism in the run condition.
        
        ticks : number of ticks
        sre : success rate of critical evaluation
        srp : success rate of plausibility checking
        src : success rate of competence checking
        mechanism : string that determines the agent mechanism
        show : boolean plots network in each tick
        mode : 's' for small or 'l' for large plot 
        """
        if mechanism == 'crival':
            self.run_crival(ticks, sre, show, mode)
        if mechanism == 'plaucheck':
            self.run_plaucheck(ticks, sre, srp, show, mode)
        if mechanism == 'compcheck':
            self.run_compcheck(ticks, sre, srp, src, show, mode)


    def run_crival(self, ticks, sre, show, mode):
        """Runs the model with mechanisms critical evaluation.
        
        ticks : number of ticks
        sre : success rate of critical evaluation
        show : boolean plots network in each tick
        mode : 's' for small or 'l' for large plot
        """
        self.calc()

        if show:
            self.draw(0, mode)
        
        for tick in range(1, ticks + 1):
            np.random.shuffle(self.agents)
            for a in self.agents:
                a.crival(sre)

            self.calc()

            if show:
                self.draw(tick, mode)


    def run_plaucheck(self, ticks, sre, srp, show, mode):
        """Runs the model with mechanisms plausibility checking.
        
        ticks : number of ticks
        sre : success rate of critical evaluation
        srp : success rate of plausibility checking
        show : boolean plots network in each tick
        mode : 's' for small or 'l' for large plot
        """
        self.calc()

        if show:
            self.draw(0, mode)
        
        for tick in range(1, ticks + 1):
            np.random.shuffle(self.agents)
            for a in self.agents:
                a.plaucheck(sre, srp)

            self.calc()

            if show:
                self.draw(tick, mode)


    def run_compcheck(self, ticks, sre, srp, src, show, mode):
        """Runs the model with mechanisms competence checking.
        
        ticks : number of ticks
        sre : success rate of critical evaluation
        srp : success rate of plausibility checking
        src : success rate of competence checking
        show : boolean plots network in each tick
        mode : 's' for small or 'l' for large plot
        """
        self.calc()

        if show:
            self.draw(0, mode)
        
        for tick in range(1, ticks + 1):
            np.random.shuffle(self.agents)
            for a in self.agents:
                a.compcheck(sre, srp, src)
                
            self.calc()

            if show:
                self.draw(tick, mode)


    def draw(self, tick, mode):
        """Draws the network of current tick in specified size.
        
        tick : current tick or time step
        mode : 's' for small or 'l' for large plot
        """
        ids = [] # agent ids
        delues = [] # default colorvalues
        belues = [] # beliefs colorvalues

        mids = [] # message ids
        melues = [] # message colorvalues
        
        tmess = [] # nodes with true messages
        fmess = [] # nodes with false messages
        
        tcids = [] # truth biased critic ids
        rcids = [] # rumor biased critic ids
        tpids = [] # truth biased plausible ids
        rpids = [] # rumor biased plausible ids

        for a in self.agents:
            # these are all agents
            ids.append(a.id)
            delues.append(0)
            belues.append(a.b)
        
            # want to only color those that carry messages
            if a.m != 0:
                mids.append(a.id)
                melues.append(a.m)
                
            if a.m == 1:
                tmess.append(a.id)
                
            if a.m == -1:
                fmess.append(a.id)

            # these agent ids get customized 
            if a.b == 1 and a.i == 1:
                tcids.append(a.id)
            if a.b == -1 and a.i == 1:
                rcids.append(a.id)
            if a.b == 1 and a.i == 0:
                tpids.append(a.id)
            if a.b == -1 and a.i == 0:
                rpids.append(a.id)

        # small plot
        if mode == 's':
            fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5), facecolor='white')

            # define colormaps for small plot
            cmap = plt.get_cmap('binary')
            cmap = ListedColormap(cmap(np.linspace(0.2, 0.8, 256)))
            cmap = ListedColormap(cmap.colors[::-1])

            # draw all nodes in default color
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=ids, 
            node_size=30, node_color=delues, cmap=cmap, vmin=-1, vmax=1, ax=ax)

            # overpaint
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=tmess,
            node_size=160, node_color='c', ax=ax)
    
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=fmess,
            node_size=160, node_color='crimson', ax=ax)
            
            # customize truth biased critics
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=tcids,
            node_size=160, node_color='none', linewidths=2.5, edgecolors='k', ax=ax)
            
            # customize rumor biased critics
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=rcids,
            node_size=160, node_color='none', linewidths=2.5, edgecolors='k', ax=ax)

            # customize truth biased plausibles
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=tpids,
            node_size=60, node_color='none', linewidths=1.5, edgecolors='k', ax=ax)

            # customize rumor biased plausibles
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=rpids,
            node_size=60, node_color='none', linewidths=1.5, edgecolors='k', ax=ax)

            # draw all edges with transparency
            nx.draw_networkx_edges(self.G, self.pos, width=1.1, alpha=0.2, ax=ax)

            # customize figure
            ax.set_aspect('equal', adjustable='box')
            ax.set_xlabel('Tick {}'.format(tick), size=12)
            plt.margins(0.08)
            plt.tight_layout()

        # large plot
        if mode == 'l':
            fig, ax = plt.subplots(1, 1, figsize=(5, 5), facecolor='white')

            # define colormaps for large plot
            # colormap for background beliefs
            bmap = plt.get_cmap('binary') # matplotlib colormap
            bmap = ListedColormap(bmap(np.linspace(0.3, 0.7, 256)))
            bmap = ListedColormap(bmap.colors[::-1])

            # colormap for message values
            mmap = plt.get_cmap('binary') # matplotlib colormap
            mmap = ListedColormap(mmap(np.linspace(0.05, 0.95, 256)))
            mmap = ListedColormap(mmap.colors[::-1])

            # draw all nodes based on color of beliefs
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=ids,
            node_size=78, node_color=belues, cmap=bmap, vmin=-1, vmax=1, ax=ax)

            # overpaint
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=tmess,
            node_size=80, node_color='c', ax=ax)
    
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=fmess,
            node_size=80, node_color='crimson', ax=ax)

            # customize truth biased critics
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=tcids,
            node_size=80, node_color='none', linewidths=2, edgecolors='k', ax=ax)

            # customize rumor biased critics
            nx.draw_networkx_nodes(self.G, self.pos, nodelist=rcids,
            node_size=80, node_color='none', linewidths=2, edgecolors='k', ax=ax)

            # add legend
            labelcmap = plt.get_cmap('binary')
            labelcolors = ['c', labelcmap(0.3), labelcmap(0.5), labelcmap(0.7), 'crimson']
            labelnames = ['true message', 'true belief', 'no belief', 'false belief', 'false message']
            for i in range(5):
                ax.scatter([], [], color=labelcolors[i], label=labelnames[i])
            plt.legend(loc='lower right', fancybox=False, edgecolor='k', 
            framealpha=1, markerscale=1.3)

            # customize figure
            ax.set_aspect('equal', adjustable='box')
            ax.set_xlabel('Tick {}'.format(tick), size=12)
            plt.margins(0.08)
            plt.tight_layout()
    
        plt.savefig('../images/{}.tif'.format(tick), dpi=200)
        plt.savefig('../images/{}.png'.format(tick), dpi=200)
        plt.close(fig)


    def calc(self):
        """Calculates metrics stored in data dictionary"""
        truths = 0
        rumors = 0
        blanks = 0
        for a in self.agents:
            if a.m == 1:
                truths += 1
            elif a.m == -1:
                rumors += 1
            else:
                blanks += 1
        self.data['truths'].append(truths)
        self.data['rumors'].append(rumors)
        self.data['blanks'].append(blanks)