package edu.unika.aifb.kaon.datalog.evaluation;

import java.util.Set;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashMap;
import java.util.Iterator;

import edu.unika.aifb.kaon.datalog.*;
import edu.unika.aifb.kaon.datalog.program.*;

/**
 * Compiles a rule. The compiler is intialized with a set of extensional databases. The rule is read
 * from left to right and literals that can be evaluated in one extensional database are collected.
 * Then these literals are passed to the appropriate extensional database and a query operator
 * evaluating them is obtained.
 */
public class RuleCompiler {
    /** The manager for the database extensions. */
    protected ExtensionalManager m_extensionalManager;
    /** The head literals of the rule. */
    protected Literal[] m_headLiterals;
    /** The body literals of the rule. */
    protected Literal[] m_bodyLiterals;
    /** The counter for new variables. */
    protected int m_newVariableCounter;
    /** The variables of the query operator. */
    protected List m_operatorVariables;
    /** The variables containing information about disjunctions. */
    protected Set m_disjunctionInfoVariables;
    /** The unification information. */
    protected Map m_unificationInfos;
    /** Query operator for the rule. */
    protected QueryOperator m_queryOperator;

    /**
     * Initializes this class.
     *
     * @param extensionalManager                the manager for the database extensions
     * @param rule                              the rule being compiled
     */
    public RuleCompiler(ExtensionalManager extensionalManager,Rule rule) {
        m_extensionalManager=extensionalManager;
        m_headLiterals=(Literal[])rule.getHeadLiterals().clone();
        m_bodyLiterals=(Literal[])rule.getBodyLiterals().clone();
    }
    /**
     * Compiles the actual rule.
     *
     * @throws DatalogException                 thrown if there is an error
     */
    public void compile() throws DatalogException {
        m_unificationInfos=new HashMap();
        m_disjunctionInfoVariables=new HashSet();
        m_operatorVariables=new ArrayList();
        if (m_bodyLiterals.length==0)
            m_queryOperator=new SingletonQueryOperator(new Object[0]);
        else {
            Map functionCallsToVariables=new HashMap();
            Set visibleVariables=new HashSet();
            Set negatedNonBoundVariables=new HashSet();
            processFunctionCalls(functionCallsToVariables,visibleVariables,negatedNonBoundVariables,m_bodyLiterals,true);
            processFunctionCalls(functionCallsToVariables,visibleVariables,negatedNonBoundVariables,m_headLiterals,false);
            processDisjunctiveBodyPredicates();
            createOperator();
        }
    }
    /**
     * Returns <code>true</code> if the disjunction infos will be passed to the head.
     *
     * @return                                  <code>true</code> if the disjunction infos will be passed to the head
     */
    public boolean getDisjunctionInfosPassed() {
        return !m_disjunctionInfoVariables.isEmpty();
    }
    /**
     * Returns the array of indices of positions in the operator where disjunction variables occur.
     *
     * @return                                  the array of indices where disjunction variables occur
     */
    public int[] getDisjunctionInfoPositions() {
        int[] result=new int[m_disjunctionInfoVariables.size()];
        int index=0;
        Iterator iterator=m_disjunctionInfoVariables.iterator();
        while (iterator.hasNext()) {
            Variable variable=(Variable)iterator.next();
            result[index++]=m_operatorVariables.indexOf(variable);
        }
        return result;
    }
    /**
     * Returns the predicate of the head.
     *
     * @param headLiteralIndex                  the index of the head literal
     * @return                                  the predicate of the head
     */
    public Predicate getHeadPredicate(int headLiteralIndex) {
        return m_headLiterals[headLiteralIndex].getPredicate();
    }
    /**
     * Returns the predicates of the head.
     *
     * @return                                  the predicates of the head
     */
    public Predicate[] getHeadPredicates() {
        Predicate[] predicate=new Predicate[m_headLiterals.length];
        for (int i=0;i<m_headLiterals.length;i++)
            predicate[i]=m_headLiterals[i].getPredicate();
        return predicate;
    }
    /**
     * Returns the query operator returning all necessary variables in the rule.
     *
     * @return                                  the query operator
     */
    public QueryOperator getQueryOperator() {
        return m_queryOperator;
    }
    /**
     * Preprocesses the function calls in the literal array.
     *
     * @param functionCallsToVariables          the mapping of function calls to variables
     * @param visibleVariables                  the set of visible variables
     * @param negatedNonBoundVariables          the set of variables that weren't bound in a previous literal
     * @param literals                          the array of literals to be preprocessed
     * @param isInBody                          <code>true</code> if these are the body literals
     * @throws DatalogException                 thrown if there is an error
     */
    protected void processFunctionCalls(Map functionCallsToVariables,Set visibleVariables,Set negatedNonBoundVariables,Literal[] literals,boolean isInBody) throws DatalogException {
        for (int i=0;i<literals.length;i++) {
            Literal literal=literals[i];
            boolean hasFunctionCall=false;
            for (int j=0;j<literal.getArity();j++) {
                Variable variable=literal.getArgumentVariable(j);
                if (literal.isArgumentBoundToUnaryFunctionCall(j)) {
                    hasFunctionCall=true;
                    Term term=literal.getArgumentUnaryFunctionCall(j).getCallTerm();
                    if (term instanceof Variable)
                        variable=(Variable)term;
                }
                // This check makessure that if a variable occurs in a negative literal and is not bound at that point,
                // it then shouldn't occur at any position subsequently in the rule or in the rule head. This is equivalent
                // to the definition of safety (each variable in the rule occurs in a positive body literal). In this
                // definition, a variable may occur only in a single negative body literal, which means that the value of the
                // variable is not important.
                // NOTE: the rule must be written in a way that, when read left-to-right, variables are bound.
                if (variable!=null) {
                    if (literal.isPositive()) {
                        if (negatedNonBoundVariables.contains(variable))
                            throw new DatalogException("Variable '"+variable.getVariableName()+"' occurs negated in a literal before its positive occurence, making thus the negated variable flounder.");
                    }
                    else {
                        if (!visibleVariables.contains(variable))
                            negatedNonBoundVariables.add(variable);
                    }
                }
            }
            if (hasFunctionCall) {
                Term[] newTerms=new Term[literal.getArity()];
                for (int j=0;j<literal.getArity();j++) {
                    UnaryFunctionCall unaryFunctionCall=literal.getArgumentUnaryFunctionCall(j);
                    if (unaryFunctionCall!=null) {
                        Variable variable=(Variable)functionCallsToVariables.get(unaryFunctionCall);
                        if (variable!=null)
                            newTerms[j]=variable;
                        else {
                            Term callTerm=unaryFunctionCall.getCallTerm();
                            if (callTerm instanceof Constant) {
                                Object callValue=instantiateFunctionCall(unaryFunctionCall,(Constant)callTerm);
                                newTerms[j]=new Constant(callValue);
                            }
                            else {
                                newTerms[j]=new Variable("internal$"+(m_newVariableCounter++));
                                functionCallsToVariables.put(unaryFunctionCall,newTerms[j]);
                                if (visibleVariables.contains(callTerm)) {
                                    if (isInBody)
                                        addUnificationInfo(i-1,unaryFunctionCall,true,(Variable)callTerm,(Variable)newTerms[j]);
                                    else
                                        addUnificationInfo(m_bodyLiterals.length-1,unaryFunctionCall,true,(Variable)callTerm,(Variable)newTerms[j]);

                                }
                                else {
                                    if (isInBody)
                                        addUnificationInfo(i,unaryFunctionCall,false,(Variable)newTerms[j],(Variable)callTerm);
                                    else
                                        throw new DatalogException("The rule is unsafe.");
                                }
                            }
                        }
                    }
                    Constant constant=literal.getArgumentConstant(j);
                    if (constant!=null)
                        newTerms[j]=constant;
                    Variable variable=literal.getArgumentVariable(j);
                    if (variable!=null)
                        newTerms[j]=variable;
                }
                literal=new Literal(literal.getPredicate(),literal.isPositive(),newTerms);
                literals[i]=literal;
            }
            for (int j=0;j<literal.getArity();j++) {
                Variable variable=literal.getArgumentVariable(j);
                if (variable!=null)
                    visibleVariables.add(variable);
            }
        }
    }
    /**
     * Processes the disjunctions in the body of the rule and adds the disjunction variables.
     *
     * @throws DatalogException                 thrown if there is an error
     */
    protected void processDisjunctiveBodyPredicates() throws DatalogException {
        for (int i=0;i<m_bodyLiterals.length;i++) {
            Literal literal=m_bodyLiterals[i];
            ExtensionalDatabase database=m_extensionalManager.getExtensionalDatabase(literal.getPredicate());
            if (database==null)
                throw new DatalogException("Can't find a database for predicate '"+literal.getPredicate().getFullName()+"'.");
            if (database.getContainsDisjunctionInfo(literal.getPredicate())) {
                Term[] terms=new Term[literal.getArity()+1];
                for (int j=0;j<literal.getArity();j++)
                    terms[j]=literal.getTerm(j);
                terms[literal.getArity()]=new Variable("disjunction$"+(++m_newVariableCounter));
                m_disjunctionInfoVariables.add(terms[literal.getArity()]);
                m_bodyLiterals[i]=new Literal(literal.getPredicate(),literal.isPositive(),terms,true);
            }
        }
    }
    /**
     * Adds information about unification.
     *
     * @param index                             the index
     * @param unaryFunctionCall                 the function call
     * @param instantiateCall                   determines whether the call should be instantiated
     * @param argument                          the argument for which the call was instantiated
     * @param instantiatedVariable              the variable being instantiated
     */
    protected void addUnificationInfo(int index,UnaryFunctionCall unaryFunctionCall,boolean instantiateCall,Variable argument,Variable instantiatedVariable) {
        Integer key=new Integer(index);
        UnificationInfo unificationInfo=(UnificationInfo)m_unificationInfos.get(key);
        if (unificationInfo==null) {
            unificationInfo=new UnificationInfo();
            m_unificationInfos.put(key,unificationInfo);
        }
        unificationInfo.addFunctionCall(unaryFunctionCall,instantiateCall,argument,instantiatedVariable);
    }
    /**
     * Returns the query operator for the case if the rule has body literals and no function calls.
     *
     * @throws DatalogException                 thrown if there is an error
     */
    protected void createOperator() throws DatalogException {
        Set[] literalVariables=computeNeededVariables();
        List bindingsVariables=new ArrayList();
        QueryOperator bindingsOperator=null;
        ExtensionalDatabase currentDatabase=getExtensionalDatabase(0,m_bodyLiterals[0].getPredicate());
        int lastIndex=0;
        for (int i=0;i<m_bodyLiterals.length;i++) {
            UnificationInfo unificationInfo=(UnificationInfo)m_unificationInfos.get(new Integer(i));
            updateOperatorVariables(i,literalVariables,unificationInfo);
            if (unificationInfo!=null || i==m_bodyLiterals.length-1 || !currentDatabase.canEvaluatePredicate(m_bodyLiterals[i+1].getPredicate())) {
                bindingsOperator=processSegment(bindingsOperator,currentDatabase,m_bodyLiterals,lastIndex,i,bindingsVariables,m_operatorVariables);
                lastIndex=i+1;
                bindingsVariables.clear();
                bindingsVariables.addAll(m_operatorVariables);
                if (i!=m_bodyLiterals.length-1)
                    currentDatabase=getExtensionalDatabase(i+1,m_bodyLiterals[i+1].getPredicate());
            }
            if (unificationInfo!=null) {
                updateOperatorVariables(i,literalVariables,null);
                unificationInfo.addInstantiatedVariables(m_operatorVariables);
                bindingsOperator=unificationInfo.getUnificationOperator(bindingsOperator,bindingsVariables,m_operatorVariables);
                bindingsVariables.clear();
                bindingsVariables.addAll(m_operatorVariables);
            }
        }
        m_queryOperator=bindingsOperator;
    }
    /**
     * Retrurns the renamer for the tuples for given head literal. This object can buld the appropriate
     * tuple for the given head predicate.
     *
     * @param headLiteralIndex                  the index of the head literal
     * @return                                  the tuple renamer
     */
    public TupleRename getTupleRename(int headLiteralIndex) {
        Literal headLiteral=m_headLiterals[headLiteralIndex];
        int constantsInHead=0;
        for (int i=0;i<headLiteral.getArity();i++)
            if (headLiteral.isArgumentBoundToConstant(i))
                constantsInHead++;
        Object[] constants=new Object[constantsInHead];
        int[] indicesOfConstants=new int[constantsInHead];
        int lastIndex=0;
        for (int i=0;i<headLiteral.getArity();i++) {
            Constant constant=headLiteral.getArgumentConstant(i);
            if (constant!=null) {
                constants[lastIndex]=constant.getValue();
                indicesOfConstants[lastIndex]=i;
                lastIndex++;
            }
        }
        int[][] indicesToCopy=new int[headLiteral.getArity()-constantsInHead][2];
        lastIndex=0;
        for (int i=0;i<headLiteral.getArity();i++) {
            Variable variable=headLiteral.getArgumentVariable(i);
            if (variable!=null) {
                indicesToCopy[lastIndex][0]=m_operatorVariables.indexOf(variable);
                indicesToCopy[lastIndex][1]=i;
                lastIndex++;
            }
        }
        return new TupleRename(indicesToCopy,constants,indicesOfConstants);
    }
    /**
     * Returns the renames for the tuple.
     *
     * @return                                  the renamers for the tuple
     */
    public TupleRename[] getTupleRenames() {
        TupleRename[] tupleRenames=new TupleRename[m_headLiterals.length];
        for (int i=0;i<m_headLiterals.length;i++)
            tupleRenames[i]=getTupleRename(i);
        return tupleRenames;
    }
    /**
     * Updates the variable lists with supplied literal.
     *
     * @param index                             the index
     * @param literalVariables                  the sets of literal variables
     * @param unificationInfo                   the information about the unification
     */
    protected void updateOperatorVariables(int index,Set[] literalVariables,UnificationInfo unificationInfo) {
        Literal literal=m_bodyLiterals[index];
        for (int i=0;i<literal.getArity();i++) {
            Variable variable=literal.getArgumentVariable(i);
            if (variable!=null && !m_operatorVariables.contains(variable))
                m_operatorVariables.add(variable);
        }
        Iterator iterator=m_operatorVariables.iterator();
        nextVariable: while (iterator.hasNext()) {
            Variable variable=(Variable)iterator.next();
            if (unificationInfo!=null && unificationInfo.isNeededForUnification(variable))
                continue nextVariable;
            for (int i=index+1;i<literalVariables.length;i++)
                if (literalVariables[i].contains(variable))
                    continue nextVariable;
            iterator.remove();
        }
    }
    /**
     * Processes the specified literals by passing them to the supplied extensional database.
     *
     * @param bindingsOperator                  the operator of bindings
     * @param extensionalDatabase               the extensional database in which the literals are evaluated
     * @param bodyLiterals                      the body literals
     * @param startLiteralIndex                 the index of the first literal
     * @param endLiteralIndex                   the index of the last literal
     * @param boundVariables                    the list of bound variables
     * @param requestedVariables                the list of variables to generate
     * @return                                  the operator from the extensional database
     * @throws DatalogException                 thrown if there is an error
     */
    protected QueryOperator processSegment(QueryOperator bindingsOperator,ExtensionalDatabase extensionalDatabase,Literal[] bodyLiterals,int startLiteralIndex,int endLiteralIndex,List boundVariables,List requestedVariables) throws DatalogException {
        Literal[] literals=new Literal[endLiteralIndex-startLiteralIndex+1];
        for (int i=startLiteralIndex;i<=endLiteralIndex;i++)
            literals[i-startLiteralIndex]=bodyLiterals[i];
        Variable[] boundVariablesArray=new Variable[boundVariables.size()];
        boundVariables.toArray(boundVariablesArray);
        Variable[] requestedVariablesArray=new Variable[requestedVariables.size()];
        requestedVariables.toArray(requestedVariablesArray);
        return extensionalDatabase.createQueryOperator(literals,bindingsOperator,boundVariablesArray,requestedVariablesArray);
    }
    /**
     * Locates some extensional database that can evaluate supplied predicate.
     *
     * @param literalIndex                      the index of literal
     * @param predicate                         the predicate for which the extensional database is sought
     * @return                                  the extensional database for supplied predicate (or <code>null</code> if no such database exists)
     * @throws DatalogException                 thrown if there is an error
     */
    protected ExtensionalDatabase getExtensionalDatabase(int literalIndex,Predicate predicate) throws DatalogException {
        ExtensionalDatabase extensionalDatabase=m_extensionalManager.getExtensionalDatabase(predicate);
        if (extensionalDatabase==null)
            throw new DatalogException("Cannot locate extensional database for predicate '"+predicate.getFullName()+"'.");
        else
            return extensionalDatabase;
    }
    /**
     * Computes the array of variables for given array of literals.
     *
     * @return                                  sets of variables mentioned in literals
     */
    protected Set[] computeNeededVariables() {
        Set[] literalVariables=new Set[m_bodyLiterals.length+1];
        for (int i=0;i<m_bodyLiterals.length;i++) {
            literalVariables[i]=new HashSet();
            Literal literal=m_bodyLiterals[i];
            for (int j=0;j<literal.getArity();j++) {
                Variable variable=literal.getArgumentVariable(j);
                if (variable!=null)
                    literalVariables[i].add(variable);
            }
            UnificationInfo unificationInfo=(UnificationInfo)m_unificationInfos.get(new Integer(i));
            if (unificationInfo!=null)
                unificationInfo.addAllArguments(literalVariables[i]);
        }
        literalVariables[m_bodyLiterals.length]=new HashSet();
        for (int i=0;i<m_headLiterals.length;i++) {
            Literal literal=m_headLiterals[i];
            for (int j=0;j<literal.getArity();j++) {
                Variable variable=literal.getArgumentVariable(j);
                if (variable!=null)
                    literalVariables[m_bodyLiterals.length].add(variable);
            }
        }
        literalVariables[m_bodyLiterals.length].addAll(m_disjunctionInfoVariables);
        return literalVariables;
    }
    /**
     * Returns the function call instantiated for given constant.
     *
     * @param unaryFunctionCall                 the function call for which the term is needed
     * @param constant                          the constant
     * @return                                  the term of the call
     */
    protected Object instantiateFunctionCall(UnaryFunctionCall unaryFunctionCall,Constant constant) {
        Term argument=unaryFunctionCall.getArgument();
        Object value;
        if (argument instanceof UnaryFunctionCall)
            value=instantiateFunctionCall((UnaryFunctionCall)argument,constant);
        else
            value=constant.getValue();
        return new UnaryFunctionCallValue(unaryFunctionCall.getFunctionName(),value);
    }

