package net.derkholm.nmica.apps;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import net.derkholm.nmica.maths.DoubleFunction;
import net.derkholm.nmica.maths.IdentityDoubleFunction;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.model.ContributionPrior;
import net.derkholm.nmica.model.ContributionSampler;
import net.derkholm.nmica.model.Datum;
import net.derkholm.nmica.model.Facette;
import net.derkholm.nmica.model.FlatMixPolicy;
import net.derkholm.nmica.model.MultiICAModel;
import net.derkholm.nmica.model.MultiplexContributionSampler;
import net.derkholm.nmica.model.SimpleContributionGroup;
import net.derkholm.nmica.model.SimpleDatum;
import net.derkholm.nmica.model.SimpleFacetteMap;
import net.derkholm.nmica.model.logitseq.LogisticSequenceFacette;
import net.derkholm.nmica.model.logitseq.WeightedWeightMatrix;
import net.derkholm.nmica.model.logitseq.WeightedWeightMatrixMatrixSampler;
import net.derkholm.nmica.model.logitseq.WeightedWeightMatrixPrior;
import net.derkholm.nmica.model.logitseq.WeightedWeightMatrixWeightSampler;
import net.derkholm.nmica.model.motif.IndelSampler;
import net.derkholm.nmica.model.motif.MotifClippedSimplexPrior;
import net.derkholm.nmica.model.motif.SpinSampler;
import net.derkholm.nmica.model.motif.SymbolMassScalingSampler;
import net.derkholm.nmica.model.motif.ZapSampler;
import net.derkholm.nmica.motif.Motif;
import net.derkholm.nmica.motif.MotifIOTools;
import net.derkholm.nmica.trainer.EvaluationManager;
import net.derkholm.nmica.trainer.LocalEvaluationManager;
import net.derkholm.nmica.trainer.RandomDecopTrainer;
import net.derkholm.nmica.trainer.Trainer;
import net.derkholm.nmica.trainer.distributed.DistributedEvaluationManager;
import net.derkholm.nmica.utils.CliTools;
import net.derkholm.nmica.utils.ConfigurationException;
import org.biojava.bio.seq.DNATools;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.SequenceIterator;
import org.biojava.bio.seq.db.HashSequenceDB;
import org.biojava.bio.seq.db.SequenceDB;
import org.biojava.bio.seq.io.SeqIOTools;
import org.biojava.utils.ChangeSupport;

/* loaded from: input_file:net/derkholm/nmica/apps/Mocca.class */
public class Mocca {
    private SequenceDB[] seqDBs;
    private int targetLength = 10;
    private int ensembleSize = -1;
    private int numMotifs = 10;
    private boolean revComp = false;
    private String snapshotFile = null;
    private String outFile = "classmotifs.xms";
    private int sampleInterval = 1000;
    private String checkpoint = null;
    private int keepCheckpoints = 2;
    private File restartFromCheckpoint = null;
    private int checkpointInterval = 10000;
    private int workerThreads = 1;
    private int maxCycles = 0;
    private double crossOverProb = 0.3d;
    private double simplexMaximumScale = 2.0d;
    private double replaceComponentProb = 0.2d;
    private double minClip = 0.002d;
    private double maxClip = 1.0d;
    private int minContributionMoves = -1;
    private int minContributionProposals = -1;
    private boolean distributed = false;
    private int port = 0;

    public void setCheckpoint(String str) {
        this.checkpoint = str;
    }

    public void setCheckpointInterval(int i) {
        this.checkpointInterval = i;
    }

    public void setCrossOverProb(double d) {
        this.crossOverProb = d;
    }

    public void setEnsembleSize(int i) {
        this.ensembleSize = i;
    }

    public void setKeepCheckpoints(int i) {
        this.keepCheckpoints = i;
    }

    public void setMaxCycles(int i) {
        this.maxCycles = i;
    }

    public void setMinContributionMoves(int i) {
        this.minContributionMoves = i;
    }

    public void setMinContributionProposals(int i) {
        this.minContributionProposals = i;
    }

    public void setNumMotifs(int i) {
        this.numMotifs = i;
    }

    public void setOutFile(String str) {
        this.outFile = str;
    }

    public void setReplaceComponentProb(double d) {
        this.replaceComponentProb = d;
    }

    public void setRestartFromCheckpoint(File file) {
        this.restartFromCheckpoint = file;
    }

    public void setRevComp(boolean z) {
        this.revComp = z;
    }

    public void setSampleInterval(int i) {
        this.sampleInterval = i;
    }

