package net.derkholm.nmica.trainer.distributed;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import net.derkholm.nmica.maths.DoubleProcedure;
import net.derkholm.nmica.matrix.SimpleMatrix1D;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.trainer.EvaluationManager;
import net.derkholm.nmica.trainer.TrainableState;
import net.derkholm.nmica.trainer.TrainableStateContext;
import net.derkholm.nmica.trainer.distributed.messages.ContributionRequest;
import net.derkholm.nmica.trainer.distributed.messages.ContributionResponse;
import net.derkholm.nmica.trainer.distributed.messages.DatumRequest;
import net.derkholm.nmica.trainer.distributed.messages.DatumResponse;
import net.derkholm.nmica.trainer.distributed.messages.LikelihoodRequest;
import net.derkholm.nmica.trainer.distributed.messages.LikelihoodResponse;
import net.derkholm.nmica.trainer.distributed.messages.NotReady;
import net.derkholm.nmica.trainer.distributed.messages.Ready;
import net.derkholm.nmica.trainer.distributed.messages.Shutdown;
import net.derkholm.nmica.trainer.distributed.messages.TrainerConfigRequest;
import net.derkholm.nmica.trainer.distributed.messages.TrainerConfigResponse;
import net.derkholm.nmica.utils.mq.MessageQueue;
import net.derkholm.nmica.utils.mq.Packable;
import net.derkholm.nmica.utils.mq.QueueDeadException;

/* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedEvaluationManager.class */
public class DistributedEvaluationManager implements EvaluationManager {
    private int port;
    private TrainableStateContext trainer;
    private TrainableState currentState;
    private MessageQueue<Packable> messageQ;
    private TrainingMonitor monitor;
    private Thread ticker;
    private Thread shutdownHook;
    private static final WorkUnit[] EMPTY_WORK_ARRAY = new WorkUnit[0];
    private List<WorkUnit> workList = new ArrayList();
    private CountDownLatch workWaiter = null;
    private short currentSid = 0;
    private volatile boolean inShutdown = false;
    private Set<WorkerRecord> readySet = new HashSet();
    private boolean seedContributions = true;
    private long tickedOff = 10000;
    private boolean debug = false;
    private boolean bandwidthTicker = false;
    private int hoods = 0;
    private boolean crab = false;
    private double crabRate = 0.001d;
    private boolean crabDebug = false;
    private int crabDebugLastCycle = 1;
    private volatile FinishLine finishLine = null;

    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedEvaluationManager$BandwidthThread.class */
    private class BandwidthThread extends Thread {
        private long oldTime;
        private int oldTxPackets;
        private int oldRxPackets;
        private int oldTxBytes;
        private int oldRxBytes;
        private int oldHoods;
        private final NumberFormat FORMAT;

