package net.derkholm.nmica.trainer.distributed;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import net.derkholm.nmica.matrix.ObjectMatrix1D;
import net.derkholm.nmica.matrix.ObjectMatrix2D;
import net.derkholm.nmica.matrix.SimpleMatrix1D;
import net.derkholm.nmica.matrix.SimpleObjectMatrix2D;
import net.derkholm.nmica.model.ContributionGroup;
import net.derkholm.nmica.model.ContributionItem;
import net.derkholm.nmica.model.Facette;
import net.derkholm.nmica.model.LikelihoodCalculator;
import net.derkholm.nmica.model.SimpleContributionItem;
import net.derkholm.nmica.trainer.distributed.messages.ContributionRequest;
import net.derkholm.nmica.trainer.distributed.messages.ContributionResponse;
import net.derkholm.nmica.trainer.distributed.messages.DatumRequest;
import net.derkholm.nmica.trainer.distributed.messages.DatumResponse;
import net.derkholm.nmica.trainer.distributed.messages.Flush;
import net.derkholm.nmica.trainer.distributed.messages.LikelihoodRequest;
import net.derkholm.nmica.trainer.distributed.messages.LikelihoodResponse;
import net.derkholm.nmica.trainer.distributed.messages.NotReady;
import net.derkholm.nmica.trainer.distributed.messages.Ready;
import net.derkholm.nmica.trainer.distributed.messages.Shutdown;
import net.derkholm.nmica.trainer.distributed.messages.TrainerConfigRequest;
import net.derkholm.nmica.trainer.distributed.messages.TrainerConfigResponse;
import net.derkholm.nmica.utils.mq.MessageQueue;
import net.derkholm.nmica.utils.mq.Packable;
import net.derkholm.nmica.utils.mq.QueueDeadException;
import net.derkholm.nmica.utils.tracker.SimpleTracker;
import net.derkholm.nmica.utils.tracker.Task;
import org.biojava.bio.symbol.SymbolList;
import org.biojava.utils.ChangeSupport;

/* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedLikelihoodWorker.class */
public class DistributedLikelihoodWorker implements Runnable, Worker {
    private InetSocketAddress endpoint;
    private int numThreads;
    private MessageQueue<Packable> messageQ;
    private MessageQueue.Peer master;
    private EvaluationQueue evalQ;
    private int components;
    private Facette[] facettes;
    private ContributionGroup[] contributionGroups;
    private ObjectMatrix2D hoodCalcs;
    private Task currentDatumFetchTask;
    private int lruSize = 1000;
    private boolean running = true;
    private int currentSid = -1;
    private ObjectMatrix2D contributions = null;
    private Map<Object, ContributionItem> contributionCache = new LinkedHashMap<Object, ContributionItem>(1000, 0.75f, true) { // from class: net.derkholm.nmica.trainer.distributed.DistributedLikelihoodWorker.1
        @Override // java.util.LinkedHashMap
        protected boolean removeEldestEntry(Map.Entry<Object, ContributionItem> entry) {
            return size() > DistributedLikelihoodWorker.this.lruSize;
        }
    };
    private boolean throughputMonitor = false;
    private int tasksRun = 0;
    private int bases = 0;
    private int baseCards = 0;
    private int sleepy = 0;

    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedLikelihoodWorker$ContributionGopher.class */
    private class ContributionGopher implements ObjectMatrix1D {
        private final int cg;

        public ContributionGopher(int i) {
            this.cg = i;
        }

        @Override // net.derkholm.nmica.matrix.ObjectMatrix1D
        public int size() {
            return DistributedLikelihoodWorker.this.components;
        }

