package net.derkholm.nmica.trainer;

import java.util.ConcurrentModificationException;
import net.derkholm.nmica.maths.NativeMath;
import net.derkholm.nmica.matrix.Matrix1D;
import net.derkholm.nmica.matrix.ObjectMatrix1D;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.model.ContributionItem;
import net.derkholm.nmica.model.ContributionPrior;
import net.derkholm.nmica.model.ContributionSampler;
import net.derkholm.nmica.model.Datum;
import net.derkholm.nmica.model.FacetteMap;
import net.derkholm.nmica.model.MixPolicy;
import net.derkholm.nmica.model.PenalizedVariate;
import net.derkholm.nmica.model.SimpleContributionItem;

/* loaded from: input_file:net/derkholm/nmica/trainer/MixtureResamplingTrainer.class */
public class MixtureResamplingTrainer extends Trainer {
    private static final long serialVersionUID = -409176601994689042L;
    private int minContributionMoves;
    private int minContributionProposals;
    private int mixtureDecopSessions;
    private double moveFraction;
    private double proposalFraction;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:net/derkholm/nmica/trainer/MixtureResamplingTrainer$ZeroCardinalityException.class */
    public static class ZeroCardinalityException extends Exception {
    }

    public MixtureResamplingTrainer(FacetteMap facetteMap, Datum[] datumArr, int i, ContributionPrior[] contributionPriorArr, ContributionSampler[] contributionSamplerArr, MixPolicy mixPolicy, int i2) {
        super(facetteMap, datumArr, i, contributionPriorArr, contributionSamplerArr, mixPolicy, i2);
        this.minContributionMoves = 10;
        this.minContributionProposals = 20;
        this.mixtureDecopSessions = 5;
        this.moveFraction = 0.5d;
        this.proposalFraction = 2.0d;
    }

    public void setMixtureDecopSessions(int i) {
        this.mixtureDecopSessions = i;
    }

    public void setMinContributionMoves(int i) {
        this.minContributionMoves = i;
    }

    public int getMinContributionMoves() {
        return this.minContributionMoves;
    }

    public void setMoveFraction(double d) {
        this.moveFraction = d;
    }

    public void setProposalFraction(double d) {
        this.proposalFraction = d;
    }

    public int getMinContributionProposals() {
        return this.minContributionProposals;
    }

    public void setMinContributionProposals(int i) {
        this.minContributionProposals = i;
    }

