package io.bioimage.modelrunner.gui.custom;

import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.gui.EnvironmentInstaller;
import io.bioimage.modelrunner.gui.workers.InstallEnvWorker;
import io.bioimage.modelrunner.model.special.stardist.Stardist2D;
import io.bioimage.modelrunner.model.special.stardist.StardistAbstract;
import io.bioimage.modelrunner.tensor.Tensor;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;
import java.awt.FlowLayout;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.function.Consumer;
import javax.swing.BorderFactory;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JDialog;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JProgressBar;
import javax.swing.JSpinner;
import javax.swing.JTextField;
import javax.swing.SpinnerNumberModel;
import javax.swing.SwingUtilities;
import javax.swing.border.EmptyBorder;
import javax.swing.event.PopupMenuEvent;
import javax.swing.event.PopupMenuListener;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.view.Views;

/* loaded from: input_file:io/bioimage/modelrunner/gui/custom/StardistGUI.class */
public class StardistGUI extends JPanel implements ActionListener {
    private static final long serialVersionUID = 5381352117710530216L;
    private final ConsumerInterface consumer;
    private String whichLoaded;
    private StardistAbstract model;
    private String inputTitle;
    private Runnable cancelCallback;
    Thread workerThread;
    private JComboBox<String> modelComboBox;
    private JLabel customLabel;
    private JTextField customModelPathField;
    private JButton browseButton;
    private JSpinner minPercField;
    private JSpinner maxPercField;
    private JProgressBar bar;
    private JButton cancelButton;
    private JButton installButton;
    private JButton runButton;
    private final String CUSTOM_STR = "your custom model";
    private static List<String> VAR_NAMES = Arrays.asList("Select a model:", "Custom Model Path:", "Normalization low percentile:", "Normalization low percentile:");

    public StardistGUI(ConsumerInterface consumerInterface) {
        this.consumer = consumerInterface;
        ArrayList arrayList = new ArrayList();
        setLayout(new BorderLayout());
        JPanel jPanel = new JPanel();
        jPanel.setLayout(new BoxLayout(jPanel, 1));
        jPanel.setBorder(new EmptyBorder(15, 15, 15, 15));
        JPanel jPanel2 = new JPanel(new FlowLayout(0));
        jPanel2.add(new JLabel(VAR_NAMES.get(0)));
        this.modelComboBox = new JComboBox<>(new String[]{"StarDist Fluorescence Nuclei Segmentation", "StarDist H&E Nuclei Segmentation", "your custom model"});
        jPanel2.add(this.modelComboBox);
        JPanel jPanel3 = new JPanel(new FlowLayout(0));
        this.customLabel = new JLabel(VAR_NAMES.get(1));
        this.customLabel.setEnabled(false);
        jPanel3.add(this.customLabel);
        this.customModelPathField = new JTextField(20);
        this.customModelPathField.setEnabled(false);
        jPanel3.add(this.customModelPathField);
        this.browseButton = new JButton("Browse");
        this.browseButton.setEnabled(false);
        jPanel3.add(this.browseButton);
        JPanel jPanel4 = new JPanel(new GridLayout(3, 2, 10, 10));
        jPanel4.setBorder(BorderFactory.createTitledBorder("Optional Parameters"));
        jPanel4.add(new JLabel(VAR_NAMES.get(2)));
        this.minPercField = new JSpinner(new SpinnerNumberModel(1.0d, 0.0d, 100.0d, 0.01d));
        jPanel4.add(this.minPercField);
        jPanel4.add(new JLabel(VAR_NAMES.get(3)));
        this.maxPercField = new JSpinner(new SpinnerNumberModel(99.8d, 0.0d, 100.0d, 0.01d));
        jPanel4.add(this.maxPercField);
        JPanel jPanel5 = new JPanel(new GridLayout(1, 2));
        jPanel5.setBorder(BorderFactory.createEtchedBorder());
        JPanel jPanel6 = new JPanel(new BorderLayout());
        jPanel6.setBorder(BorderFactory.createEmptyBorder(10, 5, 10, 5));
        JPanel jPanel7 = new JPanel(new FlowLayout(2));
        jPanel7.setBorder(BorderFactory.createEtchedBorder());
        this.cancelButton = new JButton("Cancel");
        this.installButton = new JButton("Install");
        this.runButton = new JButton("Run");
        jPanel7.add(this.cancelButton);
        jPanel7.add(this.installButton);
        jPanel7.add(this.runButton);
        this.bar = new JProgressBar();
        this.bar.setStringPainted(true);
        this.bar.setString("");
        jPanel6.add(this.bar, "Center");
        jPanel5.add(jPanel6);
        jPanel5.add(jPanel7);
        jPanel.add(jPanel2);
        jPanel.add(jPanel3);
        jPanel.add(Box.createVerticalStrut(10));
        jPanel.add(jPanel4);
        jPanel.add(Box.createVerticalStrut(10));
        jPanel.add(jPanel5);
        jPanel.setBorder(BorderFactory.createEmptyBorder());
        add(jPanel, "Center");
        this.consumer.setVariableNames(VAR_NAMES);
        arrayList.add(this.modelComboBox);
        arrayList.add(this.customModelPathField);
        arrayList.add(this.minPercField);
        arrayList.add(this.maxPercField);
        this.consumer.setComponents(arrayList);
        this.installButton.addActionListener(this);
        this.runButton.addActionListener(this);
        this.cancelButton.addActionListener(this);
        this.modelComboBox.addPopupMenuListener(new PopupMenuListener() { // from class: io.bioimage.modelrunner.gui.custom.StardistGUI.1
            public void popupMenuWillBecomeVisible(PopupMenuEvent popupMenuEvent) {
            }

            public void popupMenuCanceled(PopupMenuEvent popupMenuEvent) {
            }

            public void popupMenuWillBecomeInvisible(PopupMenuEvent popupMenuEvent) {
                boolean equals = StardistGUI.this.modelComboBox.getSelectedItem().equals("your custom model");
                StardistGUI.this.customLabel.setEnabled(equals);
                StardistGUI.this.customModelPathField.setEnabled(equals);
                StardistGUI.this.browseButton.setEnabled(equals);
            }
        });
    }

