package net.derkholm.nmica.trainer;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import net.derkholm.nmica.maths.MathsTools;
import net.derkholm.nmica.maths.NativeMath;
import net.derkholm.nmica.matrix.Matrix1D;
import net.derkholm.nmica.matrix.MatrixTools;
import net.derkholm.nmica.matrix.ObjectMatrix1D;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.model.ContributionItem;
import net.derkholm.nmica.model.ContributionPrior;
import net.derkholm.nmica.model.ContributionSampler;
import net.derkholm.nmica.model.Datum;
import net.derkholm.nmica.model.Facette;
import net.derkholm.nmica.model.FacetteMap;
import net.derkholm.nmica.model.MixPolicy;
import net.derkholm.nmica.model.MultiICAModel;
import net.derkholm.nmica.model.SimpleContributionItem;
import net.derkholm.nmica.model.SimpleMultiICAModel;
import net.derkholm.nmica.utils.CollectTools;
import org.biojava.bio.BioError;

/* loaded from: input_file:net/derkholm/nmica/trainer/Trainer.class */
public abstract class Trainer implements Serializable, TrainableStateContext {
    private static final long serialVersionUID = 8340439207998109127L;
    private final FacetteMap facetteMap;
    private final Datum[] data;
    private final int components;
    private final ContributionPrior[] priors;
    private final ContributionSampler[] samplers;
    private final MixPolicy mixPolicy;
    private final int ensembleSize;
    private ContributionItem[][] seedValues;
    private transient Facette[] facettes;
    private transient ContributionGroup[] cgs;
    private transient Map<Facette, Integer> facetteIndices;
    private transient Map<ContributionGroup, Integer> contributionIndices;
    private transient int[] facetteIndicesToContributionIndex;
    private transient int[][] contributionIndicesToFacetteIndices;
    private transient Map<MultiICAModel, Double> modelToLikelihood;
    private transient double[] sortedLikelihoods;
    private transient TrainableState[] models;
    private transient int[] seedCounts = null;
    private double crossOverProb = 0.0d;
    private double replaceComponentProb = 0.0d;
    private boolean ignoreMixturePrior = false;
    private int samplesToRemove = 1;
    private int step = 0;
    private int stepsSinceSuccessfulDirectSample = 0;
    private double accumulatedEvidence = Double.NEGATIVE_INFINITY;
    private transient EvaluationManager evalManager = new LocalEvaluationManager();

    /* loaded from: input_file:net/derkholm/nmica/trainer/Trainer$UncleVector.class */
    private class UncleVector implements ObjectMatrix1D {
        private final TrainableState[] models;
        private final int group;
        private final int component;

        UncleVector(TrainableState[] trainableStateArr, int i, int i2) {
            this.models = trainableStateArr;
            this.group = i;
            this.component = i2;
        }

        @Override // net.derkholm.nmica.matrix.ObjectMatrix1D
        public int size() {
            return this.models.length;
        }

        @Override // net.derkholm.nmica.matrix.ObjectMatrix1D
        public Object get(int i) {
            return this.models[i].getContribution(this.group, this.component);
        }

        @Override // net.derkholm.nmica.matrix.ObjectMatrix1D
        public void set(int i, Object obj) {
            throw new UnsupportedOperationException("This is a read-only matrix for accessing other model states");
        }
    }

    /* loaded from: input_file:net/derkholm/nmica/trainer/Trainer$WeightedModel.class */
    public static class WeightedModel {
        private final MultiICAModel model;
        private final double likelihood;
        private final double priorMass;

        WeightedModel(MultiICAModel multiICAModel, double d, double d2) {
            this.model = multiICAModel;
            this.likelihood = d;
            this.priorMass = d2;
        }

        public MultiICAModel getModel() {
            return this.model;
        }

        public double getLikelihood() {
            return this.likelihood;
        }

        public double getPriorMass() {
            return this.priorMass;
        }