    /**
     * Contains information about unification that should be performed after each literal.
     */
    protected static class UnificationInfo {
        protected List m_unaryFunctionCalls;
        protected List m_instantiateCalls;
        protected List m_arguments;
        protected List m_instantiatedVariables;

        public UnificationInfo() {
            m_unaryFunctionCalls=new ArrayList();
            m_instantiateCalls=new ArrayList();
            m_arguments=new ArrayList();
            m_instantiatedVariables=new ArrayList();
        }
        public UnaryFunctionCall[] getUnaryFunctionCalls() {
            UnaryFunctionCall[] result=new UnaryFunctionCall[m_unaryFunctionCalls.size()];
            m_unaryFunctionCalls.toArray(result);
            return result;
        }
        public boolean[] getInstantiateCalls() {
            boolean[] result=new boolean[m_instantiateCalls.size()];
            for (int i=0;i<m_instantiateCalls.size();i++) {
                Boolean value=(Boolean)m_instantiateCalls.get(i);
                result[i]=value.booleanValue();
            }
            return result;
        }
        public int[] getIndicesOfArguments(List variables) {
            int[] result=new int[m_arguments.size()];
            for (int i=0;i<m_arguments.size();i++) {
                Variable variable=(Variable)m_arguments.get(i);
                result[i]=variables.indexOf(variable);
            }
            return result;
        }
        public int[] getIndicesOfResults(List variables) {
            int[] result=new int[m_instantiatedVariables.size()];
            for (int i=0;i<m_instantiatedVariables.size();i++) {
                Variable variable=(Variable)m_instantiatedVariables.get(i);
                result[i]=variables.indexOf(variable);
            }
            return result;
        }
        public void addFunctionCall(UnaryFunctionCall unaryFunctionCall,boolean instantiateCall,Variable argument,Variable instantiatedVariable) {
            m_unaryFunctionCalls.add(unaryFunctionCall);
            m_instantiateCalls.add(instantiateCall ? Boolean.TRUE : Boolean.FALSE);
            m_arguments.add(argument);
            m_instantiatedVariables.add(instantiatedVariable);
        }
        public QueryOperator getUnificationOperator(QueryOperator sourceOperator,List boundVariables,List neededVariables) {
            List indicesToCopy=new ArrayList();
            for (int i=0;i<neededVariables.size();i++) {
                Variable variable=(Variable)neededVariables.get(i);
                int boundVariableIndex=boundVariables.indexOf(variable);
                if (boundVariableIndex!=-1)
                    indicesToCopy.add(new int[] { boundVariableIndex,i });
            }
            int[][] indicesToCopyArray=new int[indicesToCopy.size()][];
            indicesToCopy.toArray(indicesToCopyArray);
            return new UnificationOperator(sourceOperator,indicesToCopyArray,getUnaryFunctionCalls(),getIndicesOfArguments(boundVariables),getInstantiateCalls(),getIndicesOfResults(neededVariables));
        }
        public boolean isNeededForUnification(Variable variable) {
            return m_arguments.contains(variable);
        }
        public void addInstantiatedVariables(List variables) {
            for (int i=0;i<m_instantiatedVariables.size();i++) {
                Variable variable=(Variable)m_instantiatedVariables.get(i);
                if (!variables.contains(variable))
                    variables.add(variable);
            }
        }
        public void addAllArguments(Set set) {
            set.addAll(m_arguments);
        }
    }
}
