package net.derkholm.nmica.trainer;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import net.derkholm.nmica.maths.NativeMath;
import net.derkholm.nmica.matrix.Matrix1D;
import net.derkholm.nmica.matrix.MatrixTools;
import net.derkholm.nmica.matrix.ObjectMatrix1D;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.model.ContributionItem;
import net.derkholm.nmica.model.ContributionPrior;
import net.derkholm.nmica.model.ContributionSampler;
import net.derkholm.nmica.model.Datum;
import net.derkholm.nmica.model.Facette;
import net.derkholm.nmica.model.FacetteMap;
import net.derkholm.nmica.model.MixPolicy;
import net.derkholm.nmica.model.MultiICAModel;
import net.derkholm.nmica.model.PenalizedVariate;
import net.derkholm.nmica.model.SimpleContributionItem;

/* loaded from: input_file:net/derkholm/nmica/trainer/MCMCTrainer.class */
public class MCMCTrainer implements TrainableStateContext {
    private final FacetteMap facetteMap;
    private final Datum[] data;
    private final int components;
    private final ContributionPrior[] priors;
    private final ContributionSampler[] samplers;
    private final MixPolicy mixPolicy;
    private transient Facette[] facettes;
    private transient ContributionGroup[] cgs;
    private transient Map<Facette, Integer> facetteIndices;
    private transient Map<ContributionGroup, Integer> contributionIndices;
    private transient int[] facetteIndicesToContributionIndex;
    private transient int[][] contributionIndicesToFacetteIndices;
    private transient TrainableState model;
    private double replaceComponentProb = 0.0d;
    private boolean ignoreMixturePrior = true;
    private double lastHood = Double.NEGATIVE_INFINITY;
    private int drops = 0;
    private double beta = 1.0d;
    private double maxBeta = 1000.0d;
    private int step = 0;
    private transient EvaluationManager evalManager = new LocalEvaluationManager();
    private int minMixtureMoves = 100;
    private int minContributionMoves = 100;
    private int minMixtureProposals = 100;
    private int minContributionProposals = 100;
    private int mixtureDecopSessions = 5;
    private double mixtureFractionPerSession = 0.5d;

    /* loaded from: input_file:net/derkholm/nmica/trainer/MCMCTrainer$MTModel.class */
    public static class MTModel {
        public final MultiICAModel model;
        public final double likelihood;
        public final double prior;
        public final double beta;

        private MTModel(MultiICAModel multiICAModel, double d, double d2, double d3) {
            this.model = multiICAModel;
            this.likelihood = d;
            this.prior = d2;
            this.beta = d3;
        }
    }

    public void setBeta(double d) {
        this.beta = d;
    }

