package io.bioimage.modelrunner.tensor;

import icy.math.UnitUtil;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.Converter;
import net.imglib2.converter.RealTypeConverters;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Intervals;
import net.imglib2.util.Util;

/* loaded from: input_file:io/bioimage/modelrunner/tensor/Tensor.class */
public final class Tensor<T extends RealType<T> & NativeType<T>> {
    private String tensorName;
    private int[] axesArray;
    private String axesString;
    private RandomAccessibleInterval<T> data;
    private boolean isImage;
    private boolean emptyTensor;
    private Type<T> dType;
    private int[] shape;
    private boolean closed = false;

    private Tensor(String str, String str2, RandomAccessibleInterval<T> randomAccessibleInterval) {
        this.isImage = true;
        Objects.requireNonNull(str, "'tensorName' field should not be empty");
        Objects.requireNonNull(str2, "'axes' field should not be empty");
        if (randomAccessibleInterval != null) {
            checkDims(randomAccessibleInterval, str2);
        }
        this.tensorName = str;
        this.axesString = str2;
        this.axesArray = convertToTensorDimOrder(str2);
        this.data = randomAccessibleInterval;
        if (randomAccessibleInterval != null) {
            setShape();
            this.dType = (Type) Util.getTypeFromInterval(randomAccessibleInterval);
            this.emptyTensor = false;
        } else {
            this.emptyTensor = true;
        }
        if (this.axesString.indexOf("x") == -1 && this.axesString.indexOf("y") == -1 && this.axesString.indexOf("z") == -1) {
            this.isImage = false;
        }
    }

    public static <T extends RealType<T> & NativeType<T>> Tensor<T> build(String str, String str2, RandomAccessibleInterval<T> randomAccessibleInterval) {
        if (randomAccessibleInterval == null) {
            throw new IllegalArgumentException("Trying to create tensor from an empty NDArray");
        }
        return new Tensor<>(str, str2, randomAccessibleInterval);
    }

    public static <T extends RealType<T> & NativeType<T>> Tensor<T> buildEmptyTensor(String str, String str2) {
        return new Tensor<>(str, str2, null);
    }

    public static <T extends RealType<T> & NativeType<T>> Tensor<T> buildBlankTensor(String str, String str2, long[] jArr, T t) {
        return new Tensor<>(str, str2, new ArrayImgFactory(t).create(jArr));
    }

    public void setData(RandomAccessibleInterval<T> randomAccessibleInterval) {
        throwExceptionIfClosed();
        if (randomAccessibleInterval == null && this.data != null) {
            this.data = null;
            return;
        }
        if (!this.emptyTensor) {
            checkDims(randomAccessibleInterval, this.axesString);
        }
        if (!this.emptyTensor && !equalShape(randomAccessibleInterval.dimensionsAsLongArray())) {
            throw new IllegalArgumentException("Trying to set an array as the backend of the Tensor with a different shape than the Tensor. Tensor shape is: " + Arrays.toString(this.shape) + " and array shape is: " + Arrays.toString(randomAccessibleInterval.dimensionsAsLongArray()));
        }
        if (!this.emptyTensor && this.data != null && getDataType().getClass() != ((RealType) Util.getTypeFromInterval(randomAccessibleInterval)).getClass()) {
            throw new IllegalArgumentException("Trying to set an array as the backend of the Tensor with a different data type than the Tensor. Tensor data type is: " + this.dType.toString() + " and array data type is: " + ((RealType) Util.getTypeFromInterval(randomAccessibleInterval)).toString());
        }
        if (this.data == null) {
            this.data = randomAccessibleInterval;
        } else {
            LoopBuilder.setImages(this.data, randomAccessibleInterval).multiThreaded().forEachPixel((realType, realType2) -> {
                realType.set(realType2);
            });
        }
        if (this.emptyTensor) {
            setShape();
            this.dType = (Type) Util.getTypeFromInterval(randomAccessibleInterval);
            this.emptyTensor = false;
        }
    }