        private BandwidthThread() {
            this.oldTime = -1L;
            this.oldTxPackets = 0;
            this.oldRxPackets = 0;
            this.oldTxBytes = 0;
            this.oldRxBytes = 0;
            this.oldHoods = 0;
            this.FORMAT = new DecimalFormat("######0.0");
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (!DistributedEvaluationManager.this.inShutdown) {
                long currentTimeMillis = System.currentTimeMillis();
                int txPackets = DistributedEvaluationManager.this.messageQ.getTxPackets();
                int txBytes = DistributedEvaluationManager.this.messageQ.getTxBytes();
                int rxPackets = DistributedEvaluationManager.this.messageQ.getRxPackets();
                int rxBytes = DistributedEvaluationManager.this.messageQ.getRxBytes();
                int i = DistributedEvaluationManager.this.hoods;
                if (this.oldTime > 0) {
                    int i2 = (int) (currentTimeMillis - this.oldTime);
                    StringBuffer stringBuffer = new StringBuffer();
                    stringBuffer.append("Out ");
                    stringBuffer.append(this.FORMAT.format((1.0d * (txBytes - this.oldTxBytes)) / i2));
                    stringBuffer.append("kb/s ");
                    stringBuffer.append(this.FORMAT.format((1000.0d * (txPackets - this.oldTxPackets)) / i2));
                    stringBuffer.append("packs/s       In ");
                    stringBuffer.append(this.FORMAT.format((1.0d * (rxBytes - this.oldRxBytes)) / i2));
                    stringBuffer.append("kb/s ");
                    stringBuffer.append(this.FORMAT.format((1000.0d * (rxPackets - this.oldRxPackets)) / i2));
                    stringBuffer.append("packs/s       Throughput ");
                    stringBuffer.append(this.FORMAT.format((1000.0d * (i - this.oldHoods)) / i2));
                    stringBuffer.append("lc/s");
                    System.err.println(stringBuffer.toString());
                }
                this.oldTxPackets = txPackets;
                this.oldTxBytes = txBytes;
                this.oldRxPackets = rxPackets;
                this.oldRxBytes = rxBytes;
                this.oldHoods = i;
                this.oldTime = currentTimeMillis;
                try {
                    Thread.sleep(1000L);
                } catch (InterruptedException e) {
                }
            }
        }
    }

    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedEvaluationManager$FinishLine.class */
    private static final class FinishLine {
        private int[] counts;
        private CountDownLatch readyLatch = new CountDownLatch(1);
        private CountDownLatch finishedLatch = new CountDownLatch(1);
        private int total = 0;
        private int completed = 0;
        private int winner = -1;
        private int loser = -1;

        FinishLine() {
        }

