/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.processing;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.TransformSpec;
import io.bioimage.modelrunner.model.processing.TransformationInstance;
import io.bioimage.modelrunner.tensor.Tensor;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;

public class Processing {
    private ModelDescriptor descriptor;
    private Map<String, List<TransformationInstance>> preMap;
    private Map<String, List<TransformationInstance>> postMap;
    private static String BIOIMAGEIO_PYTHON_TRANSFORMATIONS_WEB = "https://github.com/bioimage-io/core-bioimage-io-python/blob/b0ceac8fa5b412b1ea811c442697de2150fa1b90/bioimageio/core/prediction_pipeline/_processing.py#L105";

    private Processing(ModelDescriptor descriptor) throws IllegalArgumentException, RuntimeException {
        this.descriptor = descriptor;
        this.buildPreprocessing();
        this.buildPostprocessing();
    }

    private void buildPreprocessing() throws IllegalArgumentException, RuntimeException {
        this.preMap = new LinkedHashMap<String, List<TransformationInstance>>();
        for (TensorSpec tt : this.descriptor.getInputTensors()) {
            List<TransformSpec> preprocessing = tt.getPreprocessing();
            ArrayList<TransformationInstance> list = new ArrayList<TransformationInstance>();
            for (TransformSpec transformation : preprocessing) {
                if (transformation.getName() == null) continue;
                list.add(TransformationInstance.create(transformation));
            }
            this.preMap.put(tt.getName(), list);
        }
    }

    private void buildPostprocessing() throws IllegalArgumentException, RuntimeException {
        this.postMap = new LinkedHashMap<String, List<TransformationInstance>>();
        for (TensorSpec tt : this.descriptor.getOutputTensors()) {
            List<TransformSpec> preprocessing = tt.getPostprocessing();
            ArrayList<TransformationInstance> list = new ArrayList<TransformationInstance>();
            for (TransformSpec transformation : preprocessing) {
                if (transformation.getName() == null) continue;
                list.add(TransformationInstance.create(transformation));
            }
            this.postMap.put(tt.getName(), list);
        }
    }

    public static Processing init(ModelDescriptor descriptor) throws IllegalArgumentException, RuntimeException {
        return new Processing(descriptor);
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<R>> preprocess(List<Tensor<T>> tensorList) {
        return this.preprocess(tensorList, false);
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<R>> preprocess(List<Tensor<T>> tensorList, boolean inplace) {
        ArrayList<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
        if (this.preMap.entrySet().size() == 0) {
            return (List)Cast.unchecked(tensorList);
        }
        for (Map.Entry<String, List<TransformationInstance>> ee : this.preMap.entrySet()) {
            int index = IntStream.range(0, tensorList.size()).filter(i -> ((Tensor)tensorList.get(i)).getName().equals(ee.getKey())).findFirst().orElse(-1);
            if (index == -1) continue;
            if (ee.getValue().size() == 0) {
                outputs.add((Tensor)Cast.unchecked(tensorList.get(index)));
            }
            for (TransformationInstance trans : ee.getValue()) {
                List outList = trans.run(tensorList.get(index), inplace);
                int index2 = IntStream.range(0, outList.size()).filter(i -> ((Tensor)outList.get(i)).getName().equals(((Tensor)tensorList.get(index)).getName())).findFirst().orElse(-1);
                if (index2 != -1) {
                    tensorList.set(index, outList.get(index2));
                }
                for (int j = 0; j < outList.size(); ++j) {
                    boolean found = false;
                    for (int k = 0; k < outputs.size(); ++k) {
                        if (!((Tensor)outputs.get(k)).getName().equals(outList.get(j).getName())) continue;
                        found = true;
                        outputs.set(k, outList.get(j));
                        break;
                    }
                    if (found) continue;
                    outputs.add(outList.get(j));
                }
            }
        }
        return outputs;
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<R>> postprocess(List<Tensor<T>> tensorList) {
        return this.postprocess(tensorList, false);
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<R>> postprocess(List<Tensor<T>> tensorList, boolean inplace) {
        ArrayList<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
        if (this.postMap.entrySet().size() == 0) {
            return (List)Cast.unchecked(tensorList);
        }
        for (Map.Entry<String, List<TransformationInstance>> ee : this.postMap.entrySet()) {
            int index = IntStream.range(0, tensorList.size()).filter(i -> ((Tensor)tensorList.get(i)).getName().equals(ee.getKey())).findFirst().orElse(-1);
            if (index == -1) continue;
            if (ee.getValue().size() == 0) {
                outputs.add((Tensor)Cast.unchecked(tensorList.get(index)));
            }
            for (TransformationInstance trans : ee.getValue()) {
                List outList = trans.run(tensorList.get(index), inplace);
                int index2 = IntStream.range(0, outList.size()).filter(i -> ((Tensor)outList.get(i)).getName().equals(((Tensor)tensorList.get(index)).getName())).findFirst().orElse(-1);
                if (index2 != -1) {
                    tensorList.set(index, outList.get(index2));
                }
                for (int j = 0; j < outList.size(); ++j) {
                    boolean found = false;
                    for (int k = 0; k < outputs.size(); ++k) {
                        if (!((Tensor)outputs.get(k)).getName().equals(outList.get(j).getName())) continue;
                        found = true;
                        outputs.set(k, outList.get(j));
                        break;
                    }
                    if (found) continue;
                    outputs.add(outList.get(j));
                }
            }
        }
        return outputs;
    }
}

