package net.derkholm.nmica.model.logitseq;

import java.util.ArrayList;
import net.derkholm.nmica.maths.NativeMath;
import net.derkholm.nmica.matrix.Matrix1D;
import net.derkholm.nmica.matrix.Matrix2D;
import net.derkholm.nmica.matrix.ObjectMatrix1D;
import net.derkholm.nmica.model.ContributionItem;
import net.derkholm.nmica.model.Facette;
import net.derkholm.nmica.model.LikelihoodCalculator;
import net.derkholm.nmica.utils.CollectTools;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.symbol.AlphabetIndex;
import org.biojava.bio.symbol.AtomicSymbol;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.Symbol;

/* loaded from: input_file:net/derkholm/nmica/model/logitseq/LogisticSequenceLikelihoodCalculator.class */
class LogisticSequenceLikelihoodCalculator implements LikelihoodCalculator {
    private final LogisticSequenceFacette facette;
    private final Sequence sequence;
    private final int[] trimmedSeqIndices;
    private final int label;

    public LogisticSequenceLikelihoodCalculator(LogisticSequenceFacette logisticSequenceFacette, Sequence sequence) throws IllegalSymbolException {
        this.facette = logisticSequenceFacette;
        this.sequence = sequence;
        AlphabetIndex alphabetIndex = logisticSequenceFacette.getAlphabetIndex();
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i <= sequence.length(); i++) {
            Symbol symbolAt = sequence.symbolAt(i);
            if (symbolAt instanceof AtomicSymbol) {
                arrayList.add(new Integer(alphabetIndex.indexForSymbol(symbolAt)));
            } else {
                arrayList.add(new Integer(-1));
            }
        }
        this.trimmedSeqIndices = CollectTools.toIntArray(arrayList);
        if (!sequence.getAnnotation().containsProperty("mocca.label")) {
            throw new IllegalArgumentException("Unlabelled sequence");
        }
        this.label = ((Integer) sequence.getAnnotation().getProperty("mocca.label")).intValue();
    }

    @Override // net.derkholm.nmica.model.LikelihoodCalculator
    public Facette getFacette() {
        return this.facette;
    }

    @Override // net.derkholm.nmica.model.LikelihoodCalculator
    public Object getData() {
        return this.sequence;
    }

    @Override // net.derkholm.nmica.model.LikelihoodCalculator
    public double likelihood(ObjectMatrix1D objectMatrix1D, Matrix1D matrix1D) {
        double d = 0.0d;
        for (int i = 0; i < objectMatrix1D.size(); i++) {
            ContributionItem contributionItem = (ContributionItem) objectMatrix1D.get(i);
            WeightedWeightMatrix weightedWeightMatrix = (WeightedWeightMatrix) contributionItem.getItem();
            Matrix2D matrix2D = (Matrix2D) contributionItem.getItemView(this.facette.forwardBmView());
            int columns = matrix2D.columns();
            double fastlog2 = columns * NativeMath.fastlog2(0.25d);
            int length = (this.trimmedSeqIndices.length - columns) + 1;
            double[] dArr = new double[length];
            for (int i2 = 0; i2 < length; i2++) {
                try {
                    dArr[i2] = scoreWM(matrix2D, i2);
                } catch (Exception e) {
                    throw new IllegalArgumentException();
                }
            }
            d += weightedWeightMatrix.getWeight() * (NativeMath.addLog2(dArr) - fastlog2);
        }
        double logit = logit(d);
        return this.label >= 0 ? NativeMath.log2(logit) : NativeMath.log2(1.0d - logit);
    }

    private double logit(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    private double scoreWM(Matrix2D matrix2D, int i) throws Exception {
        double d = 0.0d;
        int columns = matrix2D.columns();
        for (int i2 = 0; i2 < columns; i2++) {
            int i3 = this.trimmedSeqIndices[i + i2];
            if (i3 < 0) {
                return Double.NEGATIVE_INFINITY;
            }
            d += matrix2D.get(i3, i2);
        }
        return d;
    }
}