        @Override // net.derkholm.nmica.matrix.ObjectMatrix1D
        public Object get(int i) {
            Object obj = DistributedLikelihoodWorker.this.contributions.get(this.cg, i);
            if (obj == null) {
                try {
                    SimpleTracker simpleTracker = new SimpleTracker();
                    int i2 = 0;
                    for (int i3 = 0; i3 < DistributedLikelihoodWorker.this.components; i3++) {
                        if (DistributedLikelihoodWorker.this.contributions.get(this.cg, i3) == null) {
                            Task newTask = simpleTracker.newTask();
                            ContributionRequest contributionRequest = new ContributionRequest();
                            contributionRequest.sid = (short) DistributedLikelihoodWorker.this.currentSid;
                            contributionRequest.component = i3;
                            contributionRequest.contributionGroup = this.cg;
                            newTask.setData(contributionRequest);
                            DistributedLikelihoodWorker.this.contributions.set(this.cg, i3, newTask);
                            DistributedLikelihoodWorker.this.messageQ.sendMessage(DistributedLikelihoodWorker.this.master, contributionRequest);
                            DistributedLikelihoodWorker.this.messageQ.flush();
                            i2++;
                        }
                    }
                    System.err.println("Sent out " + i2 + " datum requests with sid=" + DistributedLikelihoodWorker.this.currentSid);
                    DistributedLikelihoodWorker.this.messageQ.flush();
                    simpleTracker.waitForTasks(2000L);
                    obj = DistributedLikelihoodWorker.this.contributions.get(this.cg, i);
                } catch (Exception e) {
                    throw new RuntimeException("Error fetching contribution", e);
                }
            }
            return obj;
        }