    public RandomAccessibleInterval<T> getData() {
        throwExceptionIfClosed();
        if (this.data == null && isEmpty()) {
            throw new IllegalArgumentException("Tensor '" + this.tensorName + "' is empty.");
        }
        if (this.data == null) {
            throw new IllegalArgumentException("If you want to retrieve the tensor data as an NDArray, please first transform the tensor data into an NDArray using: TensorManager.buffer2array(tensor)");
        }
        return this.data;
    }

    public void copyTensorBackend(Tensor<T> tensor) {
        throwExceptionIfClosed();
        if (tensor.getData() != null) {
            copyRAITensorBackend(tensor);
        }
    }

    public void copyRAITensorBackend(Tensor<T> tensor) {
        throwExceptionIfClosed();
        setData(tensor.getData());
    }

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> Tensor<R> createCopyOfTensorInWantedDataType(Tensor<T> tensor, R r) {
        tensor.throwExceptionIfClosed();
        RandomAccessibleInterval<T> data = tensor.getData();
        Img create = Util.getArrayOrCellImgFactory(data, r).create(data);
        Converter converter = RealTypeConverters.getConverter((RealType) Util.getTypeFromInterval(data), (RealType) Util.getTypeFromInterval(create));
        LoopBuilder multiThreaded = LoopBuilder.setImages(data, create).multiThreaded(Intervals.numElements(create) >= 20000);
        Objects.requireNonNull(converter);
        multiThreaded.forEachPixel((obj, obj2) -> {
            converter.convert(obj, obj2);
        });
        return build(tensor.getName(), tensor.getAxesOrderString(), create);
    }

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> RandomAccessibleInterval<R> createCopyOfRaiInWantedDataType(RandomAccessibleInterval<T> randomAccessibleInterval, R r) {
        Img create = Util.getArrayOrCellImgFactory(randomAccessibleInterval, r).create(randomAccessibleInterval);
        Converter converter = RealTypeConverters.getConverter((RealType) Util.getTypeFromInterval(randomAccessibleInterval), (RealType) Util.getTypeFromInterval(create));
        LoopBuilder multiThreaded = LoopBuilder.setImages(randomAccessibleInterval, create).multiThreaded(Intervals.numElements(create) >= 20000);
        Objects.requireNonNull(converter);
        multiThreaded.forEachPixel((obj, obj2) -> {
            converter.convert(obj, obj2);
        });
        return create;
    }

    private void throwExceptionIfClosed() {
        if (this.closed) {
            throw new IllegalStateException("The tensor that is trying to be modified has already been closed.");
        }
    }

    public void close() {
        if (this.closed) {
            return;
        }
        try {
            this.closed = true;
            this.axesArray = null;
            if (this.data != null) {
                this.data = null;
            }
            this.data = null;
            this.axesString = null;
            this.dType = null;
            this.shape = null;
            this.tensorName = null;
        } catch (Exception e) {
            this.closed = false;
            throw new IllegalStateException(("Error trying to close tensor: " + this.tensorName + ". ") + e.toString());
        }
    }

    public static <T extends NativeType<T> & RealType<T>> Tensor<T> getTensorByNameFromList(List<Tensor<T>> list, String str) {
        return list.stream().filter(tensor -> {
            return (tensor.isClosed() || tensor.getName() == null || !tensor.getName().equals(str)) ? false : true;
        }).findAny().orElse(null);
    }

    private boolean equalShape(long[] jArr) {
        if (jArr.length != this.shape.length) {
            return false;
        }
        for (int i = 0; i < jArr.length; i++) {
            if (((int) jArr[i]) != this.shape[i]) {
                return false;
            }
        }
        return true;
    }

