package net.derkholm.nmica.trainer;

import java.util.ArrayList;
import java.util.Collections;
import net.derkholm.nmica.maths.NativeMath;
import net.derkholm.nmica.matrix.Matrix1D;
import net.derkholm.nmica.matrix.ObjectMatrix1D;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.model.ContributionItem;
import net.derkholm.nmica.model.ContributionPrior;
import net.derkholm.nmica.model.ContributionSampler;
import net.derkholm.nmica.model.Datum;
import net.derkholm.nmica.model.FacetteMap;
import net.derkholm.nmica.model.MixPolicy;
import net.derkholm.nmica.model.PenalizedVariate;
import net.derkholm.nmica.model.SimpleContributionItem;

/* loaded from: input_file:net/derkholm/nmica/trainer/QueuedDecopTrainer.class */
public class QueuedDecopTrainer extends Trainer {
    private static final long serialVersionUID = -409176609194689042L;
    private int minContributionMoves;
    private int minContributionProposals;
    private int mixtureDecopSessions;
    private double mixtureFractionPerSession;

    public QueuedDecopTrainer(FacetteMap facetteMap, Datum[] datumArr, int i, ContributionPrior[] contributionPriorArr, ContributionSampler[] contributionSamplerArr, MixPolicy mixPolicy, int i2) {
        super(facetteMap, datumArr, i, contributionPriorArr, contributionSamplerArr, mixPolicy, i2);
        this.minContributionMoves = 10;
        this.minContributionProposals = 20;
        this.mixtureDecopSessions = 2;
        this.mixtureFractionPerSession = 0.5d;
    }

    public void setMixtureDecopSessions(int i) {
        this.mixtureDecopSessions = i;
    }

    public void setMixtureFractionPerSession(double d) {
        this.mixtureFractionPerSession = d;
    }

    public void setMinContributionMoves(int i) {
        this.minContributionMoves = i;
    }

    public int getMinContributionMoves() {
        return this.minContributionMoves;
    }

    public int getMinContributionProposals() {
        return this.minContributionProposals;
    }

    public void setMinContributionProposals(int i) {
        this.minContributionProposals = i;
    }

    @Override // net.derkholm.nmica.trainer.Trainer
    protected double decorrelateState(TrainableState trainableState, double d, int i, int i2) {
        if (i > 0) {
            throw new RuntimeException("FIXME");
        }
        Datum[] dataSet = getDataSet();
        int components = getComponents();
        ContributionGroup[] contributionGroups = getFacetteMap().getContributionGroups();
        ContributionSampler[] samplers = getSamplers();
        MixPolicy mixPolicy = getMixPolicy();
        double likelihood = trainableState.likelihood();
        int i3 = 0;
        int i4 = 0;
        trainableState.commit();
        for (int i5 = 0; i5 < this.mixtureDecopSessions; i5++) {
            ArrayList<Integer> arrayList = new ArrayList();
            for (int i6 = 0; i6 < dataSet.length; i6++) {
                if (Math.random() < this.mixtureFractionPerSession) {
                    arrayList.add(new Integer(i6));
                    Matrix1D mixture = trainableState.getMixture(i6);
                    double prior = mixPolicy.prior(mixture);
                    while (true) {
                        mixPolicy.sample(mixture);
                        double prior2 = mixPolicy.prior(mixture);
                        if (prior2 <= prior && Math.random() >= NativeMath.exp2(prior2 - prior)) {
                            trainableState.rollbackDatum(i6);
                        }
                    }
                }
            }
            trainableState.likelihood();
            Collections.shuffle(arrayList);
            for (Integer num : arrayList) {
                double newDatumLikelihood = likelihood + (trainableState.getNewDatumLikelihood(num.intValue()) - trainableState.getOldDatumLikelihood(num.intValue()));
                if (newDatumLikelihood > d) {
                    likelihood = newDatumLikelihood;
                } else {
                    trainableState.rollbackDatum(num.intValue());
                }
            }
            trainableState.commit();
        }
        double prior3 = prior(trainableState);
        int max = Math.max(0, this.minContributionMoves - 0) + Math.max(0, this.minContributionProposals - 0);
        while (max > 0) {
            if (components > 1 && Math.random() < 0.1d) {
                int floor = (int) Math.floor(Math.random() * components);
                int i7 = floor;
                while (floor == i7) {
                    i7 = (int) Math.floor(Math.random() * components);
                }
                trainableState.permuteContributions(floor, i7);
                double prior4 = prior(trainableState);
                if (prior4 > prior3 || Math.random() < NativeMath.exp2(prior4 - prior3)) {
                    trainableState.commit();
                    prior3 = prior4;
                } else {
                    trainableState.rollback();
                }
            }
            int floor2 = (int) Math.floor(Math.random() * contributionGroups.length);
            int floor3 = (int) Math.floor(Math.random() * components);
            ObjectMatrix1D contributions = trainableState.getContributions(floor2);
            PenalizedVariate sample = samplers[floor2].sample(((ContributionItem) contributions.get(floor3)).getItem(), getUncleVector(floor2, floor3));
            contributions.set(floor3, new SimpleContributionItem(sample.getVariate()));
            double balancePenalty = sample.getBalancePenalty();
            i4++;
            double prior5 = prior(trainableState);
            if (Math.random() > NativeMath.exp2((prior5 - prior3) + balancePenalty)) {
                trainableState.rollback();
            } else {
                double likelihood2 = trainableState.likelihood();
                if (likelihood2 < d) {
                    trainableState.rollback();
                } else {
                    trainableState.commit();
                    likelihood = likelihood2;
                    prior3 = prior5;
                    i3++;
                    max = Math.max(0, this.minContributionMoves - i3) + Math.max(0, this.minContributionProposals - i4);
                }
            }
        }
        return likelihood;
    }
}