    /* JADX WARN: Type inference failed for: r1v24, types: [int[], int[][]] */
    private void makeIndices() {
        this.facettes = this.facetteMap.getFacettes();
        this.cgs = this.facetteMap.getContributionGroups();
        this.facetteIndices = new HashMap();
        for (int i = 0; i < this.facettes.length; i++) {
            this.facetteIndices.put(this.facettes[i], new Integer(i));
        }
        this.contributionIndices = new HashMap();
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            this.contributionIndices.put(this.cgs[i2], new Integer(i2));
        }
        this.facetteIndicesToContributionIndex = new int[this.facettes.length];
        for (int i3 = 0; i3 < this.facettes.length; i3++) {
            this.facetteIndicesToContributionIndex[i3] = contributionGroupToIndex(this.facetteMap.getContributionForFacette(this.facettes[i3]));
        }
        this.contributionIndicesToFacetteIndices = new int[this.cgs.length];
        for (int i4 = 0; i4 < this.cgs.length; i4++) {
            Facette[] facettesForContribution = this.facetteMap.getFacettesForContribution(this.cgs[i4]);
            int[] iArr = new int[facettesForContribution.length];
            for (int i5 = 0; i5 < facettesForContribution.length; i5++) {
                iArr[i5] = facetteToIndex(facettesForContribution[i5]);
            }
            this.contributionIndicesToFacetteIndices[i4] = iArr;
        }
    }

    public MCMCTrainer(MultiICAModel multiICAModel, ContributionPrior[] contributionPriorArr, ContributionSampler[] contributionSamplerArr, MixPolicy mixPolicy) {
        this.facetteMap = multiICAModel.getFacetteMap();
        this.data = multiICAModel.getDataSet();
        this.components = multiICAModel.getComponents();
        makeIndices();
        this.priors = contributionPriorArr;
        this.samplers = contributionSamplerArr;
        this.mixPolicy = mixPolicy;
        this.model = new TrainableState(this, multiICAModel);
    }

    public void setEvaluationManager(EvaluationManager evaluationManager) {
        this.evalManager = evaluationManager;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public EvaluationManager getEvaluationManager() {
        return this.evalManager;
    }

    public void setIgnoreMixturePrior(boolean z) {
        this.ignoreMixturePrior = z;
    }

    public boolean getIgnoreMixturePrior() {
        return this.ignoreMixturePrior;
    }

    public void setReplaceComponentProb(double d) {
        this.replaceComponentProb = d;
    }

    int facetteToIndex(Facette facette) throws IllegalArgumentException {
        Integer num = this.facetteIndices.get(facette);
        if (num == null) {
            throw new IllegalArgumentException("This model doesn't know anything about " + facette.toString());
        }
        return num.intValue();
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public int contributionGroupToIndex(ContributionGroup contributionGroup) throws IllegalArgumentException {
        Integer num = this.contributionIndices.get(contributionGroup);
        if (num == null) {
            throw new IllegalArgumentException("This model doesn't know anything about " + contributionGroup.toString());
        }
        return num.intValue();
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public int facetteIndexToContributionIndex(int i) {
        return this.facetteIndicesToContributionIndex[i];
    }

    int[] contributionIndexToFacetteIndices(int i) {
        return this.contributionIndicesToFacetteIndices[i];
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public FacetteMap getFacetteMap() {
        return this.facetteMap;
    }

    public Facette[] getFacettes() {
        return this.facettes;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public int getComponents() {
        return this.components;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public Datum[] getDataSet() {
        return this.data;
    }

    public ContributionPrior[] getPriors() {
        return this.priors;
    }

    public ContributionSampler[] getSamplers() {
        return this.samplers;
    }

    @Override // net.derkholm.nmica.trainer.TrainableStateContext
    public MixPolicy getMixPolicy() {
        return this.mixPolicy;
    }

    protected double prior(TrainableState trainableState) {
        double d = 0.0d;
        for (int i = 0; i < this.cgs.length; i++) {
            ObjectMatrix1D contributions = trainableState.getContributions(i);
            for (int i2 = 0; i2 < this.components; i2++) {
                d += this.priors[i].probability(((ContributionItem) contributions.get(i2)).getItem());
            }
        }
        if (Double.isNaN(d)) {
            d = Double.NEGATIVE_INFINITY;
        }
        return d;
    }

    private TrainableState copyState(TrainableState trainableState) {
        TrainableState trainableState2 = new TrainableState(this);
        for (int i = 0; i < this.data.length; i++) {
            MatrixTools.copy(trainableState2.getMixture(i), trainableState.getMixture(i));
        }
        for (int i2 = 0; i2 < this.cgs.length; i2++) {
            MatrixTools.copy(trainableState2.getContributions(i2), trainableState.getContributions(i2));
        }
        trainableState2.commit();
        return trainableState2;
    }

    private TrainableState crossover(TrainableState trainableState, TrainableState trainableState2, int i) {
        TrainableState trainableState3 = new TrainableState(this);
        for (int i2 = 0; i2 < this.data.length; i2++) {
            crossMatrix(trainableState3.getMixture(i2), trainableState.getMixture(i2), trainableState2.getMixture(i2), i);
        }
        for (int i3 = 0; i3 < this.cgs.length; i3++) {
            crossMatrix(trainableState3.getContributions(i3), trainableState.getContributions(i3), trainableState2.getContributions(i3), i);
        }
        trainableState3.commit();
        return trainableState3;
    }

    private void crossMatrix(Matrix1D matrix1D, Matrix1D matrix1D2, Matrix1D matrix1D3, int i) {
        int i2 = 0;
        while (i2 < matrix1D.size()) {
            matrix1D.set(i2, (i2 == i ? matrix1D3 : matrix1D2).get(i2));
            i2++;
        }
    }

    private void crossMatrix(ObjectMatrix1D objectMatrix1D, ObjectMatrix1D objectMatrix1D2, ObjectMatrix1D objectMatrix1D3, int i) {
        int i2 = 0;
        while (i2 < objectMatrix1D.size()) {
            objectMatrix1D.set(i2, (i2 == i ? objectMatrix1D3 : objectMatrix1D2).get(i2));
            i2++;
        }
    }

    public void init() {
    }

    public MTModel next() {
        this.step++;
        double decorrelateState = decorrelateState(this.model, this.beta);
        MTModel mTModel = new MTModel(this.model, decorrelateState, prior(this.model), this.beta);
        if (decorrelateState < this.lastHood) {
            this.drops++;
            if (this.drops > 2 && this.beta < this.maxBeta) {
                this.beta = Math.min(this.maxBeta, this.beta * 1.2d);
            }
        } else {
            this.drops = 0;
        }
        this.lastHood = decorrelateState;
        return mTModel;
    }

    protected double decorrelateState(TrainableState trainableState, double d) {
        Datum[] dataSet = getDataSet();
        int components = getComponents();
        ContributionGroup[] contributionGroups = getFacetteMap().getContributionGroups();
        ContributionSampler[] samplers = getSamplers();
        MixPolicy mixPolicy = getMixPolicy();
        double likelihood = trainableState.likelihood();
        int i = 0;
        int i2 = 0;
        trainableState.commit();
        for (int i3 = 0; i3 < this.mixtureDecopSessions; i3++) {
            ArrayList<Integer> arrayList = new ArrayList();
            for (int i4 = 0; i4 < dataSet.length; i4++) {
                if (Math.random() < this.mixtureFractionPerSession) {
                    arrayList.add(new Integer(i4));
                    Matrix1D mixture = trainableState.getMixture(i4);
                    double prior = mixPolicy.prior(mixture);
                    while (true) {
                        mixPolicy.sample(mixture);
                        double prior2 = mixPolicy.prior(mixture);
                        if (prior2 <= prior && Math.random() >= NativeMath.exp2(prior2 - prior)) {
                            trainableState.rollbackDatum(i4);
                        }
                    }
                }
            }
            trainableState.likelihood();
            Collections.shuffle(arrayList);
            for (Integer num : arrayList) {
                double newDatumLikelihood = likelihood + (trainableState.getNewDatumLikelihood(num.intValue()) - trainableState.getOldDatumLikelihood(num.intValue()));
                if (Math.random() > NativeMath.exp2(d * (newDatumLikelihood - likelihood))) {
                    trainableState.rollbackDatum(num.intValue());
                } else {
                    likelihood = newDatumLikelihood;
                }
            }
            trainableState.commit();
        }
        prior(trainableState);
        int max = Math.max(0, this.minContributionMoves - 0) + Math.max(0, this.minContributionProposals - 0);
        while (max > 0) {
            int floor = (int) Math.floor(Math.random() * contributionGroups.length);
            int floor2 = (int) Math.floor(Math.random() * components);
            ObjectMatrix1D contributions = trainableState.getContributions(floor);
            PenalizedVariate sample = samplers[floor].sample(((ContributionItem) contributions.get(floor2)).getItem(), null);
            contributions.set(floor2, new SimpleContributionItem(sample.getVariate()));
            double balancePenalty = sample.getBalancePenalty();
            i2++;
            double prior3 = prior(trainableState);
            if (Double.isNaN(prior3) || Double.isInfinite(prior3)) {
                trainableState.rollback();
            } else {
                double likelihood2 = trainableState.likelihood();
                if (Math.random() > NativeMath.exp2((d * (likelihood2 - likelihood)) + balancePenalty)) {
                    trainableState.rollback();
                } else {
                    trainableState.commit();
                    likelihood = likelihood2;
                    i++;
                    max = Math.max(0, this.minContributionMoves - i) + Math.max(0, this.minContributionProposals - i2);
                }
            }
        }
        return likelihood;
    }

    /* JADX WARN: Removed duplicated region for block: B:18:0x0146 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:26:0x013f A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected double old_decorrelateState(net.derkholm.nmica.trainer.TrainableState r10) {
        /*
            Method dump skipped, instructions count: 453
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: net.derkholm.nmica.trainer.MCMCTrainer.old_decorrelateState(net.derkholm.nmica.trainer.TrainableState):double");
    }

    public void terminate() {
        throw new RuntimeException("Termination conditions aren't implemented yet...");
    }

    public int getCycle() {
        return this.step;
    }
}