        @Override // net.derkholm.nmica.matrix.ObjectMatrix1D
        public void set(int i, Object obj) {
            throw new RuntimeException("ContributionGophers are read-only!");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedLikelihoodWorker$EvaluationQueue.class */
    public class EvaluationQueue {
        private EQThread[] threads;
        private BlockingQueue<MessageQueue.Message<Packable>> work = new LinkedBlockingQueue();
        private boolean run = true;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedLikelihoodWorker$EvaluationQueue$EQThread.class */
        public class EQThread extends Thread {
            public boolean running;

            private EQThread() {
                this.running = true;
            }

            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                while (EvaluationQueue.this.run) {
                    try {
                        MessageQueue.Message message = (MessageQueue.Message) EvaluationQueue.this.work.poll(200L, TimeUnit.MICROSECONDS);
                        if (message != null) {
                            LikelihoodRequest likelihoodRequest = (LikelihoodRequest) message.getBody();
                            LikelihoodCalculator likelihoodCalculator = DistributedLikelihoodWorker.this.getLikelihoodCalculator(likelihoodRequest.facette, likelihoodRequest.datum);
                            if (DistributedLikelihoodWorker.this.throughputMonitor) {
                                DistributedLikelihoodWorker.access$312(DistributedLikelihoodWorker.this, ((SymbolList) likelihoodCalculator.getData()).length());
                            }
                            double likelihood = likelihoodCalculator.likelihood(new ContributionGopher(likelihoodRequest.contributionGroup), new SimpleMatrix1D(likelihoodRequest.weights));
                            LikelihoodResponse likelihoodResponse = new LikelihoodResponse();
                            likelihoodResponse.sid = likelihoodRequest.sid;
                            likelihoodResponse.wid = likelihoodRequest.wid;
                            likelihoodResponse.likelihood = likelihood;
                            if (DistributedLikelihoodWorker.this.sleepy > 0 && Math.random() < 1.0d / DistributedLikelihoodWorker.this.sleepy) {
                                Thread.sleep(1L);
                            }
                            DistributedLikelihoodWorker.this.messageQ.sendMessage(DistributedLikelihoodWorker.this.master, likelihoodResponse);
                            if (EvaluationQueue.this.work.isEmpty()) {
                                DistributedLikelihoodWorker.this.messageQ.flush();
                            }
                            DistributedLikelihoodWorker.access$204(DistributedLikelihoodWorker.this);
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        public EvaluationQueue(int i) {
            this.threads = new EQThread[i];
        }

        public void start() {
            for (int i = 0; i < this.threads.length; i++) {
                this.threads[i] = new EQThread();
                this.threads[i].start();
            }
        }

        public void flush() {
            synchronized (this.work) {
                this.work.clear();
            }
        }

        public void enqueueWork(MessageQueue.Message<Packable> message) {
            while (true) {
                try {
                    this.work.put(message);
                    return;
                } catch (InterruptedException e) {
                }
            }
        }

        public void shutdown() {
            this.run = false;
        }
    }

    /* loaded from: input_file:net/derkholm/nmica/trainer/distributed/DistributedLikelihoodWorker$ThroughputMonitor.class */
    private class ThroughputMonitor extends Thread {
        private int oldTasksRun;
        private int oldBases;
        private long oldTime;

        private ThroughputMonitor() {
            this.oldTime = -1L;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (DistributedLikelihoodWorker.this.running) {
                long currentTimeMillis = System.currentTimeMillis();
                int i = DistributedLikelihoodWorker.this.tasksRun;
                int i2 = DistributedLikelihoodWorker.this.bases;
                if (this.oldTime > 0) {
                    int i3 = (int) (currentTimeMillis - this.oldTime);
                    System.err.println("" + ((1000 * (i - this.oldTasksRun)) / i3) + "\t" + ((i2 - this.oldBases) / i3));
                }
                this.oldTasksRun = i;
                this.oldBases = i2;
                this.oldTime = currentTimeMillis;
                try {
                    Thread.sleep(1000L);
                } catch (InterruptedException e) {
                }
            }
        }
    }

    public void setSleepy(int i) {
        this.sleepy = i;
    }

    public void setLruSize(int i) {
        this.lruSize = i;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public LikelihoodCalculator getLikelihoodCalculator(int i, int i2) {
        return (LikelihoodCalculator) this.hoodCalcs.get(i, i2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void initLikelihoodCalculators() throws Exception {
        for (int i = 0; i < this.hoodCalcs.rows(); i++) {
            for (int i2 = 0; i2 < this.hoodCalcs.columns(); i2++) {
                this.hoodCalcs.set(i, i2, this.facettes[i].getLikelihoodCalculator(getDatum(i2, i)));
            }
        }
        try {
            ChangeSupport.class.getMethod("setGlobalChangeBypass", Boolean.TYPE).invoke(null, Boolean.TRUE);
        } catch (NoSuchMethodException e) {
            System.err.println("Short-circuiting isn't available");
        }
    }

    private Object getDatum(int i, int i2) throws Exception {
        SimpleTracker simpleTracker = new SimpleTracker();
        this.currentDatumFetchTask = simpleTracker.newTask();
        DatumRequest datumRequest = new DatumRequest();
        datumRequest.datumIndex = i;
        datumRequest.facette = i2;
        this.currentDatumFetchTask.setData(datumRequest);
        this.messageQ.sendMessage(this.master, datumRequest);
        this.messageQ.flush();
        simpleTracker.waitForTasks(2000L);
        DatumResponse datumResponse = (DatumResponse) this.currentDatumFetchTask.getResult();
        if (datumResponse != null) {
            return datumResponse.datum;
        }
        System.err.println("Didn't get datum");
        return getDatum(i, i2);
    }

    DistributedLikelihoodWorker(InetSocketAddress inetSocketAddress, int i) {
        this.endpoint = inetSocketAddress;
        this.numThreads = i;
    }

    public static DistributedLikelihoodWorker connect(String str, int i, int i2) throws Exception {
        return new DistributedLikelihoodWorker(new InetSocketAddress(str, i), i2);
    }

    @Override // net.derkholm.nmica.trainer.distributed.Worker
    public void start() {
        try {
            this.messageQ = new MessageQueue<>();
            this.messageQ.setCodec(Protocol.CODEC);
            this.messageQ.setFlushThreshold(200);
            this.messageQ.start();
            this.master = this.messageQ.getPeer(this.endpoint);
            this.evalQ = new EvaluationQueue(this.numThreads);
            this.evalQ.start();
            new Thread(this).start();
            if (this.throughputMonitor) {
                new ThroughputMonitor().start();
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // net.derkholm.nmica.trainer.distributed.Worker
    public void stop() {
        this.running = false;
    }

    private boolean updateSid(int i) throws Exception {
        if (i == this.currentSid) {
            return true;
        }
        if (i < this.currentSid && this.currentSid - i < 1000) {
            System.err.println("Hit a request for an out-of-date SID (currentSid=" + this.currentSid + " requestedSid=" + i + "), bailing out...");
            return false;
        }
        this.contributions = new SimpleObjectMatrix2D(this.contributionGroups.length, this.components);
        this.currentSid = i;
        return true;
    }

    @Override // java.lang.Runnable
    public void run() {
        try {
            this.messageQ.sendMessage(this.master, new TrainerConfigRequest());
            this.messageQ.flush();
            Packable body = this.messageQ.next().getBody();
            if (!(body instanceof TrainerConfigResponse)) {
                throw new Exception("Unexpected response: " + body);
            }
            TrainerConfigResponse trainerConfigResponse = (TrainerConfigResponse) body;
            this.components = trainerConfigResponse.components;
            this.facettes = trainerConfigResponse.facetteMap.getFacettes();
            this.contributionGroups = trainerConfigResponse.facetteMap.getContributionGroups();
            this.hoodCalcs = new SimpleObjectMatrix2D(this.facettes.length, trainerConfigResponse.dataSetSize);
            System.err.println("We're configured!");
            Thread thread = new Thread("Readiness pinger") { // from class: net.derkholm.nmica.trainer.distributed.DistributedLikelihoodWorker.2
                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    try {
                        DistributedLikelihoodWorker.this.initLikelihoodCalculators();
                        while (true) {
                            try {
                                DistributedLikelihoodWorker.this.messageQ.sendMessage(DistributedLikelihoodWorker.this.master, new Ready());
                                Thread.sleep(1000L);
                            } catch (InterruptedException e) {
                            } catch (QueueDeadException e2) {
                                return;
                            }
                        }
                    } catch (Exception e3) {
                        throw new RuntimeException(e3);
                    }
                }
            };
            thread.setDaemon(true);
            thread.start();
            ArrayList<MessageQueue.Message<Packable>> arrayList = new ArrayList();
            while (this.running) {
                arrayList.clear();
                this.messageQ.next(arrayList);
                for (MessageQueue.Message<Packable> message : arrayList) {
                    Packable body2 = message.getBody();
                    if (body2 instanceof LikelihoodRequest) {
                        if (updateSid(((LikelihoodRequest) body2).sid)) {
                            this.evalQ.enqueueWork(message);
                        }
                    } else if (body2 instanceof DatumResponse) {
                        this.currentDatumFetchTask.completed(body2);
                    } else if (body2 instanceof ContributionResponse) {
                        ContributionResponse contributionResponse = (ContributionResponse) body2;
                        if (updateSid(contributionResponse.sid)) {
                            Object obj = this.contributions.get(contributionResponse.contributionGroup, contributionResponse.component);
                            Task task = obj instanceof Task ? (Task) obj : null;
                            ContributionItem contributionItem = this.contributionCache.get(contributionResponse.contribution);
                            if (contributionItem == null) {
                                contributionItem = new SimpleContributionItem(contributionResponse.contribution);
                                this.contributionCache.put(contributionResponse.contribution, contributionItem);
                            }
                            this.contributions.set(contributionResponse.contributionGroup, contributionResponse.component, contributionItem);
                            if (task != null) {
                                task.completed();
                            }
                        }
                    } else if (body2 instanceof Flush) {
                        this.evalQ.flush();
                    } else if (body2 instanceof Shutdown) {
                        this.running = false;
                    } else {
                        System.err.println("Unexpected message type " + body2.getClass().getName());
                    }
                }
            }
            this.messageQ.sendMessage(this.master, new NotReady());
            this.messageQ.shutdown();
            this.evalQ.shutdown();
        } catch (Exception e) {
            throw new RuntimeException("Error communicating with trainer", e);
        }
    }

    static /* synthetic */ int access$312(DistributedLikelihoodWorker distributedLikelihoodWorker, int i) {
        int i2 = distributedLikelihoodWorker.bases + i;
        distributedLikelihoodWorker.bases = i2;
        return i2;
    }

    static /* synthetic */ int access$204(DistributedLikelihoodWorker distributedLikelihoodWorker) {
        int i = distributedLikelihoodWorker.tasksRun + 1;
        distributedLikelihoodWorker.tasksRun = i;
        return i;
    }
}
