package edu.unika.aifb.kaon.datalog.evaluation;

import java.util.Collection;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Collections;

import edu.unika.aifb.kaon.datalog.program.*;

/**
 * A class that tracks the state of the compilation of a series of the literals. At each point in time
 * m_variablesOfPassedValues contains the variables of values before the current literal,
 * m_variablesAfterCurrentLiteral contains the variables visible after the current literal.
 */
public class ExpressionCompiler {
    /** The array of literals being compiled. */
    protected Literal[] m_literals;
    /** The set of variables in each literal. */
    protected Set[] m_literalVariables;
    /** The index of the current literal. */
    protected int m_currentLiteralIndex;
    /** The list of variables of the values being passed to the first uncompiled literal. */
    protected List m_variablesOfPassedValues;
    /** The list of variables produced after the compilation of the literal. */
    protected List m_variablesAfterCurrentLiteral;
    /** The list of the variables of the current literal. */
    protected List m_currentLiteralVariables;
    /** The list of variables that are needed at the end of expression. */
    protected List m_neededVariables;
    /** The set of the join variables. */
    protected Set m_joinVariables;

    /**
     * Initializes this class.
     *
     * @param literals                          the literals being compiled
     * @param passedVariables                   the array of variables passed to the literals
     * @param neededVariables                   the array of variables needed at the end
     */
    public ExpressionCompiler(Literal[] literals,Variable[] passedVariables,Variable[] neededVariables) {
        m_literals=literals;
        m_variablesOfPassedValues=new ArrayList();
        for (int i=0;i<passedVariables.length;i++)
            m_variablesOfPassedValues.add(passedVariables[i]);
        m_variablesAfterCurrentLiteral=new ArrayList(m_variablesOfPassedValues);
        m_currentLiteralVariables=new ArrayList();
        m_currentLiteralIndex=-1;
        m_neededVariables=new ArrayList();
        for (int i=0;i<neededVariables.length;i++)
            m_neededVariables.add(neededVariables[i]);
        m_joinVariables=new HashSet();
        m_literalVariables=new Set[m_literals.length+1];
        for (int i=0;i<m_literals.length;i++) {
            Set set=new HashSet();
            Literal literal=m_literals[i];
            for (int j=0;j<literal.getArity();j++) {
                Variable variable=literal.getArgumentVariable(j);
                if (variable!=null)
                    set.add(variable);
            }
            m_literalVariables[i]=set;
        }
        m_literalVariables[m_literals.length]=new HashSet(m_neededVariables);
    }
    /**
     * Compiles the next literal.
     */
    public void compileNextLiteral() {
        m_currentLiteralIndex++;
        m_variablesOfPassedValues.clear();
        m_variablesOfPassedValues.addAll(m_variablesAfterCurrentLiteral);
        m_currentLiteralVariables.clear();
        m_joinVariables.clear();
        if (m_currentLiteralIndex<m_literals.length) {
            Literal literal=m_literals[m_currentLiteralIndex];
            for (int i=0;i<literal.getArity();i++) {
                Variable variable=literal.getArgumentVariable(i);
                m_currentLiteralVariables.add(variable);    // this is OK: m_currentLiteralVariables is used to determine the indices of variables
                if (variable!=null && !m_variablesAfterCurrentLiteral.contains(variable))
                    m_variablesAfterCurrentLiteral.add(variable);
            }
            if (m_currentLiteralIndex==m_literals.length-1) {
                // in the case of the last literal we need the variables in their original order
                m_variablesAfterCurrentLiteral.clear();
                m_variablesAfterCurrentLiteral.addAll(m_neededVariables);
            }
            else {
                Iterator iterator=m_variablesAfterCurrentLiteral.iterator();
                nextVariable: while (iterator.hasNext()) {
                    Variable variable=(Variable)iterator.next();
                    for (int j=m_currentLiteralIndex+1;j<m_literalVariables.length;j++)
                        if (m_literalVariables[j].contains(variable))
                            continue nextVariable;
                    iterator.remove();
                }
            }
            m_joinVariables.addAll(m_variablesOfPassedValues);
            m_joinVariables.retainAll(m_literalVariables[m_currentLiteralIndex]);
        }
    }
    /**
     * Checks whether the entire expression has been compiled.
     *
     * @return                                  <code>true</code> if entire expression has been compiled
     */
    public boolean expressionCompiled() {
        return m_currentLiteralIndex>=m_literals.length-1;
    }
    /**
     * Returns the indices of join positions of passed values.
     *
     * @return                                  the list of indices
     */
    public int[] getPassedValuesJoinIndices() {
        return getIndicesOfVariables(m_variablesOfPassedValues,m_joinVariables);
    }
    /**
     * Returns the indices of non-join positions of passed values.
     *
     * @return                                  the list of indices
     */
    public int[] getPassedValuesNonJoinIndices() {
        Set nonJoinVariables=new HashSet(m_variablesOfPassedValues);
        nonJoinVariables.removeAll(m_joinVariables);
        return getIndicesOfVariables(m_variablesOfPassedValues,nonJoinVariables);
    }
    /**
     * Returns the indices of join positions of the current literal.
     *
     * @return                                  the list of indices
     */
    public int[] getCurrentLiteralJoinIndices() {
        return getIndicesOfVariablesInOrder(m_currentLiteralVariables,m_joinVariables,m_variablesOfPassedValues);
    }
    /**
     * Returns the indices of non-join positions of the current literal.
     *
     * @return                                  the list of indices
     */
    public int[] getCurrentLiteralNonJoinIndices() {
        Literal literal=m_literals[m_currentLiteralIndex];
        int numberOfNonJoinIndices=0;
        for (int i=0;i<literal.getArity();i++) {
            Variable variable=literal.getArgumentVariable(i);
            if (variable==null || !m_joinVariables.contains(variable))
                numberOfNonJoinIndices++;
        }
        int[] result=new int[numberOfNonJoinIndices];
        int index=0;
        for (int i=0;i<literal.getArity();i++) {
            Variable variable=literal.getArgumentVariable(i);
            if (variable==null || !m_joinVariables.contains(variable))
                result[index++]=i;
        }
        return result;
    }
    /**
     * Returns the array of pairs of indices to copy from the passed values.
     *
     * @return                                  the array of pairs of indices to copy from the passed values
     */
    public int[][] getPassedValuesIndicesToCopy() {
        return getPairsOfMatchingIndices(m_variablesOfPassedValues,m_variablesAfterCurrentLiteral,Collections.EMPTY_SET);
    }
    /**
     * Returns the array of pairs of indices to copy from the current literal values.
     *
     * @return                                  the array of pairs of indices to copy from the current literal
     */
    public int[][] getCurrentLiteralIndicesToCopy() {
        return getPairsOfMatchingIndices(m_currentLiteralVariables,m_variablesAfterCurrentLiteral,m_variablesOfPassedValues);
    }
    /**
     * Returns <code>true</code> if there are some values passed to the current literal.
     *
     * @return                                  <code>true</code> if some values are passed to the current literal
     */
    public boolean hasPassedValues() {
        return !m_variablesOfPassedValues.isEmpty();
    }
    /**
     * Returns the values for the constants bound in the current literal.
     *
     * @return                                  the values for the constants bound in the current literal
     */
    public Object[] getBoundConstantValues() {
        int count=0;
        Literal literal=m_literals[m_currentLiteralIndex];
        for (int i=literal.getArity()-1;i>=0;--i)
            if (literal.isArgumentBoundToConstant(i))
                count++;
        Object[] result=new Object[count];
        count=0;
        for (int i=literal.getArity()-1;i>=0;--i) {
            Constant constant=literal.getArgumentConstant(i);
            if (constant!=null)
                result[count++]=constant.getValue();
        }
        return result;
    }
    /**
     * Returns the indices of the bound conditions.
     *
     * @return                                  the indices of the constants bound in the current literal
     */
    public int[] getBoundConstantPositions() {
        int count=0;
        Literal literal=m_literals[m_currentLiteralIndex];
        for (int i=literal.getArity()-1;i>=0;--i)
            if (literal.isArgumentBoundToConstant(i))
                count++;
        int[] result=new int[count];
        count=0;
        for (int i=literal.getArity()-1;i>=0;--i)
            if (literal.isArgumentBoundToConstant(i))
                result[count++]=i;
        return result;
    }
    /**
     * Returns the current predicate.
     *
     * @return                                  the last compiled predicate
     */
    public Predicate getCurrentPredicate() {
        return m_literals[m_currentLiteralIndex].getPredicate();
    }
    /**
     * Returns the current literal.
     *
     * @return                                  the current literal
     */
    public Literal getCurrentLiteral() {
        return m_literals[m_currentLiteralIndex];
    }
    /**
     * Returns the join filter to apply to the last compiled literal.
     *
     * @return                                  the join filter to apply
     */
    public JoinTupleFilter getJoinTupleFilter() {
        final int[][] tuple2filterPairs=getSameVariablePairs(m_currentLiteralVariables);
        if (tuple2filterPairs.length==0)
            return Filters.ALWAYS_TRUE_JOIN_TUPLE_FILTER;
        else {
            return new JoinTupleFilter() {
                public boolean shouldJoin(Object[] tuple1,Object[] tuple2) {
                    for (int i=tuple2filterPairs.length-1;i>=0;--i) {
                        Object value1=tuple2[tuple2filterPairs[i][0]];
                        Object value2=tuple2[tuple2filterPairs[i][1]];
                        if (value1!=value2 && (value1==null || !value1.equals(value2)))
                            return false;
                    }
                    return true;
                }
            };
        }
    }
    /**
     * Returns the tuple filter to apply to the last compiled literal.
     *
     * @return                                  the tuple filter
     */
    public TupleFilter getTupleFilter() {
        final int[][] filterPairs=getSameVariablePairs(m_currentLiteralVariables);
        if (filterPairs.length==0)
            return Filters.ALWAYS_TRUE_TUPLE_FILTER;
        else {
            return new TupleFilter() {
                public boolean evaluate(Object[] tuple) {
                    for (int i=filterPairs.length-1;i>=0;--i) {
                        Object value1=tuple[filterPairs[i][0]];
                        Object value2=tuple[filterPairs[i][1]];
                        if (value1!=value2 && (value1==null || !value1.equals(value2)))
                            return false;
                    }
                    return true;
                }
            };
        }
    }
    /**
     * Computes the pairs of indices of same variables.
     *
     * @param variables                         the variables
     * @return                                  the array of pairs of indices of same variables
     */
    protected int[][] getSameVariablePairs(List variables) {
        List samePairs=new ArrayList();
        for (int i=0;i<variables.size();i++) {
            Variable variable1=(Variable)variables.get(i);
            if (variable1!=null) {
                for (int j=i+1;j<variables.size();j++) {
                    Variable variable2=(Variable)variables.get(j);
                    if (variable1.equals(variable2))
                        samePairs.add(new int[] { i,j });
                }
            }
        }
        int[][] samePairsArray=new int[samePairs.size()][];
        samePairs.toArray(samePairsArray);
        return samePairsArray;
    }
    /**
     * Returns the list of indices from the passed array that are contained in the passed set.
     *
     * @param variables                         the list of variables
     * @param variablesToInclude                the set of variables to include
     * @return                                  the array of indices
     */
    protected int[] getIndicesOfVariables(List variables,Set variablesToInclude) {
        Set includedVariables=new HashSet();
        int[] indices=new int[variablesToInclude.size()];
        int count=0;
        for (int i=0;i<variables.size();i++) {
            Variable variable=(Variable)variables.get(i);
            if (variable!=null && variablesToInclude.contains(variable) && includedVariables.add(variable))
                indices[count++]=i;
        }
        return indices;
    }
    /**
     * Returns the list of indices from the passed array that are contained in the passed list in the order of that list.
     *
     * @param variables                         the list of variables
     * @param variablesToInclude                the set of variables to include
     * @param variablesOrder                    the list of variables determining the order of indices
     * @return                                  the array of indices
     */
    protected int[] getIndicesOfVariablesInOrder(List variables,Set variablesToInclude,List variablesOrder) {
        Set includedVariables=new HashSet();
        int[] indices=new int[variablesToInclude.size()];
        int count=0;
        for (int i=0;i<variablesOrder.size();i++) {
            Variable variable=(Variable)variablesOrder.get(i);
            if (variable!=null && variablesToInclude.contains(variable) && includedVariables.add(variable))
                indices[count++]=variables.indexOf(variable);
        }
        return indices;
    }
    /**
     * Returns the pairs of matching indices from the supplied set of variables.
     *
     * @param variables                         the list of variables
     * @param neededVariables                   the list of needed variables (should not contain repreated variables)
     * @param variablesToSkip                   the variables not to copy
     * @return                                  the array of pairs of indices of matching variables
     */
    protected int[][] getPairsOfMatchingIndices(List variables,List neededVariables,Collection variablesToSkip) {
        int count=0;
        for (int i=0;i<neededVariables.size();i++) {
            Variable variable=(Variable)neededVariables.get(i);
            if (variable!=null && !variablesToSkip.contains(variable) && variables.contains(variable))
                count++;
        }
        int[][] result=new int[count][2];
        count=0;
        for (int i=0;i<neededVariables.size();i++) {
            Variable variable=(Variable)neededVariables.get(i);
            if (variable!=null && !variablesToSkip.contains(variable)) {
                int originalIndex=variables.indexOf(variable);
                if (originalIndex!=-1) {
                    result[count][0]=originalIndex;
                    result[count][1]=i;
                    count++;
                }
            }
        }
        return result;
    }
}