        public void setCounts(int[] iArr) {
            this.counts = iArr;
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] == 0) {
                    this.winner = i;
                } else {
                    this.total += iArr[i];
                }
            }
            this.readyLatch.countDown();
            if (this.total == 0) {
                this.finishedLatch.countDown();
            }
        }

        public void commitAndDecrement(WorkUnit workUnit, double d) throws InterruptedException {
            this.readyLatch.await();
            workUnit.writeback.run(d);
            workUnit.writeback = null;
            int i = workUnit.assignedWorkerID;
            int[] iArr = this.counts;
            iArr[i] = iArr[i] - 1;
            if (this.counts[i] == 0 && this.winner < 0) {
                this.winner = i;
            }
            this.completed++;
            if (this.completed == this.total) {
                this.loser = i;
                this.finishedLatch.countDown();
            }
        }

        public int getTotal() {
            return this.total;
        }

        public int getCompleted() {
            return this.completed;
        }

        public int getWinner() {
            return this.winner;
        }

        public int getLoser() {
            return this.loser;
        }

        public void await() throws InterruptedException {
            this.finishedLatch.await();
        }

        public boolean await(long j, TimeUnit timeUnit) throws InterruptedException {
            return this.finishedLatch.await(j, timeUnit);
        }
    }

    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedEvaluationManager$TrainingMonitor.class */
    private class TrainingMonitor extends Thread {
        private TrainingMonitor() {
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    MessageQueue.Message next = DistributedEvaluationManager.this.messageQ.next();
                    MessageQueue.Peer sender = next.getSender();
                    WorkerRecord workerRecord = (WorkerRecord) sender.getUserData();
                    if (workerRecord == null) {
                        System.err.println("Got first call-in from new node");
                        workerRecord = new WorkerRecord(sender);
                        sender.setUserData(workerRecord);
                    }
                    Packable packable = (Packable) next.getBody();
                    if (DistributedEvaluationManager.this.debug) {
                        System.err.println("Got " + packable.getClass().getName());
                    }
                    if (packable instanceof TrainerConfigRequest) {
                        TrainerConfigResponse trainerConfigResponse = new TrainerConfigResponse();
                        trainerConfigResponse.components = DistributedEvaluationManager.this.trainer.getComponents();
                        trainerConfigResponse.dataSetSize = DistributedEvaluationManager.this.trainer.getDataSet().length;
                        trainerConfigResponse.facetteMap = DistributedEvaluationManager.this.trainer.getFacetteMap();
                        DistributedEvaluationManager.this.messageQ.sendMessage(sender, trainerConfigResponse);
                        DistributedEvaluationManager.this.messageQ.flush();
                    } else if (packable instanceof Ready) {
                        DistributedEvaluationManager.this.readySet.add(workerRecord);
                        workerRecord.lastPing = System.currentTimeMillis();
                    } else if (packable instanceof NotReady) {
                        if (!DistributedEvaluationManager.this.readySet.contains(workerRecord)) {
                            System.err.println("Strange, node isn 't in ready set");
                        }
                        DistributedEvaluationManager.this.readySet.remove(workerRecord);
                    } else if (packable instanceof DatumRequest) {
                        DatumRequest datumRequest = (DatumRequest) packable;
                        DatumResponse datumResponse = new DatumResponse();
                        datumResponse.datumIndex = datumRequest.datumIndex;
                        datumResponse.facette = datumRequest.facette;
                        datumResponse.datum = DistributedEvaluationManager.this.trainer.getDataSet()[datumRequest.datumIndex].getFacettedData()[datumRequest.facette];
                        DistributedEvaluationManager.this.messageQ.sendMessage(sender, datumResponse);
                        DistributedEvaluationManager.this.messageQ.flush();
                    } else if (packable instanceof ContributionRequest) {
                        ContributionRequest contributionRequest = (ContributionRequest) packable;
                        ContributionResponse contributionResponse = new ContributionResponse();
                        contributionResponse.sid = DistributedEvaluationManager.this.currentSid;
                        contributionResponse.component = contributionRequest.component;
                        contributionResponse.contributionGroup = contributionRequest.contributionGroup;
                        contributionResponse.contribution = DistributedEvaluationManager.this.currentState.getContribution(contributionRequest.contributionGroup, contributionRequest.component).getItem();
                        DistributedEvaluationManager.this.messageQ.sendMessage(sender, contributionResponse);
                    } else if (packable instanceof LikelihoodResponse) {
                        LikelihoodResponse likelihoodResponse = (LikelihoodResponse) packable;
                        if (likelihoodResponse.sid != DistributedEvaluationManager.this.currentSid) {
                            System.err.println("Yuck, got back an out-of-date reponse");
                        } else {
                            WorkUnit workUnit = (WorkUnit) DistributedEvaluationManager.this.workList.get(likelihoodResponse.wid);
                            if (workUnit.writeback != null) {
                                DistributedEvaluationManager.this.finishLine.commitAndDecrement(workUnit, likelihoodResponse.likelihood);
                                DistributedEvaluationManager.access$204(DistributedEvaluationManager.this);
                            }
                        }
                    } else {
                        System.err.println("Unrecognized message class: " + packable.getClass().getName());
                    }
                } catch (QueueDeadException e) {
                    return;
                } catch (Exception e2) {
                    e2.printStackTrace();
                    return;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedEvaluationManager$WorkerRecord.class */
    public static class WorkerRecord {
        public final MessageQueue.Peer node;
        public long lastPing = 0;
        public Set runningWork = new HashSet();
        public double weight = 1.0d;

        public WorkerRecord(MessageQueue.Peer peer) {
            this.node = peer;
        }
    }

    public void setCrab(boolean z) {
        this.crab = z;
    }

    public void setCrabRate(double d) {
        this.crabRate = d;
    }

    public void setCrabDebug(boolean z) {
        this.crabDebug = z;
    }

    public void shutdown() {
        Runtime.getRuntime().removeShutdownHook(this.shutdownHook);
        doShutdown();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void doShutdown() {
        Iterator<WorkerRecord> it = this.readySet.iterator();
        while (it.hasNext()) {
            try {
                this.messageQ.sendMessage(it.next().node, new Shutdown());
            } catch (QueueDeadException e) {
            }
        }
        this.messageQ.shutdown();
        this.inShutdown = true;
    }

    public DistributedEvaluationManager(int i) throws IOException {
        this.port = 0;
        this.port = i;
        this.messageQ = new MessageQueue<>(i);
        this.messageQ.setCodec(Protocol.CODEC);
        this.messageQ.setFlushThreshold(40);
        this.messageQ.start();
        this.monitor = new TrainingMonitor();
        this.monitor.start();
        if (this.bandwidthTicker) {
            this.ticker = new BandwidthThread();
            this.ticker.start();
        }
        this.shutdownHook = new Thread() { // from class: net.derkholm.nmica.trainer.distributed.DistributedEvaluationManager.1
            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                DistributedEvaluationManager.this.doShutdown();
            }
        };
        Runtime.getRuntime().addShutdownHook(this.shutdownHook);
    }

    public InetSocketAddress getDatagramEndpoint() {
        return this.messageQ.getEndpoint();
    }

    @Override // net.derkholm.nmica.trainer.EvaluationManager
    public void startLikelihoodCalculations(TrainableState trainableState) {
        if (this.currentState != null) {
            throw new IllegalStateException("Can't start likelihood calculations while queues are busy");
        }
        if (this.trainer == null) {
            this.trainer = trainableState.getContext();
        } else if (this.trainer != trainableState.getContext()) {
            throw new RuntimeException("Already bound to a trainer");
        }
        this.currentState = trainableState;
        this.currentSid = (short) (this.currentSid + 1);
        if (this.currentSid > 16384) {
            this.currentSid = (short) 0;
        }
        this.workList.clear();
    }

    @Override // net.derkholm.nmica.trainer.EvaluationManager
    public void enqueueLikelihoodCalculation(TrainableState trainableState, int i, int i2, DoubleProcedure doubleProcedure) {
        if (trainableState != this.currentState) {
            throw new IllegalStateException("Requested state is not open for likelihood evaluations");
        }
        synchronized (this.workList) {
            this.workList.add(new WorkUnit(this.workList.size(), this.currentSid, i2, this.trainer.facetteIndexToContributionIndex(i2), i, new SimpleMatrix1D(trainableState.getMixture(i)), doubleProcedure));
        }
    }

    @Override // net.derkholm.nmica.trainer.EvaluationManager
    public void endLikelihoodCalculations(TrainableState trainableState) {
        WorkerRecord[] workerRecordArr;
        boolean z;
        int winner;
        if (trainableState != this.currentState) {
            throw new IllegalStateException("Requested state is not open for likelihood evaluations");
        }
        loop0: while (true) {
            workerRecordArr = null;
            do {
                if (workerRecordArr != null) {
                    System.err.println("No worker nodes available");
                    try {
                        Thread.sleep(1000L);
                    } catch (Exception e) {
                    }
                }
                workerRecordArr = (WorkerRecord[]) this.readySet.toArray(new WorkerRecord[0]);
            } while (workerRecordArr.length == 0);
            if (this.seedContributions) {
                ContributionGroup[] contributionGroups = this.trainer.getFacetteMap().getContributionGroups();
                for (int i = 0; i < contributionGroups.length; i++) {
                    for (int i2 = 0; i2 < this.trainer.getComponents(); i2++) {
                        Object item = this.currentState.getContribution(i, i2).getItem();
                        ContributionResponse contributionResponse = new ContributionResponse();
                        contributionResponse.component = i2;
                        contributionResponse.contributionGroup = i;
                        contributionResponse.sid = this.currentSid;
                        contributionResponse.contribution = item;
                        for (WorkerRecord workerRecord : workerRecordArr) {
                            try {
                                this.messageQ.sendMessage(workerRecord.node, contributionResponse);
                            } catch (QueueDeadException e2) {
                            }
                        }
                    }
                }
            }
            this.finishLine = new FinishLine();
            synchronized (this.workList) {
                int[] iArr = new int[workerRecordArr.length];
                double[] dArr = new double[workerRecordArr.length];
                double d = 0.0d;
                if (this.crab) {
                    for (int i3 = 0; i3 < workerRecordArr.length; i3++) {
                        dArr[i3] = workerRecordArr[i3].weight * 1.3142d;
                        d += dArr[i3];
                    }
                }
                int i4 = 0;
                double random = Math.random() * d;
                for (WorkUnit workUnit : this.workList) {
                    if (workUnit.writeback != null) {
                        LikelihoodRequest likelihoodRequest = new LikelihoodRequest();
                        likelihoodRequest.sid = this.currentSid;
                        likelihoodRequest.wid = workUnit.wid;
                        likelihoodRequest.contributionGroup = workUnit.contributionGroup;
                        likelihoodRequest.datum = workUnit.datum;
                        likelihoodRequest.facette = workUnit.facette;
                        likelihoodRequest.weights = workUnit.weights.getRaw();
                        try {
                            int i5 = -1;
                            if (this.crab) {
                                while (random >= d) {
                                    random -= d;
                                }
                                double d2 = 0.0d;
                                int i6 = 0;
                                while (true) {
                                    if (i6 >= workerRecordArr.length) {
                                        break;
                                    }
                                    d2 += dArr[i6];
                                    if (d2 >= random) {
                                        i5 = i6;
                                        break;
                                    }
                                    i6++;
                                }
                                random += 1.0d;
                            } else {
                                i5 = likelihoodRequest.wid % workerRecordArr.length;
                            }
                            workUnit.assignedWorkerID = i5;
                            this.messageQ.sendMessage(workerRecordArr[i5].node, likelihoodRequest);
                            int i7 = i5;
                            iArr[i7] = iArr[i7] + 1;
                            i4++;
                        } catch (QueueDeadException e3) {
                        }
                    }
                }
                z = i4 >= 4 * workerRecordArr.length;
                this.finishLine.setCounts(iArr);
            }
            System.currentTimeMillis();
            this.messageQ.flush();
            int i8 = 0;
            int i9 = 0;
            while (!this.finishLine.await(500L, TimeUnit.MILLISECONDS)) {
                try {
                    int completed = this.finishLine.getCompleted();
                    if (completed == i8) {
                        System.err.println("Tick has passed without any work checked in");
                        i9++;
                        if (i9 > 1) {
                            System.err.println("Some likelihood calculations (" + (this.finishLine.getTotal() - this.finishLine.getCompleted()) + "/" + this.finishLine.getTotal() + ") seem to have failed, restarting");
                            long currentTimeMillis = System.currentTimeMillis();
                            Iterator<WorkerRecord> it = this.readySet.iterator();
                            while (it.hasNext()) {
                                if (currentTimeMillis - it.next().lastPing > 3000) {
                                    System.err.println("Haven't heard from this node for a while, removing from readyset");
                                    it.remove();
                                }
                            }
                        }
                    }
                    i8 = completed;
                } catch (Exception e4) {
                    throw new RuntimeException(e4);
                }
            }
            break loop0;
        }
        if (this.crab && z && (winner = this.finishLine.getWinner()) >= 0) {
            workerRecordArr[winner].weight += this.crabRate;
            double d3 = 0.0d;
            for (WorkerRecord workerRecord2 : workerRecordArr) {
                d3 += workerRecord2.weight;
            }
            double length = d3 / workerRecordArr.length;
            for (WorkerRecord workerRecord3 : workerRecordArr) {
                workerRecord3.weight /= length;
            }
        }
        this.currentState = null;
    }

    static /* synthetic */ int access$204(DistributedEvaluationManager distributedEvaluationManager) {
        int i = distributedEvaluationManager.hoods + 1;
        distributedEvaluationManager.hoods = i;
        return i;
    }
}