    public void setCancelCallback(Runnable runnable) {
        this.cancelCallback = runnable;
    }

    public void close() {
        if (this.model == null || !this.model.isLoaded()) {
            return;
        }
        this.model.close();
    }

    public static void main(String[] strArr) {
        SwingUtilities.invokeLater(new Runnable() { // from class: io.bioimage.modelrunner.gui.custom.StardistGUI.2
            @Override // java.lang.Runnable
            public void run() {
                JFrame jFrame = new JFrame("StarDist Plugin");
                jFrame.setDefaultCloseOperation(3);
                jFrame.getContentPane().add(new StardistGUI(null));
                jFrame.pack();
                jFrame.setLocationRelativeTo((Component) null);
                jFrame.setVisible(true);
            }
        });
    }

    public void actionPerformed(ActionEvent actionEvent) {
        if (actionEvent.getSource() == this.browseButton) {
            browseFiles();
            return;
        }
        if (actionEvent.getSource() == this.runButton) {
            this.workerThread = new Thread(() -> {
                try {
                    runStardist();
                    startModelInstallation(false);
                } catch (LoadModelException | RunModelException | IOException e) {
                    e.printStackTrace();
                    startModelInstallation(false);
                    SwingUtilities.invokeLater(() -> {
                        this.bar.setString("Error running the model");
                    });
                }
            });
            this.workerThread.start();
        } else if (actionEvent.getSource() == this.installButton) {
            this.workerThread = new Thread(() -> {
                installStardist();
            });
            this.workerThread.start();
        } else if (actionEvent.getSource() == this.cancelButton) {
            cancel();
        }
    }

    private void cancel() {
        if (this.workerThread != null && this.workerThread.isAlive()) {
            this.workerThread.interrupt();
        }
        if (this.model != null) {
            this.model.close();
        }
        if (this.cancelCallback != null) {
            this.cancelCallback.run();
        }
    }

