package edu.unika.aifb.kaon.datalog.magic;

import java.util.List;
import java.util.LinkedList;
import java.util.Set;
import java.util.HashSet;
import java.util.Map;
import java.util.HashMap;
import java.util.Iterator;

import edu.unika.aifb.kaon.datalog.*;
import edu.unika.aifb.kaon.datalog.program.*;

/**
 * This class represents a magic set rewriting of a disjunctive program.
 */
public class DisjunctiveMagicSetRewriting extends MagicSetRewriting {

    /**
     * Creates an instance of this class.
     *
     * @param program                           the original program
     * @param queryPredicate                    the query predicate
     * @param boundVariables                    the array specifying which variables are bound
     * @param sips                              the sideways information passing strategy
     * @param predicateBindingPatterns          determines allowed binding patterns of a predicate
     * @throws DatalogException                 thrown if rewriting can't be created
     */
    public DisjunctiveMagicSetRewriting(Program program,Predicate queryPredicate,boolean[] boundVariables,SIPS sips,PredicateBindingPatterns predicateBindingPatterns) throws DatalogException {
        if (queryPredicate.getArity()!=boundVariables.length)
            throw new IllegalArgumentException("The arity of the query predicate and the signature of the bound variables don't match.");
        m_program=program;
        m_queryPredicate=queryPredicate;
        m_boundVariables=boundVariables;
        new Rewriter(sips,predicateBindingPatterns);
    }

    /**
     * Encapsulates the algorithm for rewriting a program into a magic program.
     */
    protected class Rewriter {
        /** The program with eliminated integroty constraints. */
        protected Program m_noIntegrityConstraintsProgram;
        /** The set of IDB predicates. */
        protected Set m_idbPredicates;
        /** The factory for the predicates. */
        protected PredicateFactory m_predicateFactory;
        /** The map of predicates to their ESV counterparts. */
        protected Map m_esvPredicates;
        /** The ESV query predicate. */
        protected Predicate m_esvQueryPredicate;
        /** The ESV program. */
        protected Program m_esvProgram;
        /** The set of rules added while eliminating integrity constraints. */
        protected Set m_integrityConstraintRules;
        /** The set of rules that are in ESV but not in SV. */
        protected Set m_rulesESVMinusSV;
        /** The rewritten ESV program. */
        protected Program m_magicESVProgramMinusSVRules;

