package edu.unika.aifb.kaon.datalog.evaluation;

import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.Map;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Collections;
import java.util.Comparator;

import edu.unika.aifb.kaon.datalog.program.*;

/**
 * Represents a precedence rule-goal graph of the datalog program.
 * This class topologically sorts the precedence graph into sets of mutually connected components which are then topologically sorted.
 */
public class PrecedenceGraph {
    /** The program for which the precedence graph is computed. */
    protected Program m_program;
    /** The ordering of the predicates. */
    protected GroundLiteralOrdering m_ordering;
    /** The predicates possible generated by the rule. */
    protected Map m_predicatesGeneratedByRules;
    /** The list of sets of mutually recursive rules. */
    protected List m_setsOfMutuallyRecursiveRules;
    /** The list of <code>Boolean</code> objects that determine whether the approriate connected component contains a cycle. */
    protected List m_componentHasCycle;
    /** The list of <code>Boolean</code> objects that determine whether the approriate connected component contains a negative cycle. */
    protected List m_componentHasNegativeCycle;
    /** Set to <code>true</code> if the program is stratified. */
    protected boolean m_isStratified;

    /**
     * Creates an instance of this class.
     *
     * @param program                                   the datalog program
     * @param possibleDisjunctionRestsByPredicates      the map of sets of predicates that can occur in disjunctions with the key predicate as the active literal
     * @param groundLiteralOrdering                     the ordering of the predicates
     */
    public PrecedenceGraph(Program program,Map possibleDisjunctionRestsByPredicates,GroundLiteralOrdering groundLiteralOrdering) {
        m_program=program;
        m_ordering=groundLiteralOrdering;
        Map sortedHeadPredicatesForRules=computeSortedHeadPredicatesForRules();
        Map sortedPossibleDisjunctionRestsByPredicates=new HashMap();
        if (!possibleDisjunctionRestsByPredicates.isEmpty())
            copyDisjunctionRestsByPredicates(possibleDisjunctionRestsByPredicates,sortedPossibleDisjunctionRestsByPredicates);
        computeDisjunctionRestsByPredicates(sortedHeadPredicatesForRules,sortedPossibleDisjunctionRestsByPredicates);
        computesPredicatesGeneratedByRules(sortedHeadPredicatesForRules,sortedPossibleDisjunctionRestsByPredicates);
        Set nodes=new HashSet();
        buildGraph(nodes);
        computeFinishingTime(nodes);
        clearVisitedFlag(nodes);
        List connectedComponents=new ArrayList();
        generateStronglyConnectedComponents(nodes,connectedComponents);
        filterRules(connectedComponents);
    }
    /**
     * Returns <code>true</code> if the program is stratified.
     *
     * @return                                          <code>true</code> if the program is stratified
     */
    public boolean isStratified() {
        return m_isStratified;
    }
    /**
     * Returns the number of sets of mutually recursive rules.
     *
     * @return                                          the number of sets of mutually recursive rules
     */
    public int getNumberOfSetsOfMutuallyRecursiveRulesSets() {
        return m_setsOfMutuallyRecursiveRules.size();
    }
    /**
     * Returns the requested set of mutually recursove rules.
     *
     * @param setIndex                                  the number of the set of mutually recursive rules
     * @return                                          the list of sets of mutually recursive rules
     */
    public Set getMutuallyRecursiveRulesSet(int setIndex) {
        return (Set)m_setsOfMutuallyRecursiveRules.get(setIndex);
    }
    /**
     * Returns <code>true</code> if specified set of mutually recursive rules contains a cycle.
     *
     * @param setIndex                                  the number of the set of mutually recursive rules
     * @return                                          <code>true</code> if specified set of mutually recursive rules contains a cycle
     */
    public boolean getMutuallyRecursiveRulesSetHasCycle(int setIndex) {
        return ((Boolean)m_componentHasCycle.get(setIndex)).booleanValue();
    }
    /**
     * Returns <code>true</code> if specified set of mutually recursive rules contains a cycle over negation.
     *
     * @param setIndex                                  the number of the set of mutually recursive rules
     * @return                                          <code>true</code> if specified set of mutually recursive rules contains a cycle over negation
     */
    public boolean getMutuallyRecursiveRulesSetHasNegatedCycle(int setIndex) {
        return ((Boolean)m_componentHasNegativeCycle.get(setIndex)).booleanValue();
    }
    /**
     * Returns the set of predicates that can be generated by given rule.
     *
     * @param rule                                      the rule
     * @return                                          the set of predicates generated by the rule
     */
    public Set getPredicatesGeneratedByRule(Rule rule) {
        return (Set)m_predicatesGeneratedByRules.get(rule);
    }
    /**
     * Filters the list of strongly connected components to create a list of sets of mutually recursive rules. As a side-effect this method also
     * determines whether the program is stratified.
     *
     * @param connectedComponents                       the list of strongly connected components
     */
    protected void filterRules(List connectedComponents) {
        m_setsOfMutuallyRecursiveRules=new ArrayList();
        m_componentHasCycle=new ArrayList();
        m_componentHasNegativeCycle=new ArrayList();
        m_isStratified=true;
        for (int i=0;i<connectedComponents.size();i++) {
            Set mutuallyRecursiveRules=null;
            Set connectedComponent=(Set)connectedComponents.get(i);
            Iterator iterator=connectedComponent.iterator();
            boolean hasCycle=false;
            boolean hasNegativeCycle=false;
            while (iterator.hasNext()) {
                Node node=(Node)iterator.next();
                if (node.m_object instanceof Rule) {
                    if (mutuallyRecursiveRules==null) {
                        mutuallyRecursiveRules=new HashSet();
                        m_setsOfMutuallyRecursiveRules.add(mutuallyRecursiveRules);
                    }
                    mutuallyRecursiveRules.add(node.m_object);
                }
                Iterator edges=node.m_outgoingEdges.iterator();
                while (edges.hasNext()) {
                    Edge edge=(Edge)edges.next();
                    if (!edge.m_positive && connectedComponent.contains(edge.m_to)) {
                        hasCycle=true;
                        hasNegativeCycle=true;
                        break;
                    }
                    if (connectedComponent.contains(edge.m_to)) {
                        hasCycle=true;
                        if (hasNegativeCycle)
                            break;
                    }
                }
            }
            if (mutuallyRecursiveRules!=null) {
                m_componentHasCycle.add(new Boolean(hasCycle));
                m_componentHasNegativeCycle.add(new Boolean(hasNegativeCycle));
                m_isStratified=m_isStratified & !hasNegativeCycle;
            }
        }
    }
    /**
     * Generates strongly connected components of the graph.
     *
     * @param nodes                                     the nodes
     * @param connectedComponents                       the list of sets representing the connected components of the graph
     */
    protected void generateStronglyConnectedComponents(Set nodes,List connectedComponents) {
        List nodeList=new ArrayList(nodes);
        Collections.sort(nodeList,FinishingTimeComparator.INSTANCE);
        for (int i=nodeList.size()-1;i>=0;--i) {
            Node node=(Node)nodeList.get(i);
            if (!node.m_visited) {
                Set connectedComponent=new HashSet();
                connectedComponents.add(connectedComponent);
                generateStronglyConnectedComponent(node,connectedComponent);
            }
        }
    }
    /**
     * Generates one strongly connected component of the graph.
     *
     * @param node                                      the node
     * @param connectedComponent                        the connected component
     */
    protected void generateStronglyConnectedComponent(Node node,Set connectedComponent) {
        connectedComponent.add(node);
        node.m_visited=true;
        Iterator edges=node.m_incomingEdges.iterator();
        while (edges.hasNext()) {
            Edge edge=(Edge)edges.next();
            Node fromNode=edge.m_from;
            if (!fromNode.m_visited)
                generateStronglyConnectedComponent(fromNode,connectedComponent);
        }
    }
    /**
     * Clears the visited flag of the nodes.
     *
     * @param nodes                                     the set of all nodes
     */
    protected void clearVisitedFlag(Set nodes) {
        Iterator iterator=nodes.iterator();
        while (iterator.hasNext()) {
            Node node=(Node)iterator.next();
            node.m_visited=false;
        }
    }
    /**
     * Performs the depth-first traversal of the graph over the outgoing edges and computes the finishing time.
     *
     * @param node                                      the node being visited
     * @param time                                      the current time of the algorithm
     */
    protected void computeFinishingTime(Node node,IntegerHolder time) {
        node.m_visited=true;
        Iterator edges=node.m_outgoingEdges.iterator();
        while (edges.hasNext()) {
            Edge edge=(Edge)edges.next();
            Node toNode=edge.m_to;
            if (!toNode.m_visited)
                computeFinishingTime(toNode,time);
        }
        node.m_finishingTime=time.m_value++;
    }
    /**
     * Performs the depth-first traversal of the graph over the outgoing edges and computes the finishing time.
     *
     * @param nodes                                     the set of all nodes
     */
    protected void computeFinishingTime(Set nodes) {
        IntegerHolder time=new IntegerHolder();
        Iterator iterator=nodes.iterator();
        while (iterator.hasNext()) {
            Node node=(Node)iterator.next();
            if (!node.m_visited)
                computeFinishingTime(node,time);
        }
    }
    /**
     * Builds the precedence graph from the given program.
     *
     * @param nodes                                     the set of all nodes
     */
    protected void buildGraph(Set nodes) {
        Map nodesByPredicate=new HashMap();
        for (int i=m_program.getNumberOfRules()-1;i>=0;--i) {
            Rule rule=m_program.getRule(i);
            Node ruleNode=new Node(rule);
            nodes.add(ruleNode);
            Iterator iterator=getPredicatesGeneratedByRule(rule).iterator();
            while (iterator.hasNext()) {
                Predicate predicate=(Predicate)iterator.next();
                Node headNode=getNodeForPredicate(nodes,nodesByPredicate,predicate);
                new Edge(ruleNode,headNode,true);
            }
            for (int j=0;j<rule.getBodyLength();j++) {
                Literal literal=rule.getBodyLiteral(j);
                Node bodyNode=getNodeForPredicate(nodes,nodesByPredicate,literal.getPredicate());
                new Edge(bodyNode,ruleNode,literal.isPositive());
            }
        }
    }
    /**
     * Returns the node for given predicate. If the node doesn't exist, then it is created.
     *
     * @param nodes                                     the set of all nodes
     * @param nodesByPredicate                          the map indexing nodes by their predicate
     * @param predicate                                 the predicate
     * @return                                          the node for given predicate
     */
    protected Node getNodeForPredicate(Set nodes,Map nodesByPredicate,Predicate predicate) {
        Node node=(Node)nodesByPredicate.get(predicate);
        if (node==null) {
            node=new Node(predicate);
            nodes.add(node);
            nodesByPredicate.put(predicate,node);
        }
        return node;
    }
    /**
     * Computes the list of sorted predicates head predicates for rules.
     *
     * @return                                          the map of sorted head predicates for rules
     */
    protected Map computeSortedHeadPredicatesForRules() {
        PredicateListNode dummyListNode=new PredicateListNode(null,null);
        Map sortedHeadPredicatesForRules=new HashMap();
        for (int i=0;i<m_program.getNumberOfRules();i++) {
            Rule rule=m_program.getRule(i);
            dummyListNode.m_next=null;
            for (int j=0;j<rule.getHeadLength();j++) {
                Predicate predicate=rule.getHeadLiteral(j).getPredicate();
                insertPredicate(dummyListNode,predicate);
            }
            sortedHeadPredicatesForRules.put(rule,dummyListNode.m_next);
        }
        return sortedHeadPredicatesForRules;
    }
    /**
     * Copies the map of possible disjunction rests by predicates.
     *
     * @param possibleDisjunctionRestsByPredicates      the map of sets of predicates that can occur in disjunctions with the key predicate as the active literal
     * @param sortedDisjunctionRestsByPredicates        the map of lists of predicates that can occur in disjunctions with the key predicate as the active literal
     */
    protected void copyDisjunctionRestsByPredicates(Map possibleDisjunctionRestsByPredicates,Map sortedDisjunctionRestsByPredicates) {
        PredicateListNode dummyListNode=new PredicateListNode(null,null);
        Iterator keys=possibleDisjunctionRestsByPredicates.keySet().iterator();
        while (keys.hasNext()) {
            dummyListNode.m_next=null;
            Predicate activePredicate=(Predicate)keys.next();
            Iterator iterator=((Set)possibleDisjunctionRestsByPredicates.get(activePredicate)).iterator();
            while (iterator.hasNext()) {
                Predicate predicate=(Predicate)iterator.next();
                insertPredicate(dummyListNode,predicate);
            }
            sortedDisjunctionRestsByPredicates.put(activePredicate,dummyListNode.m_next);
        }
    }
    /**
     * Computes the transitive closure of predicates in disjunctions.
     *
     * @param sortedHeadPredicatesForRules              the map of sorted head predicates for rules
     * @param sortedDisjunctionRestsByPredicates        the map of lists of predicates that can occur in disjunctions with the key predicate as the active literal
     */
    protected void computeDisjunctionRestsByPredicates(Map sortedHeadPredicatesForRules,Map sortedDisjunctionRestsByPredicates) {
        PredicateListNode dummyListNode=new PredicateListNode(null,null);
        PredicateListNode dummyListNode2=new PredicateListNode(null,null);
        boolean change=true;
        while (change) {
            change=false;
            for (int i=0;i<m_program.getNumberOfRules();i++) {
                Rule rule=m_program.getRule(i);
                dummyListNode.m_next=null;
                PredicateListNode headNode=(PredicateListNode)sortedHeadPredicatesForRules.get(rule);
                Predicate minimalHeadPredicate;
                if (headNode!=null) {
                    minimalHeadPredicate=headNode.m_predicate;
                    mergeSortedPredicateLists(dummyListNode,headNode);
                }
                else
                    minimalHeadPredicate=null;
                for (int j=0;j<rule.getBodyLength();j++) {
                    PredicateListNode predicateNodeList=(PredicateListNode)sortedDisjunctionRestsByPredicates.get(rule.getBodyLiteral(j).getPredicate());
                    if (predicateNodeList!=null)
                        mergeSortedPredicateLists(dummyListNode,predicateNodeList);
                }
                PredicateListNode currentNode=dummyListNode.m_next;
                while (currentNode!=null) {
                    Predicate disjunctionPredicate=currentNode.m_predicate;
                    if (minimalHeadPredicate!=null && m_ordering.compareTo(disjunctionPredicate,minimalHeadPredicate)>0)
                        break;
                    if (currentNode.m_next!=null) {
                        PredicateListNode disjunctionPredicateDisjunctions=(PredicateListNode)sortedDisjunctionRestsByPredicates.get(disjunctionPredicate);
                        dummyListNode2.m_next=disjunctionPredicateDisjunctions;
                        if (mergeSortedPredicateLists(dummyListNode2,currentNode.m_next))
                            change=true;
                        sortedDisjunctionRestsByPredicates.put(disjunctionPredicate,dummyListNode2.m_next);
                    }
                    currentNode=currentNode.m_next;
                }
            }
        }
    }
    /**
     * Merges two sroted predicate predicate lists.
     *
     * @param head1                                     the head of the first list (starting with a dummy node)
     * @param head2                                     the head of the second list
     * @return                                          <code>true</code> if the first list changed
     */
    protected boolean mergeSortedPredicateLists(PredicateListNode head1,PredicateListNode head2) {
        boolean change=false;
        PredicateListNode pointer1=head1.m_next;
        PredicateListNode lastNode=head1;
        PredicateListNode pointer2=head2;
        while (pointer1!=null && pointer2!=null) {
            int comparison=m_ordering.compareTo(pointer1.m_predicate,pointer2.m_predicate);
            if (comparison<0) {
                lastNode=pointer1;
                pointer1=pointer1.m_next;
            }
            else if (comparison==0)
                pointer2=pointer2.m_next;
            else {
                lastNode.m_next=new PredicateListNode(pointer2.m_predicate,lastNode.m_next);
                lastNode=lastNode.m_next;
                pointer2=pointer2.m_next;
                change=true;
            }
        }
        while (pointer2!=null) {
            lastNode.m_next=new PredicateListNode(pointer2.m_predicate,null);
            lastNode=lastNode.m_next;
            pointer2=pointer2.m_next;
            change=true;
        }
        return change;
    }
    /**
     * Inserts a predicate in a list.
     *
     * @param head                                      the head of the list (starting with a dummy node)
     * @param predicate                                 the predicate
     */
    protected void insertPredicate(PredicateListNode head,Predicate predicate) {
        PredicateListNode last=head;
        PredicateListNode current=head.m_next;
        while (current!=null) {
            if (m_ordering.compareTo(current.m_predicate,predicate)>0) {
                last.m_next=new PredicateListNode(predicate,last.m_next);
                return;
            }
            last=current;
            current=current.m_next;
        }
        last.m_next=new PredicateListNode(predicate,null);
    }
    /**
     * Computes the predicates generated by rules.
     *
     * @param sortedHeadPredicatesForRules              the map of sorted head predicates for rules
     * @param sortedDisjunctionRestsByPredicates        the map of lists of predicates that can occur in disjunctions with the key predicate as the active literal
     */
    protected void computesPredicatesGeneratedByRules(Map sortedHeadPredicatesForRules,Map sortedDisjunctionRestsByPredicates) {
        m_predicatesGeneratedByRules=new HashMap();
        for (int i=0;i<m_program.getNumberOfRules();i++) {
            Rule rule=m_program.getRule(i);
            Set result=new HashSet();
            Predicate minimalHeadPredicate;
            PredicateListNode headNode=(PredicateListNode)sortedHeadPredicatesForRules.get(rule);
            if (headNode!=null) {
                minimalHeadPredicate=headNode.m_predicate;
                result.add(minimalHeadPredicate);
            }
            else
                minimalHeadPredicate=null;
            for (int j=0;j<rule.getBodyLength();j++) {
                PredicateListNode listNode=(PredicateListNode)sortedDisjunctionRestsByPredicates.get(rule.getBodyLiteral(j).getPredicate());
                while (listNode!=null) {
                    Predicate predicate=listNode.m_predicate;
                    if (minimalHeadPredicate!=null && m_ordering.compareTo(predicate,minimalHeadPredicate)>=0)
                        break;
                    result.add(predicate);
                    listNode=listNode.m_next;
                }
            }
            m_predicatesGeneratedByRules.put(rule,result);
        }
    }