    private <T extends RealType<T> & NativeType<T>> void runStardist() throws IOException, RunModelException, LoadModelException {
        startModelInstallation(true);
        installStardist(weightsInstalled(), StardistAbstract.isInstalled());
        RandomAccessibleInterval<T> focusedImageAsRai = this.consumer.getFocusedImageAsRai();
        this.inputTitle = this.consumer.getFocusedImageName();
        if (focusedImageAsRai == null) {
            JOptionPane.showMessageDialog((Component) null, "Please open an image", "No image open", 0);
            return;
        }
        SwingUtilities.invokeLater(() -> {
            this.bar.setIndeterminate(true);
            this.bar.setString("Loading model");
        });
        String str = (String) this.modelComboBox.getSelectedItem();
        String str2 = "" + str;
        if (str2.equals("your custom model")) {
            str = this.customModelPathField.getText();
        }
        if (str2.equals("your custom model") && (this.whichLoaded == null || this.model == null || this.model.isClosed() || !this.whichLoaded.equals(str))) {
            this.model = StardistAbstract.init(str);
        } else if (!str2.equals("your custom model") && (this.whichLoaded == null || this.model == null || this.model.isClosed() || !this.whichLoaded.equals(str))) {
            try {
                this.model = Stardist2D.fromPretained(str, this.consumer.getModelsDir(), false);
            } catch (InterruptedException e) {
                e.printStackTrace();
                return;
            }
        } else if (this.model == null) {
            throw new IllegalArgumentException();
        }
        if (!this.model.isLoaded()) {
            this.model.loadModel();
        }
        this.whichLoaded = str;
        SwingUtilities.invokeLater(() -> {
            this.bar.setString("Running the model");
        });
        if (focusedImageAsRai.dimensionsAsLongArray().length == 4) {
            runStardistOnFramesStack(focusedImageAsRai);
        } else {
            runStardistOnTensor(focusedImageAsRai);
        }
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runStardistOnFramesStack(RandomAccessibleInterval<R> randomAccessibleInterval) throws RunModelException {
        long[] dimensionsAsLongArray = randomAccessibleInterval.dimensionsAsLongArray();
        RandomAccessibleInterval<T> randomAccessibleInterval2 = (RandomAccessibleInterval) Cast.unchecked(ArrayImgs.floats(new long[]{dimensionsAsLongArray[0], dimensionsAsLongArray[1], dimensionsAsLongArray[3]}));
        for (int i = 0; i < randomAccessibleInterval.dimensionsAsLongArray()[3]; i++) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(Tensor.build("input", "xyc", Views.hyperSlice(randomAccessibleInterval, 3, i)));
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(Tensor.build("mask", "xy", Views.hyperSlice(randomAccessibleInterval2, 2, i)));
            this.model.run(arrayList, arrayList2);
        }
        this.consumer.display(randomAccessibleInterval2, "xyb", "mask");
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runStardistOnTensor(RandomAccessibleInterval<R> randomAccessibleInterval) throws RunModelException {
        Tensor build = Tensor.build("input", "xyc", randomAccessibleInterval);
        ArrayList arrayList = new ArrayList();
        arrayList.add(build);
        for (Tensor<T> tensor : this.model.run(arrayList)) {
            if (tensor.getAxesOrder().length != 1) {
                this.consumer.display(tensor.getData(), tensor.getAxesOrderString(), getOutputName(tensor.getName()));
            }
        }
    }

    private String getOutputName(String str) {
        return this.inputTitle.substring(0, this.inputTitle.lastIndexOf(".")) + "_" + str + ".tif";
    }

    private void installStardist() {
        startModelInstallation(true);
        boolean isInstalled = StardistAbstract.isInstalled();
        boolean weightsInstalled = weightsInstalled();
        if (isInstalled && weightsInstalled) {
            startModelInstallation(false);
        } else {
            installStardist(weightsInstalled, isInstalled);
        }
    }

    private void installStardist(boolean z, boolean z2) {
        if (z && z2) {
            return;
        }
        SwingUtilities.invokeLater(() -> {
            this.bar.setString("Installing...");
        });
        CountDownLatch countDownLatch = (z || z2) ? new CountDownLatch(1) : new CountDownLatch(2);
        if (!z) {
            installModelWeights(countDownLatch);
        }
        if (!z2) {
            installEnv(countDownLatch);
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    private boolean weightsInstalled() {
        String str = (String) this.modelComboBox.getSelectedItem();
        if (str.equals("your custom model")) {
            return true;
        }
        try {
            return Stardist2D.fromPretained(str, this.consumer.getModelsDir(), false) != null;
        } catch (Exception e) {
            return false;
        }
    }

    private void installModelWeights(CountDownLatch countDownLatch) {
        Consumer consumer = d -> {
            double round = Math.round(d.doubleValue() * 1000.0d) / 10.0d;
            SwingUtilities.invokeLater(() -> {
                this.bar.setValue((int) Math.floor(round));
                this.bar.setString(round + "% of weights");
            });
        };
        SwingUtilities.invokeLater(() -> {
            this.bar.setIndeterminate(false);
        });
        new Thread(() -> {
            try {
                Stardist2D.donwloadPretrained((String) this.modelComboBox.getSelectedItem(), this.consumer.getModelsDir(), consumer);
            } catch (IOException | InterruptedException e) {
                e.printStackTrace();
            }
            countDownLatch.countDown();
            checkModelInstallationFinished(countDownLatch);
        }).start();
    }

    private void installEnv(CountDownLatch countDownLatch) {
        if (StardistAbstract.isInstalled() || JOptionPane.showConfirmDialog((Component) null, "Installation of Python environments might take up to 20 minutes.", "Install Python for StarDist", 0) != 0) {
            countDownLatch.countDown();
            checkModelInstallationFinished(countDownLatch);
            return;
        }
        JDialog jDialog = new JDialog();
        jDialog.setTitle("Installing StarDist");
        jDialog.setDefaultCloseOperation(0);
        InstallEnvWorker installEnvWorker = new InstallEnvWorker("StarDist", countDownLatch, () -> {
            checkModelInstallationFinished(countDownLatch);
            if (jDialog.isVisible()) {
                jDialog.dispose();
            }
        });
        EnvironmentInstaller create = EnvironmentInstaller.create(installEnvWorker);
        installEnvWorker.setConsumer(str -> {
            create.updateText(str, Color.black);
            if (countDownLatch.getCount() != 1) {
                return;
            }
            SwingUtilities.invokeLater(() -> {
                if (!this.bar.isIndeterminate() || (this.bar.isIndeterminate() && !this.bar.getString().equals("Installing Python"))) {
                    this.bar.setIndeterminate(true);
                    this.bar.setString("Installing Python");
                }
            });
        });
        installEnvWorker.execute();
        create.addToFrame(jDialog);
        jDialog.setSize(600, 300);
    }

    private void checkModelInstallationFinished(CountDownLatch countDownLatch) {
        if (countDownLatch.getCount() == 0) {
            startModelInstallation(false);
        }
    }

    private void startModelInstallation(boolean z) {
        SwingUtilities.invokeLater(() -> {
            this.runButton.setEnabled(!z);
            this.installButton.setEnabled(!z);
            this.modelComboBox.setEnabled(!z);
            this.minPercField.setEnabled(!z);
            this.maxPercField.setEnabled(!z);
            if (z) {
                this.bar.setString("Checking stardist installed...");
                this.bar.setIndeterminate(true);
            } else {
                this.bar.setIndeterminate(false);
                this.bar.setValue(0);
                this.bar.setString("");
            }
        });
    }

    private void browseFiles() {
        JFileChooser jFileChooser = new JFileChooser();
        jFileChooser.setFileSelectionMode(0);
        if (jFileChooser.showOpenDialog(this) == 0) {
            this.customModelPathField.setText(jFileChooser.getSelectedFile().getAbsolutePath());
        }
    }
}