        /**
         * Creates an instance of this class.
         *
         * @param sips                          the SIPS to use
         * @param predicateBindingPatterns      determines allowed binding patterns of a predicate
         * @throws DatalogException             thrown if rewriting can't be created
         */
        public Rewriter(SIPS sips,PredicateBindingPatterns predicateBindingPatterns) throws DatalogException {
            m_predicateFactory=new PredicateFactory();
            m_esvPredicates=new HashMap();
            m_integrityConstraintRules=new HashSet();
            m_rulesESVMinusSV=new HashSet();
            eliminateIntegrityConstraints();
            m_idbPredicates=getPredicatesInRuleHeads(m_noIntegrityConstraintsProgram);
            generateESVProgram();
            Set additionalFreePredicates=new HashSet();
            Iterator iterator=m_integrityConstraintRules.iterator();
            while (iterator.hasNext()) {
                Rule icRule=(Rule)iterator.next();
                additionalFreePredicates.add(getESVPredicate(icRule.getHeadLiteral(0).getPredicate()));
            }
            HornMagicSetRewriting hornMagicSetRewriting=new HornMagicSetRewriting(m_esvProgram,m_esvQueryPredicate,m_boundVariables,additionalFreePredicates,sips,predicateBindingPatterns,m_rulesESVMinusSV);
            m_magicESVProgramMinusSVRules=hornMagicSetRewriting.getMagicProgram();
            m_addedPredicates=new HashSet();
            m_addedPredicates.addAll(hornMagicSetRewriting.getAddedPredicates());
            m_addedPredicates.addAll(m_predicateFactory.getAllPredicates());
            generateDisjunctiveRewriting(hornMagicSetRewriting.getAdornedPredicatesByPredicate());
            m_seedPredicate=hornMagicSetRewriting.getSeedPredicate();
            m_answerPredicate=m_queryPredicate;
            m_adornedPredicatesByPredicate=hornMagicSetRewriting.getAdornedPredicatesByPredicate();
        }
        /**
         * Generates the rewriting of the disjunctive program.
         *
         * @param adornedPredicatesByPredicate      the map indexed by predicates containing the set of adorned predicates
         */
        protected void generateDisjunctiveRewriting(Map adornedPredicatesByPredicate) {
            List rules=new LinkedList();
            for (int i=0;i<m_magicESVProgramMinusSVRules.getNumberOfRules();i++)
                rules.add(m_magicESVProgramMinusSVRules.getRule(i));
            // If there is only one binding for some ESV predicate, then replace the ESV with this binding.
            // Generate rules collecting the various bindings of ESV predicate in parallel.
            Iterator keys=adornedPredicatesByPredicate.keySet().iterator();
            while (keys.hasNext()) {
                Predicate unadornedPredicate=(Predicate)keys.next();
                Set adornedPredicates=(Set)adornedPredicatesByPredicate.get(unadornedPredicate);
                if (adornedPredicates.size()==1) {
                    Predicate adornedPredicate=(Predicate)adornedPredicates.iterator().next();
                    Iterator iterator=m_esvPredicates.entrySet().iterator();
                    while (iterator.hasNext()) {
                        Map.Entry entry=(Map.Entry)iterator.next();
                        if (entry.getValue().equals(unadornedPredicate)) {
                            m_addedPredicates.remove(entry.getValue());
                            entry.setValue(adornedPredicate);
                            break;
                        }
                    }
                }
                else {
                    Term[] terms=new Term[unadornedPredicate.getArity()];
                    for (int i=0;i<unadornedPredicate.getArity();i++)
                        terms[i]=new Variable("X"+i);
                    Literal[] ruleHead=new Literal[] { new Literal(unadornedPredicate,true,terms) };
                    Iterator adornations=adornedPredicates.iterator();
                    while (adornations.hasNext()) {
                        Predicate adornedPredicate=(Predicate)adornations.next();
                        rules.add(new Rule(null,ruleHead,new Literal[] { new Literal(adornedPredicate,true,terms) }));
                    }
                }
            }
            // Modify the disjunctive rules of the original program.
            for (int i=0;i<m_noIntegrityConstraintsProgram.getNumberOfRules();i++) {
                Rule rule=m_noIntegrityConstraintsProgram.getRule(i);
                Literal[] headLiterals;
                if (m_integrityConstraintRules.contains(rule))
                    headLiterals=new Literal[0];
                else
                    headLiterals=rule.getHeadLiterals();
                Literal[] bodyLiterals=new Literal[rule.getHeadLength()+rule.getBodyLength()];
                for (int headLiteralIndex=0;headLiteralIndex<rule.getHeadLength();headLiteralIndex++) {
                    Literal headLiteral=rule.getHeadLiteral(headLiteralIndex);
                    Predicate esvPredicate=getESVPredicate(headLiteral.getPredicate());
                    bodyLiterals[headLiteralIndex]=new Literal(esvPredicate,true,headLiteral.getTerms());
                }
                System.arraycopy(rule.getBodyLiterals(),0,bodyLiterals,rule.getHeadLength(),rule.getBodyLength());
                rules.add(new Rule(rule.getLabel(),headLiterals,bodyLiterals));
            }
            Rule[] rulesArray=new Rule[rules.size()];
            rules.toArray(rulesArray);
            m_magicProgram=new Program(rulesArray);
        }
        /**
         * Generates the ESV program.
         */
        protected void generateESVProgram() {
            int numberOfESVRules=0;
            for (int i=0;i<m_noIntegrityConstraintsProgram.getNumberOfRules();i++) {
                Rule rule=m_noIntegrityConstraintsProgram.getRule(i);
                if (rule.getHeadLength()==0)
                    numberOfESVRules++;
                else
                    numberOfESVRules+=rule.getHeadLength()*rule.getHeadLength();
            }
            int index=0;
            Rule[] esvRules=new Rule[numberOfESVRules];
            for (int i=0;i<m_noIntegrityConstraintsProgram.getNumberOfRules();i++) {
                Rule rule=m_noIntegrityConstraintsProgram.getRule(i);
                generateESVRules(rule,esvRules,index);
                index+=rule.getHeadLength()*rule.getHeadLength();
            }
            m_esvProgram=new Program(esvRules);
            m_esvQueryPredicate=getESVPredicate(m_queryPredicate);
        }
        /**
         * Generates the ESV rules from a program rule.
         *
         * @param rule                              the rule for which the ESV rules are generated
         * @param esvRules                          the array of ESV rules
         * @param index                             the index where to start putting the ESV rules
         */
        protected void generateESVRules(Rule rule,Rule[] esvRules,int index) {
            // first generate the ESV version of the body
            int bodyLength=rule.getBodyLength();
            Literal[] svBodyLiterals=new Literal[bodyLength];
            for (int bodyLiteralIndex=0;bodyLiteralIndex<bodyLength;bodyLiteralIndex++) {
                Literal bodyLiteral=rule.getBodyLiteral(bodyLiteralIndex);
                Predicate predicate=bodyLiteral.getPredicate();
                if (m_idbPredicates.contains(predicate)) {
                    Predicate esvPredicate=getESVPredicate(predicate);
                    svBodyLiterals[bodyLiteralIndex]=new Literal(esvPredicate,bodyLiteral.isPositive(),bodyLiteral.getTerms(),false);
                }
                else
                    svBodyLiterals[bodyLiteralIndex]=bodyLiteral;
            }
            int headLength=rule.getHeadLength();
            //  now generate the SV rules
            for (int headLiteralIndex=0;headLiteralIndex<headLength;headLiteralIndex++) {
                Literal headLiteral=rule.getHeadLiteral(headLiteralIndex);
                Predicate esvPredicate=getESVPredicate(headLiteral.getPredicate());
                Literal esvHeadLiteral=new Literal(esvPredicate,true,headLiteral.getTerms());
                Rule svRule=new Rule(rule.getLabel()+"_SV["+esvHeadLiteral.toString()+"]",new Literal[] { esvHeadLiteral },svBodyLiterals);
                esvRules[index++]=svRule;
            }
            // now generate the ESV rules
            for (int mainHeadLiteralIndex=0;mainHeadLiteralIndex<headLength;mainHeadLiteralIndex++) {
                Literal mainHeadLiteral=rule.getHeadLiteral(mainHeadLiteralIndex);
                Predicate mainESVPredicate=getESVPredicate(mainHeadLiteral.getPredicate());
                Literal[] mainHeadLiterals=new Literal[] { new Literal(mainESVPredicate,true,mainHeadLiteral.getTerms()) };
                for (int sideHeadLiteralIndex=0;sideHeadLiteralIndex<headLength;sideHeadLiteralIndex++)
                    if (mainHeadLiteralIndex!=sideHeadLiteralIndex) {
                        Literal sideHeadLiteral=rule.getHeadLiteral(sideHeadLiteralIndex);
                        Predicate sideESVPredicate=getESVPredicate(sideHeadLiteral.getPredicate());
                        Literal[] esvBodyLiterals=new Literal[1+bodyLength];
                        System.arraycopy(svBodyLiterals,0,esvBodyLiterals,1,bodyLength);
                        esvBodyLiterals[0]=new Literal(sideESVPredicate,true,sideHeadLiteral.getTerms());
                        Rule esvRule=new Rule(rule.getLabel()+"_ESV["+mainHeadLiterals[0].toString()+"]["+esvBodyLiterals[0].toString()+"]",mainHeadLiterals,esvBodyLiterals);
                        esvRules[index++]=esvRule;
                        m_rulesESVMinusSV.add(esvRule);
                    }
            }
        }
        /**
         * Returns the ESV predicate for the given predicate.
         *
         * @param predicate                         the predicate for which the ESV predicate should be returned
         * @return                                  the ESV predicate
         */
        protected Predicate getESVPredicate(Predicate predicate) {
            Predicate esvPredicate=(Predicate)m_esvPredicates.get(predicate);
            if (esvPredicate==null) {
                esvPredicate=m_predicateFactory.getPredicate("ESV_"+predicate.getSimpleName(),predicate.getArity());
                m_esvPredicates.put(predicate,esvPredicate);
            }
            return esvPredicate;
        }
        /**
         * Eliminates the integrity constraints by introducing new predicates.
         */
        protected void eliminateIntegrityConstraints() {
            int integrityConstraintIndex=0;
            Rule[] noIntegrityConstraintsRules=new Rule[m_program.getNumberOfRules()];
            for (int i=0;i<m_program.getNumberOfRules();i++) {
                Rule rule=m_program.getRule(i);
                if (rule.getHeadLength()==0) {
                    Set variables=new HashSet();
                    for (int bodyLiteralIndex=0;bodyLiteralIndex<rule.getBodyLength();bodyLiteralIndex++) {
                        Literal literal=rule.getBodyLiteral(bodyLiteralIndex);
                        for (int argumentIndex=0;argumentIndex<literal.getArity();argumentIndex++) {
                            Term term=literal.getTerm(argumentIndex);
                            if (term instanceof UnaryFunctionCall)
                                term=((UnaryFunctionCall)term).getCallTerm();
                            if (term instanceof Variable)
                                variables.add(term);
                        }
                    }
                    if (!variables.isEmpty()) {
                        Term[] variablesArray=new Term[variables.size()];
                        variables.toArray(variablesArray);
                        Predicate predicate=m_predicateFactory.getPredicate("ic"+(integrityConstraintIndex++),variables.size());
                        rule=new Rule(rule.getLabel()+"_ic",new Literal(predicate,true,variablesArray),rule.getBodyLiterals());
                        m_integrityConstraintRules.add(rule);
                    }
                }
                noIntegrityConstraintsRules[i]=rule;
            }
            m_noIntegrityConstraintsProgram=new Program(noIntegrityConstraintsRules);
        }
    }
}