    /**
     * The node in the graph.
     */
    protected static class Node {
        /** The object in the node. */
        public Object m_object;
        /** The finishing time of the DFS algorithm. */
        public int m_finishingTime;
        /** Set to <code>true</code> if the node has been visited. */
        public boolean m_visited;
        /** The set of incoming edges. */
        public Set m_incomingEdges;
        /** The set of outgoing edges. */
        public Set m_outgoingEdges;

        /**
         * Creates an instance of this class.
         *
         * @param object                                the object in the node
         */
        public Node(Object object) {
            m_object=object;
            m_incomingEdges=new HashSet();
            m_outgoingEdges=new HashSet();
        }
    }

    /**
     * The edge in the graph.
     */
    protected static class Edge {
        /** From node. */
        public Node m_from;
        /** To node. */
        public Node m_to;
        /** Set to <code>true</code> if the edge is positive. */
        public boolean m_positive;
        /** The hash-code of the edge. */
        public int m_hashCode;

        /**
         * Creates an instance of this class.
         *
         * @param from                                  the node from which the edge points
         * @param to                                    the node to which the edge points
         * @param positive                              set to <code>true</code> if the edge is poistive
         */
        public Edge(Node from,Node to,boolean positive) {
            m_from=from;
            m_to=to;
            m_positive=positive;
            m_hashCode=7*m_from.hashCode()+m_to.hashCode();
            m_from.m_outgoingEdges.add(this);
            m_to.m_incomingEdges.add(this);
        }
        public int hashCode() {
            return m_hashCode;
        }
        public boolean equals(Object that) {
            if (this==that)
                return true;
            if (!(that instanceof Edge))
                return false;
            Edge thatEdge=(Edge)that;
            return m_from.equals(thatEdge.m_from) && m_to.equals(thatEdge.m_to) && m_positive==thatEdge.m_positive;
        }
    }

    /**
     * The class that allows passing the integer that can be changed.
     */
    protected static class IntegerHolder {
        public int m_value;
    }

    /**
     * The comparator on the decreading value of the finishing time.
     */
    protected static class FinishingTimeComparator implements Comparator {
        public static final Comparator INSTANCE=new FinishingTimeComparator();

        public int compare(Object o1,Object o2) {
            int time1=((Node)o1).m_finishingTime;
            int time2=((Node)o2).m_finishingTime;
            if (time1<time2)
                return -1;
            else if (time1>time2)
                return 1;
            else
                return 0;
        }
    }

    /**
     * The cell for the list of predicates.
     */
    protected static class PredicateListNode {
        /** The predicate in the node. */
        public Predicate m_predicate;
        /** The link to the next node. */
        public PredicateListNode m_next;

        public PredicateListNode(Predicate predicate,PredicateListNode next) {
            m_predicate=predicate;
            m_next=next;
        }
    }
}