        public double getWeight() {
            return this.priorMass + this.likelihood;
        }
    }

    /* JADX WARN: Type inference failed for: r1v24, types: [int[], int[][]] */
    private void makeIndices() {
        this.facettes = this.facetteMap.getFacettes();
        this.cgs = this.facetteMap.getContributionGroups();
        this.facetteIndices = new HashMap();
        for (int i = 0; i < this.facettes.length; i++) {
            this.facetteIndices.put(this.facettes[i], new Integer(i));
        }
        this.contributionIndices = new HashMap();
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            this.contributionIndices.put(this.cgs[i2], new Integer(i2));
        }
        this.facetteIndicesToContributionIndex = new int[this.facettes.length];
        for (int i3 = 0; i3 < this.facettes.length; i3++) {
            this.facetteIndicesToContributionIndex[i3] = contributionGroupToIndex(this.facetteMap.getContributionForFacette(this.facettes[i3]));
        }
        this.contributionIndicesToFacetteIndices = new int[this.cgs.length];
        for (int i4 = 0; i4 < this.cgs.length; i4++) {
            Facette[] facettesForContribution = this.facetteMap.getFacettesForContribution(this.cgs[i4]);
            int[] iArr = new int[facettesForContribution.length];
            for (int i5 = 0; i5 < facettesForContribution.length; i5++) {
                iArr[i5] = facetteToIndex(facettesForContribution[i5]);
            }
            this.contributionIndicesToFacetteIndices[i4] = iArr;
        }
    }

    /* JADX WARN: Type inference failed for: r1v21, types: [net.derkholm.nmica.model.ContributionItem[], net.derkholm.nmica.model.ContributionItem[][]] */
    public Trainer(FacetteMap facetteMap, Datum[] datumArr, int i, ContributionPrior[] contributionPriorArr, ContributionSampler[] contributionSamplerArr, MixPolicy mixPolicy, int i2) {
        this.seedValues = (ContributionItem[][]) null;
        this.facetteMap = facetteMap;
        this.data = datumArr;
        this.components = i;
        this.priors = contributionPriorArr;
        this.samplers = contributionSamplerArr;
        this.mixPolicy = mixPolicy;
        this.ensembleSize = i2;
        makeIndices();
        this.seedValues = new ContributionItem[this.cgs.length];
    }

    public void setSeedContributions(ContributionGroup contributionGroup, ContributionItem[] contributionItemArr) throws IllegalArgumentException {
        int contributionGroupToIndex = contributionGroupToIndex(contributionGroup);
        if (contributionGroupToIndex < 0) {
            throw new IllegalArgumentException("Unknown contribution group " + contributionGroup.toString());
        }
        if (contributionItemArr.length > this.components) {
            throw new IllegalArgumentException(String.format("Specified %d seeds for a model with %d components", Integer.valueOf(contributionItemArr.length), Integer.valueOf(this.components)));
        }
        this.seedValues[contributionGroupToIndex] = contributionItemArr;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        TrainableState[] trainableStateArr = (TrainableState[]) this.modelToLikelihood.keySet().toArray(new TrainableState[0]);
        FrozenModelState[] frozenModelStateArr = new FrozenModelState[trainableStateArr.length];
        for (int i = 0; i < trainableStateArr.length; i++) {
            frozenModelStateArr[i] = new FrozenModelState(trainableStateArr[i]);
        }
        objectOutputStream.writeObject(frozenModelStateArr);
    }

    /* JADX WARN: Type inference failed for: r1v17, types: [net.derkholm.nmica.model.ContributionItem[], net.derkholm.nmica.model.ContributionItem[][]] */
    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        makeIndices();
        this.evalManager = new LocalEvaluationManager();
        if (this.seedValues == null) {
            this.seedValues = new ContributionItem[this.cgs.length];
            for (int i = 0; i < this.cgs.length; i++) {
                this.seedValues[i] = new ContributionItem[0];
                this.seedCounts[i] = 0;
            }
        }
        this.seedCounts = new int[this.cgs.length];
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            this.seedCounts[i2] = this.seedValues[i2].length;
        }
        FrozenModelState[] frozenModelStateArr = (FrozenModelState[]) objectInputStream.readObject();
        this.modelToLikelihood = new HashMap();
        for (FrozenModelState frozenModelState : frozenModelStateArr) {
            TrainableState trainableState = new TrainableState(this, frozenModelState);
            this.modelToLikelihood.put(trainableState, new Double(trainableState.likelihood()));
        }
    }

    public void setEvaluationManager(EvaluationManager evaluationManager) {
        this.evalManager = evaluationManager;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public EvaluationManager getEvaluationManager() {
        return this.evalManager;
    }

    public void setIgnoreMixturePrior(boolean z) {
        this.ignoreMixturePrior = z;
    }

    public boolean getIgnoreMixturePrior() {
        return this.ignoreMixturePrior;
    }

    public void setReplaceComponentProb(double d) {
        this.replaceComponentProb = d;
    }

    public int getSamplesToRemove() {
        return this.samplesToRemove;
    }

    public void setSamplesToRemove(int i) {
        this.samplesToRemove = i;
    }

    public void setCrossOverProb(double d) {
        this.crossOverProb = d;
    }

    int facetteToIndex(Facette facette) throws IllegalArgumentException {
        Integer num = this.facetteIndices.get(facette);
        if (num == null) {
            throw new IllegalArgumentException("This model doesn't know anything about " + facette.toString());
        }
        return num.intValue();
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public int contributionGroupToIndex(ContributionGroup contributionGroup) throws IllegalArgumentException {
        Integer num = this.contributionIndices.get(contributionGroup);
        if (num == null) {
            throw new IllegalArgumentException("This model doesn't know anything about " + contributionGroup.toString());
        }
        return num.intValue();
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public int facetteIndexToContributionIndex(int i) {
        return this.facetteIndicesToContributionIndex[i];
    }

    int[] contributionIndexToFacetteIndices(int i) {
        return this.contributionIndicesToFacetteIndices[i];
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public FacetteMap getFacetteMap() {
        return this.facetteMap;
    }

    public Facette[] getFacettes() {
        return this.facettes;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public int getComponents() {
        return this.components;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public Datum[] getDataSet() {
        return this.data;
    }

    public ContributionPrior[] getPriors() {
        return this.priors;
    }

    public ContributionSampler[] getSamplers() {
        return this.samplers;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public MixPolicy getMixPolicy() {
        return this.mixPolicy;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double prior(TrainableState trainableState) {
        double d = 0.0d;
        if (!this.ignoreMixturePrior) {
            for (int i = 0; i < this.data.length; i++) {
                d += this.mixPolicy.prior(trainableState.getMixture(i));
            }
        }
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            ObjectMatrix1D contributions = trainableState.getContributions(i2);
            for (int i3 = 0; i3 < this.components; i3++) {
                d += this.priors[i2].probability(((ContributionItem) contributions.get(i3)).getItem());
            }
        }
        if (Double.isNaN(d)) {
            d = Double.NEGATIVE_INFINITY;
        }
        return d;
    }

    private TrainableState directSampleModel() {
        TrainableState trainableState = new TrainableState(this);
        for (int i = 0; i < this.data.length; i++) {
            this.mixPolicy.variate(trainableState.getMixture(i));
        }
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            ObjectMatrix1D contributions = trainableState.getContributions(i2);
            for (int i3 = 0; i3 < this.components; i3++) {
                if (i3 < this.seedCounts[i2]) {
                    contributions.set(i3, this.seedValues[i2][i3]);
                } else {
                    contributions.set(i3, new SimpleContributionItem(this.priors[i2].variate()));
                }
            }
        }
        trainableState.commit();
        return trainableState;
    }

    private TrainableState copyState(TrainableState trainableState) {
        TrainableState trainableState2 = new TrainableState(this);
        for (int i = 0; i < this.data.length; i++) {
            MatrixTools.copy(trainableState2.getMixture(i), trainableState.getMixture(i));
        }
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            MatrixTools.copy(trainableState2.getContributions(i2), trainableState.getContributions(i2));
        }
        trainableState2.commit();
        return trainableState2;
    }

    private TrainableState crossover(TrainableState trainableState, TrainableState trainableState2, int i) {
        TrainableState trainableState3 = new TrainableState(this);
        for (int i2 = 0; i2 < this.data.length; i2++) {
            crossMatrix(trainableState3.getMixture(i2), trainableState.getMixture(i2), trainableState2.getMixture(i2), i);
        }
        for (int i3 = 0; i3 < this.cgs.length; i3++) {
            crossMatrix(trainableState3.getContributions(i3), trainableState.getContributions(i3), trainableState2.getContributions(i3), i);
        }
        trainableState3.commit();
        return trainableState3;
    }

    private void crossMatrix(Matrix1D matrix1D, Matrix1D matrix1D2, Matrix1D matrix1D3, int i) {
        int i2 = 0;
        while (i2 < matrix1D.size()) {
            matrix1D.set(i2, (i2 == i ? matrix1D3 : matrix1D2).get(i2));
            i2++;
        }
    }

    private void crossMatrix(ObjectMatrix1D objectMatrix1D, ObjectMatrix1D objectMatrix1D2, ObjectMatrix1D objectMatrix1D3, int i) {
        int i2 = 0;
        while (i2 < objectMatrix1D.size()) {
            objectMatrix1D.set(i2, (i2 == i ? objectMatrix1D3 : objectMatrix1D2).get(i2));
            i2++;
        }
    }

    public void init() {
        this.seedCounts = new int[this.cgs.length];
        for (int i = 0; i < this.cgs.length; i++) {
            if (this.seedValues[i] == null) {
                this.seedValues[i] = new ContributionItem[0];
            }
            this.seedCounts[i] = this.seedValues[i].length;
        }
        this.modelToLikelihood = new HashMap();
        for (int i2 = 0; i2 < this.ensembleSize; i2++) {
            double d = Double.NEGATIVE_INFINITY;
            TrainableState trainableState = null;
            while (d == Double.NEGATIVE_INFINITY) {
                trainableState = directSampleModel();
                d = trainableState.likelihood();
            }
            this.modelToLikelihood.put(trainableState, new Double(d));
        }
    }

    public double getPriorResidue() {
        return this.step * NativeMath.log2((1.0d * this.ensembleSize) / (this.ensembleSize + 1));
    }

    public WeightedModel next() {
        this.step++;
        double d = Double.POSITIVE_INFINITY;
        MultiICAModel multiICAModel = null;
        for (Map.Entry<MultiICAModel, Double> entry : this.modelToLikelihood.entrySet()) {
            double doubleValue = entry.getValue().doubleValue();
            if (doubleValue < d) {
                d = doubleValue;
                multiICAModel = entry.getKey();
            }
        }
        if (multiICAModel == null) {
            throw new BioError("Assertion failed: wasn't able to remove a model");
        }
        this.modelToLikelihood.remove(multiICAModel);
        WeightedModel weightedModel = new WeightedModel(new SimpleMultiICAModel(multiICAModel), d, (-NativeMath.log2(this.ensembleSize)) + (this.step * NativeMath.log2((1.0d * this.ensembleSize) / (this.ensembleSize + 1))));
        this.accumulatedEvidence = NativeMath.addLog2(this.accumulatedEvidence, weightedModel.getWeight());
        TrainableState trainableState = null;
        double d2 = Double.NEGATIVE_INFINITY;
        if (this.stepsSinceSuccessfulDirectSample < 20) {
            trainableState = directSampleModel();
            d2 = trainableState.likelihood();
        }
        if (d2 < d) {
            int max = MathsTools.max(this.seedCounts);
            int i = this.components - max;
            Set<MultiICAModel> keySet = this.modelToLikelihood.keySet();
            this.models = (TrainableState[]) keySet.toArray(new TrainableState[keySet.size()]);
            trainableState = copyState(this.models[(int) Math.floor(Math.random() * this.models.length)]);
            if (i > 1) {
                if (Math.random() < this.replaceComponentProb) {
                    int floor = max + ((int) Math.floor(Math.random() * i));
                    for (int i2 = 0; i2 < this.cgs.length; i2++) {
                        trainableState.getContributions(i2).set(floor, new SimpleContributionItem(this.priors[i2].variate()));
                    }
                    for (int i3 = 0; i3 < this.data.length; i3++) {
                        this.mixPolicy.sampleComponent(trainableState.getMixture(i3), floor);
                    }
                    if (trainableState.likelihood() >= d) {
                        trainableState.commit();
                    } else {
                        trainableState.rollback();
                    }
                } else if (Math.random() < this.crossOverProb) {
                    TrainableState crossover = crossover(trainableState, this.models[(int) Math.floor(Math.random() * this.models.length)], (int) Math.floor(Math.random() * this.components));
                    if (crossover.likelihood() >= d) {
                        trainableState = crossover;
                    }
                }
            }
            decorrelateState(trainableState, d, max, i);
            this.stepsSinceSuccessfulDirectSample++;
        } else {
            this.stepsSinceSuccessfulDirectSample = 0;
        }
        trainableState.likelihood();
        this.modelToLikelihood.put(trainableState, new Double(trainableState.likelihood()));
        this.sortedLikelihoods = null;
        return weightedModel;
    }

    protected abstract double decorrelateState(TrainableState trainableState, double d, int i, int i2);

    /* JADX INFO: Access modifiers changed from: protected */
    public ObjectMatrix1D getUncleVector(int i, int i2) {
        return new UncleVector(this.models, i, i2);
    }

    public void terminate() {
        throw new RuntimeException("Termination conditions aren't implemented yet...");
    }

    public double getAccumulatedEvidence() {
        return this.accumulatedEvidence;
    }

    public MultiICAModel[] getCurrentEnsemble() {
        return (MultiICAModel[]) this.modelToLikelihood.keySet().toArray(new MultiICAModel[0]);
    }

    public int getCycle() {
        return this.step;
    }

    private double[] getSortedLikelihoods() {
        if (this.sortedLikelihoods == null) {
            double[] doubleArray = CollectTools.toDoubleArray(this.modelToLikelihood.values());
            Arrays.sort(doubleArray);
            this.sortedLikelihoods = doubleArray;
        }
        return this.sortedLikelihoods;
    }

    public double getMinimumLikelihood() {
        return getSortedLikelihoods()[0];
    }

    public double getMaximumLikelihood() {
        double[] sortedLikelihoods = getSortedLikelihoods();
        return sortedLikelihoods[sortedLikelihoods.length - 1];
    }

    public double getLikelihoodIQR() {
        double[] sortedLikelihoods = getSortedLikelihoods();
        return sortedLikelihoods[(int) (0.75d * sortedLikelihoods.length)] - sortedLikelihoods[(int) (0.25d * sortedLikelihoods.length)];
    }

    public MultiICAModel getBestModel() {
        double d = Double.NEGATIVE_INFINITY;
        MultiICAModel multiICAModel = null;
        for (Map.Entry<MultiICAModel, Double> entry : this.modelToLikelihood.entrySet()) {
            double doubleValue = entry.getValue().doubleValue();
            if (doubleValue > d) {
                d = doubleValue;
                multiICAModel = entry.getKey();
            }
        }
        return multiICAModel;
    }
}