    public static int[] convertToTensorDimOrder(String str) throws IllegalArgumentException {
        String lowerCase = str.toLowerCase();
        int[] iArr = new int[lowerCase.length()];
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        for (int i8 = 0; i8 < lowerCase.length(); i8++) {
            switch (lowerCase.charAt(i8)) {
                case 'b':
                    iArr[i8] = 4;
                    i = 1;
                    break;
                case 'c':
                    iArr[i8] = 2;
                    i7++;
                    break;
                case 'd':
                case 'e':
                case 'f':
                case 'g':
                case 'h':
                case 'j':
                case 'k':
                case 'l':
                case 'm':
                case 'n':
                case 'o':
                case 'p':
                case 'q':
                case 's':
                case UnitUtil.MICRO_CHAR /* 117 */:
                case 'v':
                case 'w':
                default:
                    throw new IllegalArgumentException("Illegal axis for tensor dim order " + lowerCase + " (" + lowerCase.charAt(i8) + ")");
                case 'i':
                    iArr[i8] = 3;
                    i2 = 1;
                    break;
                case 'r':
                    iArr[i8] = 3;
                    i2 = 1;
                    break;
                case 't':
                    iArr[i8] = 4;
                    i3 = 1;
                    break;
                case 'x':
                    iArr[i8] = 0;
                    i4++;
                    break;
                case 'y':
                    iArr[i8] = 1;
                    i5++;
                    break;
                case 'z':
                    iArr[i8] = 3;
                    i6++;
                    break;
            }
        }
        if (i + i3 > 1) {
            throw new IllegalArgumentException("Tensor axes order can only have either one 'b' or one 't'. These axes are exclusive .");
        }
        if (i6 + 0 + i2 > 1) {
            throw new IllegalArgumentException("Tensor axes order can only have either one 'i', one 'z' or one 'r'.");
        }
        if (i5 > 1 || i4 > 1 || i7 > 1 || i6 > 1 || i3 > 1 || i2 > 1 || i > 1) {
            throw new IllegalArgumentException("There cannot be repeated dimensions in the axes order as this tensor has (" + lowerCase + ").");
        }
        return iArr;
    }

    private void setShape() {
        if (this.data == null) {
            throw new IllegalArgumentException("Trying to create tensor from an empty NDArray");
        }
        long[] dimensionsAsLongArray = this.data.dimensionsAsLongArray();
        this.shape = new int[dimensionsAsLongArray.length];
        for (int i = 0; i < this.shape.length; i++) {
            this.shape[i] = (int) dimensionsAsLongArray[i];
        }
    }

    public String getName() {
        throwExceptionIfClosed();
        return this.tensorName;
    }

    public int[] getShape() {
        throwExceptionIfClosed();
        return this.shape;
    }

    public String getAxesOrderString() {
        throwExceptionIfClosed();
        return this.axesString;
    }

    public int[] getAxesOrder() {
        throwExceptionIfClosed();
        return this.axesArray;
    }

    public void setIsImage(boolean z) {
        throwExceptionIfClosed();
        if (!z) {
            assertIsList();
        }
        this.isImage = z;
    }

    public boolean isImage() {
        throwExceptionIfClosed();
        return this.isImage;
    }

    public boolean isEmpty() {
        throwExceptionIfClosed();
        return this.emptyTensor;
    }

    public Type<T> getDataType() {
        throwExceptionIfClosed();
        return (Type) Util.getTypeFromInterval(this.data);
    }

    public boolean isClosed() {
        return this.closed;
    }

    private void checkDims(RandomAccessibleInterval<T> randomAccessibleInterval, String str) {
        if (randomAccessibleInterval.dimensionsAsLongArray().length != str.length()) {
            throw new IllegalArgumentException("The axes order introduced has to correspond to the same number of dimenensions that the NDArray has. In this case the axes order is specfied for " + str.length() + " dimensions while the array has " + randomAccessibleInterval.dimensionsAsLongArray().length + " dimensions.");
        }
    }

    private void assertIsList() {
        boolean z = this.axesString.toLowerCase().indexOf("x") != -1;
        boolean z2 = this.axesString.toLowerCase().indexOf("y") != -1;
        boolean z3 = this.axesString.toLowerCase().indexOf("t") != -1;
        boolean z4 = this.axesString.toLowerCase().indexOf("z") != -1;
        if (z || z2 || z3 || z4) {
            throw new IllegalArgumentException("Tensor '" + this.tensorName + "' cannot be represented as a ist because lists can only have the axes: 'b', 'i', 'c' and 'r'. The axes for this tensor are :" + this.axesString + ".");
        }
    }
}
