/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions.gp;

import dr.inference.distribution.RandomField;
import dr.inference.model.AbstractModel;
import dr.inference.model.DesignMatrix;
import dr.inference.model.GradientProvider;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.RandomFieldDistribution;
import dr.math.distributions.gp.GaussianProcessKernel;
import java.util.Arrays;
import java.util.List;
import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon_D64;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.linsol.LinearSolver;

public class AdditiveGaussianProcessDistribution
extends RandomFieldDistribution {
    public static final String TYPE = "GaussianProcess";
    private final int order;
    private final int dim;
    private final Parameter orderVariance;
    private final Parameter meanParameter;
    private final Parameter nuggetParameter;
    private final List<BasisDimension> bases;
    private final double[] mean;
    private final double[] tmp;
    private final DenseMatrix64F gramian;
    private final DenseMatrix64F precision;
    private final DenseMatrix64F variance;
    private double logDeterminant;
    private boolean meanKnown;
    private boolean precisionAndDeterminantKnown;
    private boolean gramianAndVarianceKnown;
    private static final boolean USE_CHOLESKY = true;

    public AdditiveGaussianProcessDistribution(String string, int n, Parameter parameter, Parameter parameter2, Parameter parameter3, List<BasisDimension> list) {
        super(string);
        this.order = parameter.getDimension();
        if (this.order != 1) {
            throw new RuntimeException("Not yet implemented");
        }
        this.dim = n;
        this.orderVariance = parameter;
        this.meanParameter = parameter2;
        this.nuggetParameter = parameter3;
        this.bases = list;
        this.mean = new double[n];
        this.tmp = new double[n];
        this.gramian = new DenseMatrix64F(n, n);
        this.precision = new DenseMatrix64F(n, n);
        this.variance = new DenseMatrix64F(n, n);
        this.addVariable(parameter);
        if (parameter2 != null) {
            this.addVariable(parameter2);
        }
        if (parameter3 != null) {
            this.addVariable(parameter3);
        }
        for (BasisDimension basisDimension : list) {
            GaussianProcessKernel gaussianProcessKernel = basisDimension.getKernel();
            if (gaussianProcessKernel instanceof AbstractModel) {
                this.addModel((AbstractModel)((Object)gaussianProcessKernel));
            }
            this.addVariable(basisDimension.getDesignMatrix1());
            this.addVariable(basisDimension.getDesignMatrix2());
        }
    }

    public int getOrder() {
        return this.order;
    }

    public Parameter getOrderVariance() {
        return this.orderVariance;
    }

    List<BasisDimension> getBases() {
        return this.bases;
    }

    private void computeGramianAndVariance() {
        AdditiveGaussianProcessDistribution.computeAdditiveGramian(this.gramian, this.bases, this.orderVariance);
        this.variance.set(this.gramian);
        if (this.nuggetParameter != null) {
            for (int i = 0; i < this.dim; ++i) {
                this.variance.add(i, i, this.getNugget(i));
            }
        }
    }

    private void computePrecisionAndDeterminant() {
        DenseMatrix64F denseMatrix64F = this.getVariance();
        LinearSolver<DenseMatrix64F> linearSolver = LinearSolverFactory.symmPosDef(this.dim);
        if (!linearSolver.setA(denseMatrix64F)) {
            throw new RuntimeException("Unable to decompose matrix");
        }
        linearSolver.invert(this.precision);
        this.logDeterminant = 2.0 * AdditiveGaussianProcessDistribution.computeLogDeterminantFromTriangularMatrix(((CholeskyDecompositionCommon_D64)linearSolver.getDecomposition()).getT());
    }

    private static double computeLogDeterminantFromTriangularMatrix(DenseMatrix64F denseMatrix64F) {
        int n = denseMatrix64F.numCols;
        double[] dArray = denseMatrix64F.getData();
        double d = 0.0;
        int n2 = n * n;
        for (int i = 0; i < n2; i += n + 1) {
            d += Math.log(dArray[i]);
        }
        return d;
    }

    private double[] getPrecision() {
        return this.getPrecisionAsMatrix().getData();
    }

    protected DenseMatrix64F getPrecisionAsMatrix() {
        if (!this.precisionAndDeterminantKnown) {
            this.computePrecisionAndDeterminant();
            this.precisionAndDeterminantKnown = true;
        }
        return this.precision;
    }

    private double getLogDeterminant() {
        if (!this.precisionAndDeterminantKnown) {
            this.computePrecisionAndDeterminant();
            this.precisionAndDeterminantKnown = true;
        }
        return this.logDeterminant;
    }

    private DenseMatrix64F getGramian() {
        if (!this.gramianAndVarianceKnown) {
            this.computeGramianAndVariance();
            this.gramianAndVarianceKnown = true;
        }
        return this.gramian;
    }

    private DenseMatrix64F getVariance() {
        if (!this.gramianAndVarianceKnown) {
            this.computeGramianAndVariance();
            this.gramianAndVarianceKnown = true;
        }
        return this.variance;
    }

    private double getNugget(int n) {
        return this.nuggetParameter.getDimension() == 1 ? this.nuggetParameter.getParameterValue(0) : this.nuggetParameter.getParameterValue(n);
    }

    @Override
    public double[] getMean() {
        if (!this.meanKnown) {
            if (this.meanParameter == null) {
                Arrays.fill(this.mean, 0.0);
            } else if (this.meanParameter.getDimension() == 1) {
                Arrays.fill(this.mean, this.meanParameter.getParameterValue(0));
            } else {
                for (int i = 0; i < this.mean.length; ++i) {
                    this.mean[i] = this.meanParameter.getParameterValue(i);
                }
            }
            this.meanKnown = true;
        }
        return this.mean;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    @Override
    public double[][] getScaleMatrix() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public Variable<Double> getLocationVariable() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double logPdf(double[] dArray) {
        this.precisionAndDeterminantKnown = false;
        this.gramianAndVarianceKnown = false;
        double[] dArray2 = this.getMean();
        double[] dArray3 = this.tmp;
        double[] dArray4 = this.getPrecision();
        for (int i = 0; i < this.dim; ++i) {
            dArray3[i] = dArray[i] - dArray2[i];
        }
        double d = 0.0;
        for (int i = 0; i < this.dim; ++i) {
            for (int j = 0; j < this.dim; ++j) {
                d += dArray3[i] * dArray4[i * this.dim + j] * dArray3[j];
            }
        }
        return -0.5 * ((double)this.dim * Math.log(Math.PI * 2) + this.getLogDeterminant()) - 0.5 * d;
    }

    @Override
    public int getDimension() {
        return this.dim;
    }

    @Override
    public double[] getGradientLogDensity(Object object) {
        return MultivariateNormalDistribution.gradLogPdf((double[])object, this.getMean(), this.getPrecision());
    }

    @Override
    public double[] getDiagonalHessianLogDensity(Object object) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[][] getHessianLogDensity(Object object) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] nextRandom() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (!this.containsKernel(model)) {
            throw new IllegalArgumentException("Unknown model");
        }
        this.precisionAndDeterminantKnown = false;
        this.gramianAndVarianceKnown = false;
        this.fireModelChanged();
    }

    private boolean containsKernel(Model model) {
        for (BasisDimension basisDimension : this.bases) {
            if (model != basisDimension.getKernel()) continue;
            return true;
        }
        return false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public GradientProvider getGradientWrt(Parameter parameter) {
        if (parameter == this.meanParameter) {
            return new GradientProvider(){

                @Override
                public int getDimension() {
                    return AdditiveGaussianProcessDistribution.this.meanParameter.getDimension();
                }

                @Override
                public double[] getGradientLogDensity(Object object) {
                    double[] dArray = MultivariateNormalDistribution.gradLogPdf((double[])object, AdditiveGaussianProcessDistribution.this.getMean(), AdditiveGaussianProcessDistribution.this.getPrecision());
                    if (AdditiveGaussianProcessDistribution.this.meanParameter.getDimension() == AdditiveGaussianProcessDistribution.this.dim) {
                        int n = 0;
                        while (n < AdditiveGaussianProcessDistribution.this.dim) {
                            int n2 = n++;
                            dArray[n2] = dArray[n2] * -1.0;
                        }
                        return dArray;
                    }
                    if (AdditiveGaussianProcessDistribution.this.meanParameter.getDimension() == 1) {
                        double d = 0.0;
                        for (int i = 0; i < AdditiveGaussianProcessDistribution.this.dim; ++i) {
                            d += dArray[i];
                        }
                        return new double[]{d};
                    }
                    throw new IllegalArgumentException("Unknown mean parameter structure");
                }
            };
        }
        throw new RuntimeException("Unknown parameter");
    }

    public static void computeAdditiveGramian(DenseMatrix64F denseMatrix64F, List<BasisDimension> list, Parameter parameter) {
        denseMatrix64F.zero();
        int n = denseMatrix64F.getNumRows();
        int n2 = denseMatrix64F.getNumCols();
        for (BasisDimension basisDimension : list) {
            GaussianProcessKernel gaussianProcessKernel = basisDimension.getKernel();
            DesignMatrix designMatrix = basisDimension.getDesignMatrix1();
            DesignMatrix designMatrix2 = basisDimension.getDesignMatrix2();
            double d = gaussianProcessKernel.getScale();
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < n2; ++j) {
                    double d2 = designMatrix.getParameterValue(i, 0);
                    double d3 = designMatrix2.getParameterValue(j, 0);
                    denseMatrix64F.add(i, j, d * gaussianProcessKernel.getUnscaledCovariance(d2, d3));
                }
            }
        }
        int n3 = parameter.getDimension();
    }

    public static class BasisDimension {
        private final GaussianProcessKernel kernel;
        private final DesignMatrix design1;
        private final DesignMatrix design2;

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, DesignMatrix designMatrix, DesignMatrix designMatrix2) {
            this.kernel = gaussianProcessKernel;
            this.design1 = designMatrix;
            this.design2 = designMatrix2;
        }

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, DesignMatrix designMatrix) {
            this(gaussianProcessKernel, designMatrix, designMatrix);
        }

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, RandomField.WeightProvider weightProvider) {
            this(gaussianProcessKernel, BasisDimension.makeDesignMatrixFromWeights(weightProvider));
        }

        GaussianProcessKernel getKernel() {
            return this.kernel;
        }

        DesignMatrix getDesignMatrix1() {
            return this.design1;
        }

        DesignMatrix getDesignMatrix2() {
            return this.design2;
        }

        private static DesignMatrix makeDesignMatrixFromWeights(final RandomField.WeightProvider weightProvider) {
            return new DesignMatrix("weights", false){

                @Override
                public double getParameterValue(int n, int n2) {
                    throw new RuntimeException("Not yet implemented");
                }

                @Override
                public int getDimension() {
                    return weightProvider.getDimension();
                }

                @Override
                public int getRowDimension() {
                    return weightProvider.getDimension();
                }

                @Override
                public int getColumnDimension() {
                    return 1;
                }

                @Override
                public Parameter getParameter(int n) {
                    throw new IllegalArgumentException("Not allowed");
                }
            };
        }
    }
}