    @Override // net.derkholm.nmica.trainer.Trainer
    protected double decorrelateState(TrainableState trainableState, double d, int i, int i2) {
        getDataSet();
        int components = getComponents();
        ContributionGroup[] contributionGroups = getFacetteMap().getContributionGroups();
        ContributionSampler[] samplers = getSamplers();
        getMixPolicy();
        double likelihood = trainableState.likelihood();
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < this.mixtureDecopSessions; i5++) {
            likelihood = resampleMixtureComponent(trainableState, (int) (Math.random() * components), d);
        }
        double prior = prior(trainableState);
        int max = Math.max(0, this.minContributionMoves - 0) + Math.max(0, this.minContributionProposals - 0);
        while (max > 0) {
            if (i2 > 1 && Math.random() < 0.1d) {
                int floor = i + ((int) Math.floor(Math.random() * i2));
                int i6 = floor;
                while (floor == i6) {
                    i6 = i + ((int) Math.floor(Math.random() * i2));
                }
                trainableState.permuteContributions(floor, i6);
                double prior2 = prior(trainableState);
                if (prior2 > prior || Math.random() < NativeMath.exp2(prior2 - prior)) {
                    trainableState.commit();
                    prior = prior2;
                } else {
                    trainableState.rollback();
                }
            }
            int floor2 = (int) Math.floor(Math.random() * contributionGroups.length);
            int floor3 = i + ((int) Math.floor(Math.random() * i2));
            ObjectMatrix1D contributions = trainableState.getContributions(floor2);
            PenalizedVariate sample = samplers[floor2].sample(((ContributionItem) contributions.get(floor3)).getItem(), getUncleVector(floor2, floor3));
            contributions.set(floor3, new SimpleContributionItem(sample.getVariate()));
            double balancePenalty = sample.getBalancePenalty();
            i4++;
            double prior3 = prior(trainableState);
            if (Math.random() > NativeMath.exp2((prior3 - prior) + balancePenalty)) {
                trainableState.rollback();
            } else {
                double likelihood2 = trainableState.likelihood();
                if (likelihood2 < d) {
                    trainableState.rollback();
                } else {
                    trainableState.commit();
                    likelihood = likelihood2;
                    prior = prior3;
                    i3++;
                    max = Math.max(0, this.minContributionMoves - i3) + Math.max(0, this.minContributionProposals - i4);
                }
            }
        }
        return likelihood;
    }

    private double resampleMixtureComponent(TrainableState trainableState, int i, double d) {
        double d2;
        double d3;
        boolean ignoreMixturePrior = getIgnoreMixturePrior();
        Datum[] dataSet = getDataSet();
        getComponents();
        MixPolicy mixPolicy = getMixPolicy();
        trainableState.commit();
        boolean[] zArr = new boolean[dataSet.length];
        double[] dArr = new double[dataSet.length];
        for (int i2 = 0; i2 < dataSet.length; i2++) {
            Matrix1D mixture = trainableState.getMixture(i2);
            zArr[i2] = mixture.get(i) != 0.0d;
            double prior = mixPolicy.prior(mixture);
            mixture.set(i, zArr[i2] ? 0.0d : 1.0d);
            if (!ignoreMixturePrior) {
                double prior2 = mixPolicy.prior(mixture);
                dArr[i2] = zArr[i2] ? prior - prior2 : prior2 - prior;
            }
        }
        trainableState.likelihood();
        boolean[] zArr2 = new boolean[dataSet.length];
        double[] dArr2 = new double[dataSet.length];
        double d4 = 0.0d;
        for (int i3 = 0; i3 < zArr2.length; i3++) {
            double oldDatumLikelihood = trainableState.getOldDatumLikelihood(i3);
            double newDatumLikelihood = trainableState.getNewDatumLikelihood(i3);
            double d5 = newDatumLikelihood - oldDatumLikelihood;
            dArr2[i3] = zArr[i3] ? -d5 : d5;
            if (dArr2[i3] > 0.0d) {
                zArr2[i3] = true;
            }
            d4 += Math.max(oldDatumLikelihood, newDatumLikelihood);
        }
        int ceil = (int) Math.ceil(this.moveFraction * dataSet.length);
        int ceil2 = (int) Math.ceil(this.proposalFraction * dataSet.length);
        int i4 = 0;
        int i5 = 0;
        while (i4 < ceil && i5 < ceil2) {
            if (Math.random() < 0.2d) {
                try {
                    int pick = pick(zArr2, true);
                    int pick2 = pick(zArr2, false);
                    double d6 = (d4 - dArr2[pick]) + dArr2[pick2];
                    if (d6 > d) {
                        zArr2[pick] = false;
                        zArr2[pick2] = true;
                        d4 = d6;
                        i4++;
                    }
                } catch (ZeroCardinalityException e) {
                }
            } else {
                int random = (int) (Math.random() * zArr2.length);
                boolean z = zArr2[random];
                if (z) {
                    d2 = d4 - dArr2[random];
                    d3 = -dArr[random];
                } else {
                    d2 = d4 + dArr2[random];
                    d3 = dArr[random];
                }
                if (d2 >= d && (d3 >= 0.0d || Math.random() < NativeMath.exp2(d3))) {
                    d4 = d2;
                    if (z) {
                        zArr2[random] = false;
                    } else {
                        zArr2[random] = true;
                    }
                    i4++;
                }
            }
            i5++;
        }
        for (int i6 = 0; i6 < zArr2.length; i6++) {
            if (!(zArr2[i6] ^ zArr[i6])) {
                trainableState.rollbackDatum(i6);
            }
        }
        trainableState.commit();
        return d4;
    }

    private static int pick(boolean[] zArr, boolean z) throws ZeroCardinalityException {
        int i = 0;
        for (boolean z2 : zArr) {
            if (z2 == z) {
                i++;
            }
        }
        if (i == 0) {
            throw new ZeroCardinalityException();
        }
        int floor = (int) Math.floor(Math.random() * i);
        for (int i2 = 0; i2 < zArr.length; i2++) {
            if (zArr[i2] == z) {
                if (floor == 0) {
                    return i2;
                }
                floor--;
            }
        }
        throw new ConcurrentModificationException(String.format("Strange.  card=%d hit=%d", Integer.valueOf(i), Integer.valueOf(floor)));
    }
}