    public void setSeqs(File[] fileArr) throws Exception {
        this.seqDBs = new SequenceDB[fileArr.length];
        for (int i = 0; i < fileArr.length; i++) {
            this.seqDBs[i] = loadDB(fileArr[i]);
        }
    }

    private static SequenceDB loadDB(File file) throws Exception {
        Pattern compile = Pattern.compile("label=(1|-1)");
        SequenceIterator readFastaDNA = SeqIOTools.readFastaDNA(new BufferedReader(new FileReader(file)));
        HashSequenceDB hashSequenceDB = new HashSequenceDB();
        while (readFastaDNA.hasNext()) {
            Sequence nextSequence = readFastaDNA.nextSequence();
            Matcher matcher = compile.matcher((String) nextSequence.getAnnotation().getProperty("description_line"));
            if (matcher.find()) {
                nextSequence.getAnnotation().setProperty("mocca.label", new Integer(Integer.parseInt(matcher.group(1))));
            }
            hashSequenceDB.addSequence(nextSequence);
        }
        return hashSequenceDB;
    }

    public void setSeqDBs(SequenceDB[] sequenceDBArr) {
        this.seqDBs = sequenceDBArr;
    }

    public void setSimplexMaximumScale(double d) {
        this.simplexMaximumScale = d;
    }

    public void setSnapshotFile(String str) {
        this.snapshotFile = str;
    }

    public void setTargetLength(int i) {
        this.targetLength = i;
    }

    public void setWorkerThreads(int i) {
        this.workerThreads = i;
    }

