
import networkx as nx
import random
from networkx.classes.function import set_node_attributes
import copy
import matplotlib.pyplot as plt
import re
from pathlib import Path
import math
from argument import Argument, WeightedArgument


def sigmoid(x):
  return 1 / (1 + math.exp(-x))



class DebateGraph(nx.DiGraph):

    """ An attack argumentation graph
         
        Parameters : 
            issue = the main issue of the debate
            GPOV = the GPOV object that the tree belongs to 
    """
    def __init__(self, GPOV = None, issue_id = None):
        # in non weighted graphs, arguments have a strength of None

        super().__init__()
        if issue_id is not None:
            issue = Argument(issue_id, None, self)
            self.issue = issue
            self.add_node(issue)
            self.issue_acceptability = True
        self.GPOV = GPOV
        
    
        
    
    def get_issue(self):
        return self.issue

    def get_name(self):
        return self.issue.id

    def get_gpov_name(self):
        return self.GPOV.get_name()

    def get_size(self):
        # getting the size of the graph (including the issue)
        return len(list(self.nodes))

    
    def get_edges_between(self, a, graph):
        """ Getting all edges from argument a toward nodes that belong to another graph
        """
        in_edges = [e for e in self.in_edges(a) if e[0] in graph.nodes ]
        out_edges = [e for e in self.out_edges(a) if e[1] in graph.nodes ]
        return in_edges + out_edges


    def view_graph(self):
        print("Issue : ", self.issue)
        print(self.issue_acceptability)
        print("Arguments :")
        for arg in self.nodes:
            print(arg)
        print("Edges")
        for edge in self.edges:
            print(edge[0], " ===> ",edge[1])
    
    def draw(self, time = None, title = None, save = False):
        """
        This function draws a graph and saves the image
        """

        if time is not None:
            path = 'Figs/' + re.sub(  "\:", "_", str(time)) + '/'   # sub to avoid filename errors
        if save:
            Path(path).mkdir(parents=True, exist_ok=True)
        plt.figure(figsize=(10,5))
        ax = plt.gca()
        if title is not None:
            ax.set_title(title)
        nx.draw(self, pos=nx.spring_layout(self), labels = {n:str(n) for n in self.nodes})
        if save:
            plt.savefig( path +title + '.png', format ="PNG" )
        plt.show()

    
    def __str__(self) -> str:
        return str(self.issue) + str(self.nodes)

    def deep_copy(self):
        return copy.deepcopy(self)
    
    
    def get_oddity(self, arg):
        """ Returns 1 if the sequence from the arg to the issue is odd, 0 if it is even
        odd -> defense node
        """

        paths = [p for p in nx.all_simple_paths(self, source=arg, target=self.issue)]
        if len(paths)>0:
            path = paths[0]
            return len(path) % 2
        else:
            if arg == self.issue:
                return 1

    def get_oddity_new_arg(self, arg):
        #  """ Gets the oddity of an argument which is not in the graph but who is attacking an argument 
        # of the graph. Used in the review process
        #  Returns 1 if the sequence from the arg to the issue is odd, 0 if it is even
        # odd -> defense node
        # """
        if self.get_oddity(arg.attacked_arg)==0:
            return 1
        return 0


    def get_leafs(self):
        return [x for x in self.nodes() if self.in_degree(x)==0]

    def is_issue_acceptable(self):
        # Algorithm for computing the acceptabily of the issue 
        # inspired by Mogdil and Caminada
        G = self.get_leafs()
        O = []
        flag = True

        while flag:
            O += list(set([s for x in G for s in self.successors(x)]))
            if self.issue in O:
                return False
            for o in O:
                for a in self.successors(o):
                    if a not in O and set(self.predecessors(a)).issubset(O):
                        if a == self.issue:
                            return True
                        G +=[a]

    def get_acceptable_arguments(self):
        # Algorithm for computing the set of acceptable arguments
        # inspired by Mogdil and Caminada
        G = self.get_leafs()
        O = []
        flag = True
        while len(G) + len(O) < len(self.nodes):
            O += list(set([s for x in G for s in self.successors(x)]))
            for o in O:
                for a in self.successors(o):
                    if a not in O and set(self.predecessors(a)).issubset(O):
                        G +=[a]
        return G
            
        

    
    def is_linked_to_issue(self, new_arg):
        # checks whether a new argument is part of a direct path towards the issue
        if new_arg.attacked_arg.id not in [n.id for n in self.nodes]:
            return False
        if new_arg.attacked_arg.id == self.issue.id:
            return True
        else:
            return self.is_linked_to_issue(new_arg.attacked_arg)

                


    

class WeightedDebateGraph(DebateGraph):

    def __init__(self, GPOV = None, issue_strength = 5):

        super().__init__(GPOV)
        issue = WeightedArgument(issue_strength, None, self)
        self.issue = issue
        self.issue_acceptability = True
        self.add_node(issue)

    # def generate_new_arg(self, arg):
    #     #generates an argument which counter attacks the given one 
    #     new_strength = random.randint(0,10)
    #     return WeightedArgument( new_strength, arg, self)