    public static void main(String[] strArr) throws Exception {
        Mocca mocca = new Mocca();
        mocca.run(CliTools.configureBean(mocca, strArr));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v142, types: [net.derkholm.nmica.trainer.Trainer] */
    public void run(String[] strArr) throws Exception {
        RandomDecopTrainer randomDecopTrainer;
        boolean z;
        if (this.restartFromCheckpoint != null) {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(this.restartFromCheckpoint));
            randomDecopTrainer = (Trainer) objectInputStream.readObject();
            objectInputStream.close();
            this.numMotifs = randomDecopTrainer.getComponents();
            z = false;
        } else {
            if (this.seqDBs == null) {
                throw new ConfigurationException("You must specify a -seqs option");
            }
            HashSet<String> hashSet = new HashSet();
            for (int i = 0; i < this.seqDBs.length; i++) {
                hashSet.addAll(this.seqDBs[i].ids());
            }
            ArrayList arrayList = new ArrayList();
            for (String str : hashSet) {
                Object[] objArr = new Object[this.seqDBs.length];
                for (int i2 = 0; i2 < this.seqDBs.length; i2++) {
                    if (this.seqDBs[i2].ids().contains(str)) {
                        objArr[i2] = this.seqDBs[i2].getSequence(str);
                    }
                }
                arrayList.add(new SimpleDatum(str, objArr));
            }
            Datum[] datumArr = (Datum[]) arrayList.toArray(new Datum[arrayList.size()]);
            if (this.ensembleSize < 0) {
                this.ensembleSize = Math.max(200, (int) (1000.0d / this.numMotifs));
            }
            if (this.minContributionMoves < 0) {
                this.minContributionMoves = this.numMotifs * 2;
            }
            if (this.minContributionProposals < 0) {
                this.minContributionProposals = this.numMotifs * 8;
            }
            MotifClippedSimplexPrior motifClippedSimplexPrior = new MotifClippedSimplexPrior(DNATools.getDNA(), this.targetLength, this.minClip, this.maxClip);
            WeightedWeightMatrixPrior weightedWeightMatrixPrior = new WeightedWeightMatrixPrior(motifClippedSimplexPrior);
            MultiplexContributionSampler multiplexContributionSampler = new MultiplexContributionSampler();
            multiplexContributionSampler.addSampler(new WeightedWeightMatrixMatrixSampler(new SymbolMassScalingSampler(this.simplexMaximumScale)), 16.0d);
            multiplexContributionSampler.addSampler(new WeightedWeightMatrixMatrixSampler(new SpinSampler(motifClippedSimplexPrior)), 2.0d);
            multiplexContributionSampler.addSampler(new WeightedWeightMatrixMatrixSampler(new ZapSampler(motifClippedSimplexPrior)), 4.0d);
            multiplexContributionSampler.addSampler(new WeightedWeightMatrixMatrixSampler(new IndelSampler(motifClippedSimplexPrior)), 2.0d);
            multiplexContributionSampler.addSampler(new WeightedWeightMatrixWeightSampler(), 5.0d);
            FlatMixPolicy flatMixPolicy = new FlatMixPolicy();
            DoubleFunction doubleFunction = IdentityDoubleFunction.INSTANCE;
            Facette[] facetteArr = new Facette[this.seqDBs.length];
            for (int i3 = 0; i3 < this.seqDBs.length; i3++) {
                facetteArr[i3] = new LogisticSequenceFacette();
            }
            SimpleContributionGroup simpleContributionGroup = new SimpleContributionGroup("motifs", WeightedWeightMatrix.class);
            SimpleFacetteMap simpleFacetteMap = new SimpleFacetteMap(new ContributionGroup[]{simpleContributionGroup}, facetteArr);
            for (Facette facette : facetteArr) {
                simpleFacetteMap.setContributesToFacette(simpleContributionGroup, facette, true);
            }
            randomDecopTrainer = new RandomDecopTrainer(simpleFacetteMap, datumArr, this.numMotifs, new ContributionPrior[]{weightedWeightMatrixPrior}, new ContributionSampler[]{multiplexContributionSampler}, flatMixPolicy, this.ensembleSize);
            randomDecopTrainer.setMinMixtureMoves(0);
            randomDecopTrainer.setMinContributionMoves(this.minContributionMoves);
            randomDecopTrainer.setMinMixtureProposals(0);
            randomDecopTrainer.setMinContributionProposals(this.minContributionProposals);
            randomDecopTrainer.setIgnoreMixturePrior(true);
            randomDecopTrainer.setCrossOverProb(this.crossOverProb);
            randomDecopTrainer.setReplaceComponentProb(this.replaceComponentProb);
            z = true;
        }
        EvaluationManager distributedEvaluationManager = this.distributed ? new DistributedEvaluationManager(this.port) : new LocalEvaluationManager(this.workerThreads);
        randomDecopTrainer.setEvaluationManager(distributedEvaluationManager);
        if (z) {
            randomDecopTrainer.init();
        }
        try {
            ChangeSupport.class.getMethod("setGlobalChangeBypass", Boolean.TYPE).invoke(null, Boolean.TRUE);
        } catch (NoSuchMethodException e) {
            System.err.println("Short-circuiting isn't available");
        }
        ArrayList arrayList2 = new ArrayList();
        long j = 0;
        do {
            Trainer.WeightedModel next = randomDecopTrainer.next();
            int cycle = randomDecopTrainer.getCycle();
            if (this.snapshotFile != null && cycle % this.sampleInterval == 0) {
                writeMotifs(randomDecopTrainer.getBestModel(), this.snapshotFile + '.' + cycle + ".xms");
            }
            long currentTimeMillis = System.currentTimeMillis();
            System.out.print("" + cycle + '\t' + next.getWeight() + '\t' + next.getLikelihood());
            System.out.print("\t" + randomDecopTrainer.getLikelihoodIQR());
            System.out.print("\t" + (currentTimeMillis - j));
            j = currentTimeMillis;
            System.out.println();
            if (this.checkpoint != null && cycle % this.checkpointInterval == 0) {
                File file = new File(this.checkpoint + '.' + cycle + ".jos");
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
                objectOutputStream.writeObject(randomDecopTrainer);
                objectOutputStream.close();
                arrayList2.add(file);
                while (arrayList2.size() > this.keepCheckpoints) {
                    ((File) arrayList2.remove(0)).delete();
                }
            }
            if (this.maxCycles > 0 && cycle >= this.maxCycles) {
                break;
            }
        } while (randomDecopTrainer.getLikelihoodIQR() >= 0.01d);
        if (this.outFile != null) {
            writeMotifs(randomDecopTrainer.getBestModel(), this.outFile);
        }
        if (distributedEvaluationManager instanceof DistributedEvaluationManager) {
            ((DistributedEvaluationManager) distributedEvaluationManager).shutdown();
        }
    }

    private void writeMotifs(MultiICAModel multiICAModel, String str) throws Exception {
        ContributionGroup contributionGroup = multiICAModel.getFacetteMap().getContributionGroups()[0];
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < multiICAModel.getComponents(); i++) {
            WeightedWeightMatrix weightedWeightMatrix = (WeightedWeightMatrix) multiICAModel.getContribution(contributionGroup, i).getItem();
            Motif motif = new Motif();
            motif.setName("motif" + i);
            motif.setWeightMatrix(weightedWeightMatrix.getWeightMatrix());
            motif.getAnnotation().setProperty("mocca.weight", new Double(weightedWeightMatrix.getWeight()));
            arrayList.add(motif);
        }
        Motif[] motifArr = (Motif[]) arrayList.toArray(new Motif[0]);
        FileOutputStream fileOutputStream = new FileOutputStream(str);
        MotifIOTools.writeMotifSetXML(fileOutputStream, motifArr);
        fileOutputStream.close();
    }
}
