/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.bioimageio.tiling;

import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.tiling.PatchSpec;
import io.bioimage.modelrunner.bioimageio.tiling.TileGrid;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.tensor.Tensor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class TileMaker {
    private final List<TileInfo> inputTileInfo;
    private List<TileInfo> outputTileInfo;
    private final ModelDescriptor descriptor;
    private final LinkedHashMap<String, PatchSpec> input = new LinkedHashMap();
    private final LinkedHashMap<String, PatchSpec> output = new LinkedHashMap();
    private final LinkedHashMap<String, TileGrid> inputGrid = new LinkedHashMap();
    private final LinkedHashMap<String, TileGrid> outputGrid = new LinkedHashMap();

    private TileMaker(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
        this.descriptor = descriptor;
        this.inputTileInfo = tileInfoList;
        this.validate();
        this.calculate();
    }

    private TileMaker(List<TileInfo> inputTiles, List<TileInfo> outputTiles) {
        PatchSpec patch;
        this.inputTileInfo = inputTiles;
        this.outputTileInfo = outputTiles;
        this.descriptor = null;
        for (TileInfo tile : this.inputTileInfo) {
            patch = this.createPatch(tile);
            this.input.put(tile.getName(), patch);
            this.inputGrid.put(tile.getName(), TileGrid.create(patch));
        }
        TileInfo.adaptHalos(this.outputTileInfo);
        for (TileInfo tile : this.outputTileInfo) {
            patch = this.createPatch(tile);
            this.output.put(tile.getName(), patch);
            this.outputGrid.put(tile.getName(), TileGrid.create(patch));
        }
    }

    private PatchSpec createPatch(TileInfo tile) {
        int i2;
        long[] imSize = TileMaker.arrayToWantedAxesOrderAddOnes(tile.getImageDims(), tile.getImageAxesOrder(), tile.getTileAxesOrder());
        long[] tileSize = tile.getTileDims();
        int[][] paddingSize = new int[2][tileSize.length];
        long[] halo = TileMaker.arrayToWantedAxesOrderAddZeros(tile.getHalo(), tile.getHaloAxesOrder(), tile.getTileAxesOrder());
        for (i2 = 0; i2 < halo.length; ++i2) {
            paddingSize[0][i2] = (int)halo[i2];
        }
        for (i2 = 0; i2 < halo.length; ++i2) {
            paddingSize[1][i2] = (int)halo[i2];
        }
        int[] patchGridSize = new int[imSize.length];
        for (int i3 = 0; i3 < patchGridSize.length; ++i3) {
            patchGridSize[i3] = 1;
        }
        patchGridSize = IntStream.range(0, tileSize.length).map(i -> (int)Math.ceil((double)imSize[i] / ((double)tileSize[i] - (double)(halo[i] * 2L)))).toArray();
        paddingSize[0] = IntStream.range(0, tileSize.length).map(i -> (int)Math.max((double)paddingSize[0][i], Math.ceil((double)(tileSize[i] - imSize[i]) / 2.0))).toArray();
        paddingSize[1] = IntStream.range(0, tileSize.length).map(i -> (int)Math.max((long)paddingSize[1][i], tileSize[i] - imSize[i] - (long)paddingSize[0][i])).toArray();
        return PatchSpec.create(tile.getName(), tileSize, patchGridSize, paddingSize, imSize);
    }

    public static TileMaker build(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
        return new TileMaker(descriptor, tileInfoList);
    }

    public static TileMaker build(List<TileInfo> inputTiles, List<TileInfo> outputTiles) {
        return new TileMaker(inputTiles, outputTiles);
    }

    private void validate() {
        this.checkAllTensorsDefined();
        this.validateTileVsImageSize();
        this.validateStepMin();
        this.getOutputTiles();
        this.validateTileVsHalo();
        this.validateTileVsImageChannel();
        this.checkTilesCombine();
    }

    private void getOutputTiles() {
        this.outputTileInfo = new ArrayList<TileInfo>();
        for (TensorSpec tt : this.descriptor.getOutputTensors()) {
            String outAxesOrder = tt.getAxesOrder();
            long[] tileSize = new long[outAxesOrder.length()];
            long[] imagSize = new long[outAxesOrder.length()];
            int i = -1;
            for (Axis ax : tt.getAxesInfo().getAxesList()) {
                ++i;
                if (ax.getStep() == 0 && !this.descriptor.isTilingAllowed()) {
                    tileSize[i] = ax.getMin();
                    imagSize[i] = ax.getMin();
                    continue;
                }
                if (ax.getStep() != 0 && !this.descriptor.isTilingAllowed()) {
                    throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
                }
                if (ax.getStep() == 0 && ax.getMin() != 0 && ax.getReferenceTensor() == null) {
                    TensorSpec intt = this.descriptor.getInputTensors().stream().filter(t -> t.isImage()).findFirst().orElse(null);
                    TileInfo inTile = this.inputTileInfo.stream().filter(t -> t.getName().equals(intt.getName())).findFirst().orElse(null);
                    int indTile = inTile.getTileAxesOrder().indexOf(ax.getAxis());
                    int indIm = inTile.getImageAxesOrder().indexOf(ax.getAxis());
                    if (indTile == -1 || indIm == -1) {
                        imagSize[i] = ax.getMin();
                        tileSize[i] = ax.getMin();
                        continue;
                    }
                    double factor = (double)inTile.getImageDims()[indIm] / (double)inTile.getTileDims()[indTile];
                    if (Math.floor((double)ax.getMin() * factor) != (double)ax.getMin() * factor) {
                        throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
                    }
                    imagSize[i] = (long)((double)ax.getMin() * factor);
                    tileSize[i] = ax.getMin();
                    continue;
                }
                if (ax.getReferenceTensor() == null) {
                    throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
                }
                if (ax.getReferenceTensor() != null) {
                    TileInfo inTile = this.inputTileInfo.stream().filter(t -> t.getName().equals(ax.getReferenceTensor())).findFirst().orElse(null);
                    int indTile = inTile.getTileAxesOrder().indexOf(ax.getReferenceAxis());
                    int indIm = inTile.getImageAxesOrder().indexOf(ax.getReferenceAxis());
                    long imDim = inTile.getImageDims()[indIm];
                    long tileDim = inTile.getTileDims()[indTile];
                    imagSize[i] = (long)((double)imDim * ax.getScale() + ax.getOffset() * 2.0);
                    tileSize[i] = (long)((double)(tileDim == -1L ? imDim : tileDim) * ax.getScale() + ax.getOffset() * 2.0);
                    continue;
                }
                throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
            }
            this.outputTileInfo.add(TileInfo.build(tt.getName(), imagSize, outAxesOrder, tileSize, outAxesOrder));
        }
    }

    private void validateTileVsHalo() {
        for (TileInfo tile : this.outputTileInfo) {
            TensorSpec tt = this.descriptor.findOutputTensor(tile.getName());
            for (Axis ax : tt.getAxesInfo().getAxesList()) {
                int ind = tile.getImageAxesOrder().indexOf(ax.getAxis());
                if (tile.getTileDims()[ind] - (long)(ax.getHalo() * 2) > 0L) continue;
                throw new IllegalArgumentException("Input size too small, halo would be bigger than the image accross dimension '" + ax.getAxis() + "'. Toal halo = " + ax.getHalo() * 2 + ", image size = " + tile.getTileDims()[ind] + ".");
            }
        }
    }

    private void validateTileVsHalo2() {
        for (TensorSpec tt : this.descriptor.getOutputTensors()) {
            for (Axis ax : tt.getAxesInfo().getAxesList()) {
                double outSize;
                String ref = ax.getReferenceTensor();
                if (ref == null || ax.getMin() != 0 && ax.getStep() == 0) continue;
                TileInfo tile = this.inputTileInfo.stream().filter(til -> til.getName().equals(ref)).findFirst().orElse(null);
                if (tile == null) {
                    throw new IllegalArgumentException("Tile specs of input tensor '" + ref + "' not defined.");
                }
                String axisStr = ax.getReferenceAxis();
                String tileAxes = tile.getTileAxesOrder();
                int ind = tileAxes.indexOf(axisStr);
                long refSize = 1L;
                if (ind != -1) {
                    refSize = tile.getTileDims()[ind];
                }
                if (!((outSize = ax.getScale() * (double)refSize + ax.getOffset() * 2.0) - (double)(ax.getHalo() * 2) <= 0.0)) continue;
                throw new IllegalArgumentException("Input size too small, halo would be bigger than the image accross dimension '" + axisStr + "'. Toal halo = " + ax.getHalo() * 2 + ", image size = " + outSize + ".");
            }
        }
    }

    private void validateStepMin() {
        for (TileInfo tile : this.inputTileInfo) {
            TensorSpec tt = this.descriptor.findInputTensor(tile.getName());
            if (tt == null) continue;
            String axesTile = tile.getTileAxesOrder();
            long[] tileDims = tile.getTileDims();
            String axesTensor = tt.getAxesOrder();
            axesTile = TileMaker.addMissingAxes(axesTensor, axesTile);
            axesTensor = TileMaker.addMissingAxes(axesTile, axesTensor);
            tileDims = TileMaker.arrayToWantedAxesOrderAddOnes(tileDims, tile.getTileAxesOrder(), axesTile);
            int[] min = TileMaker.arrayToWantedAxesOrderAddOnes(tt.getMinTileSizeArr(), tt.getAxesOrder(), axesTile);
            int[] step = TileMaker.arrayToWantedAxesOrderAddZeros(tt.getTileStepArr(), tt.getAxesOrder(), axesTile);
            for (int i = 0; i < tileDims.length; ++i) {
                if (min[i] != -1 && tileDims[i] != (long)min[i] && step[i] == 0) {
                    throw new IllegalArgumentException("Invalid tile size for axis '" + axesTile.split("")[i].toUpperCase() + "'. Only allowed tile size for this axis is: " + min[i]);
                }
                if (step[i] == 0 || (tileDims[i] - (long)min[i]) % (long)step[i] == 0L) continue;
                throw new IllegalArgumentException("Invalid tile size for axis '" + axesTile.split("")[i].toUpperCase() + "'. Tile size for this axis should satisfy: " + min[i] + " + n x " + step[i] + " where n can be any positive integer.");
            }
        }
    }

    private void validateTileVsImageChannel() {
        for (TileInfo tile : this.inputTileInfo) {
            String tileAxes = tile.getTileAxesOrder();
            String imageAxes = tile.getImageAxesOrder();
            long[] tileSize = tile.getTileDims();
            long[] imSize = tile.getImageDims();
            int indTile = tileAxes.indexOf("c");
            int indIm = imageAxes.indexOf("c");
            if (indIm != -1 && indTile != -1 && tileSize[indTile] != imSize[indIm]) {
                throw new IllegalArgumentException("Tiling cannot happen accross the channel dimension. The tile number of channels (" + tileSize[indTile] + ") must be the same as the image number of channels (" + imSize[indIm] + ").");
            }
            if (indIm == -1 && tileSize[indTile] != 1L) {
                throw new IllegalArgumentException("Tiling cannot happen accross the channel dimension. The tile number of channels (" + tileSize[indTile] + ") must be the same as the image number of channels (" + 1 + ").");
            }
            if (indTile != -1 || imSize[indIm] == 1L) continue;
            throw new IllegalArgumentException("Tiling cannot happen accross the channel dimension. The tile number of channels (1) must be the same as the image number of channels (" + imSize[indIm] + ").");
        }
    }

    private void validateTileVsImageSize() throws IllegalArgumentException {
        for (TileInfo tile : this.inputTileInfo) {
            String axesTile = tile.getTileAxesOrder();
            String axesImage = tile.getImageAxesOrder();
            long[] tileDims = tile.getTileDims();
            TileMaker.checkAxisSize(tile);
            long[] imDims = TileMaker.arrayToWantedAxesOrderAddOnes(tile.getImageDims(), axesImage, axesTile);
            for (int i = 0; i < axesTile.length(); ++i) {
                int indIm = axesImage.indexOf(axesTile.split("")[i]);
                if (imDims[indIm] * 3L >= tileDims[i]) continue;
                throw new IllegalArgumentException("Error in the axes size selected. The axes size introduced in any of the dimensions cannot be bigger than 3 times the image size of that same axes. The image selected has " + axesTile.split("")[i] + "-dimension of size " + imDims[indIm] + "and the tile is of size " + tileDims[i] + ". Maxmum tile size for " + axesTile.split("")[i] + "-axis in this image is " + imDims[indIm] * 3L);
            }
        }
    }

    private static void checkAxisSize(TileInfo tile) {
        String axesTile = tile.getTileAxesOrder();
        long[] tileDims = tile.getTileDims();
        if (axesTile.length() != tileDims.length) {
            throw new IllegalArgumentException("The tile dimensions and tile axes should be of the same length: " + axesTile + " (" + axesTile.length() + ") vs " + Arrays.toString(tileDims) + " (" + tileDims.length + ")");
        }
        String axesImage = tile.getImageAxesOrder();
        long[] imDims = tile.getImageDims();
        if (axesImage.length() != imDims.length) {
            throw new IllegalArgumentException("The image dimensions and image axes should be of the same length: " + axesImage + " (" + axesImage.length() + ") vs " + Arrays.toString(imDims) + " (" + imDims.length + ")");
        }
    }

    private void checkTilesCombine() {
    }

    private void checkAllTensorsDefined() {
        for (TensorSpec tensor : this.descriptor.getInputTensors()) {
            TileInfo info = this.inputTileInfo.stream().filter(tt -> tt.getName().equals(tensor.getName())).findFirst().orElse(null);
            if (info != null) continue;
            throw new IllegalArgumentException("Tiling info for input tensor '" + tensor.getName() + "' not defined.");
        }
    }

    private void calculate() {
        PatchSpec patch;
        TileInfo tile;
        for (TensorSpec tt : this.descriptor.getInputTensors()) {
            tile = this.inputTileInfo.stream().filter(til -> til.getName().equals(tt.getName())).findFirst().orElse(null);
            patch = this.computePatchSpecs(tt, tile);
            this.input.put(tt.getName(), patch);
            this.inputGrid.put(tt.getName(), TileGrid.create(patch));
        }
        for (TensorSpec tt : this.descriptor.getOutputTensors()) {
            tile = this.outputTileInfo.stream().filter(til -> til.getName().equals(tt.getName())).findFirst().orElse(null);
            patch = this.computePatchSpecs(tt, tile);
            this.output.put(tt.getName(), patch);
            this.outputGrid.put(tt.getName(), TileGrid.create(patch));
        }
    }

    private PatchSpec computePatchSpecs(TensorSpec spec, TileInfo tile) {
        long[] imSize = TileMaker.arrayToWantedAxesOrderAddOnes(tile.getImageDims(), tile.getImageAxesOrder(), spec.getAxesInfo().getAxesOrder());
        long[] tileSize = TileMaker.arrayToWantedAxesOrderAddOnes(tile.getTileDims(), tile.getTileAxesOrder(), spec.getAxesInfo().getAxesOrder());
        int[][] paddingSize = new int[2][tileSize.length];
        int[] halo = spec.getHaloArr();
        if (!this.descriptor.isPyramidal() && this.descriptor.isTilingAllowed()) {
            int i2;
            for (i2 = 0; i2 < halo.length; ++i2) {
                paddingSize[0][i2] = (int)Math.ceil(halo[i2]);
            }
            for (i2 = 0; i2 < halo.length; ++i2) {
                paddingSize[1][i2] = (int)Math.floor(halo[i2]);
            }
        }
        int[] patchGridSize = new int[imSize.length];
        for (int i3 = 0; i3 < patchGridSize.length; ++i3) {
            patchGridSize[i3] = 1;
        }
        if (this.descriptor.isTilingAllowed()) {
            patchGridSize = IntStream.range(0, tileSize.length).map(i -> (int)Math.ceil((double)imSize[i] / ((double)tileSize[i] - (double)(halo[i] * 2)))).toArray();
        }
        paddingSize[0] = IntStream.range(0, tileSize.length).map(i -> (int)Math.max((double)paddingSize[0][i], Math.ceil((double)(tileSize[i] - imSize[i]) / 2.0))).toArray();
        paddingSize[1] = IntStream.range(0, tileSize.length).map(i -> (int)Math.max((long)paddingSize[1][i], tileSize[i] - imSize[i] - (long)paddingSize[0][i])).toArray();
        return PatchSpec.create(spec.getName(), tileSize, patchGridSize, paddingSize, imSize);
    }

    public int getNumberOfTiles() {
        return this.inputGrid.get(this.descriptor.getInputTensors().get(0).getName()).getRoiPostionsInImage().size();
    }

    public Map<String, Integer> getTilesPerAxis() {
        return null;
    }

    public void getInputInsertionPoints(String tensorName, int nTile) {
        TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Input tensor '" + tensorName + "' does not require tiling.");
        }
    }

    public void getOutputInsertionPoints(String tensorName, int nTile) {
        TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Output tensor '" + tensorName + "' does not require tiling.");
        }
    }

    public List<String> getInputTensorNames() {
        return this.descriptor.getInputTensors().stream().map(tt -> tt.getName()).collect(Collectors.toList());
    }

    public List<String> getOutputTensorNames() {
        return this.descriptor.getOutputTensors().stream().map(tt -> tt.getName()).collect(Collectors.toList());
    }

    public long[] getInputTileSize(String tensorName) {
        TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Input tensor '" + tensorName + "' does not require tiling.");
        }
        return tile.getTileDims();
    }

    public long[] getOutputTileSize(String tensorName) {
        TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Output tensor '" + tensorName + "' does not require tiling.");
        }
        return tile.getTileDims();
    }

    public int[] getInputRoiSize(String tensorName) {
        TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Input tensor '" + tensorName + "' does not require tiling.");
        }
        return this.inputGrid.get(tensorName).getRoiSize();
    }

    public int[] getOutputRoiSize(String tensorName) {
        TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Output tensor '" + tensorName + "' does not require tiling.");
        }
        return this.outputGrid.get(tensorName).getRoiSize();
    }

    public List<long[]> getTilePostionsInputImage(String tensorName) {
        TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Input tensor '" + tensorName + "' does not require tiling.");
        }
        return this.inputGrid.get(tensorName).getTilePostionsInImage();
    }

    public List<long[]> getTilePostionsOutputImage(String tensorName) {
        TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("Output tensor '" + tensorName + "' does not require tiling.");
        }
        return this.outputGrid.get(tensorName).getTilePostionsInImage();
    }

    public <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> getNthTileInput(String tensorName, RandomAccessibleInterval<T> rai, int n) {
        List<long[]> tiles = this.getTilePostionsInputImage(tensorName);
        if (tiles.size() <= n) {
            throw new IllegalArgumentException("There are only " + tiles.size() + " tiles. Tile " + n + " is out of bounds.");
        }
        long[] minLim = tiles.get(n);
        long[] size = this.getInputTileSize(tensorName);
        long[] maxLim = new long[size.length];
        for (int i = 0; i < size.length; ++i) {
            maxLim[i] = minLim[i] + size[i] - 1L;
        }
        IntervalView tileRai = Views.interval((RandomAccessible)Views.extendMirrorDouble(rai), (Interval)new FinalInterval(minLim, maxLim));
        return tileRai;
    }

    public <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> getNthTileOutput(String tensorName, RandomAccessibleInterval<T> rai, int n) {
        List<long[]> tiles = this.getTilePostionsOutputImage(tensorName);
        if (tiles.size() <= n) {
            throw new IllegalArgumentException("There are only " + tiles.size() + " tiles. Tile " + n + " is out of bounds.");
        }
        long[] minLim = tiles.get(n);
        int[] rois = this.getOutputRoiSize(tensorName);
        long[] size = this.getOutputTileSize(tensorName);
        int[][] padding = new int[2][rois.length];
        for (int i = 0; i < rois.length; ++i) {
            padding[0][i] = (int)Math.ceil((double)(size[i] - (long)rois[i]) / 2.0);
            padding[1][i] = (int)(size[i] - (long)rois[i] - (long)padding[0][i]);
        }
        long[] maxLim = new long[size.length];
        for (int i = 0; i < size.length; ++i) {
            maxLim[i] = minLim[i] + size[i] - 1L;
        }
        long[] minLimNoPad = new long[minLim.length];
        for (int i = 0; i < size.length; ++i) {
            minLimNoPad[i] = minLim[i] + (long)padding[0][i];
        }
        long[] maxLimNoPad = new long[maxLim.length];
        for (int i = 0; i < size.length; ++i) {
            maxLimNoPad[i] = maxLim[i] - (long)padding[1][i];
        }
        IntervalView tileRai = Views.interval(rai, (Interval)new FinalInterval(minLimNoPad, maxLimNoPad));
        IntervalView extendedTileRai = Views.interval((RandomAccessible)Views.extendZero((RandomAccessibleInterval)tileRai), (Interval)new FinalInterval(minLim, maxLim));
        return extendedTileRai;
    }

    public <T extends NativeType<T> & RealType<T>> Tensor<T> getNthTileInput(Tensor<T> tensor, int n) {
        RandomAccessibleInterval<T> rai = this.getNthTileInput(tensor.getName(), tensor.getData(), n);
        return Tensor.build(tensor.getName(), tensor.getAxesOrderString(), rai);
    }

    public <T extends NativeType<T> & RealType<T>> Tensor<T> getNthTileOutput(Tensor<T> tensor, int n) {
        RandomAccessibleInterval<T> rai = this.getNthTileOutput(tensor.getName(), tensor.getData(), n);
        return Tensor.build(tensor.getName(), tensor.getAxesOrderString(), rai);
    }

    public long[] getOutputImageSize(String tensorName) {
        TileInfo tile = this.outputTileInfo.stream().filter(tt -> tt.getName().equals(tensorName)).findFirst().orElse(null);
        if (tile == null) {
            throw new IllegalArgumentException("The tensor ID proposed does not correspond to an output tensor: '" + tensorName + "'.");
        }
        return tile.getImageDims();
    }

    public static long[] arrayToWantedAxesOrderAddOnes(long[] size, String orginalAxes, String targetAxes) {
        orginalAxes = orginalAxes.toLowerCase();
        String[] axesArr = targetAxes.toLowerCase().split("");
        long[] finalSize = new long[targetAxes.length()];
        for (int i = 0; i < finalSize.length; ++i) {
            int ind = orginalAxes.indexOf(axesArr[i]);
            finalSize[i] = ind == -1 ? 1L : size[ind];
        }
        return finalSize;
    }

    public static int[] arrayToWantedAxesOrderAddOnes(int[] size, String orginalAxes, String targetAxes) {
        orginalAxes = orginalAxes.toLowerCase();
        String[] axesArr = targetAxes.toLowerCase().split("");
        int[] finalSize = new int[targetAxes.length()];
        for (int i = 0; i < finalSize.length; ++i) {
            int ind = orginalAxes.indexOf(axesArr[i]);
            finalSize[i] = ind == -1 ? 1 : size[ind];
        }
        return finalSize;
    }

    public static float[] arrayToWantedAxesOrderAddOnes(float[] size, String orginalAxes, String targetAxes) {
        orginalAxes = orginalAxes.toLowerCase();
        String[] axesArr = targetAxes.toLowerCase().split("");
        float[] finalSize = new float[targetAxes.length()];
        for (int i = 0; i < finalSize.length; ++i) {
            int ind = orginalAxes.indexOf(axesArr[i]);
            finalSize[i] = ind == -1 ? 1.0f : size[ind];
        }
        return finalSize;
    }

    public static float[] arrayToWantedAxesOrderAddZeros(float[] size, String orginalAxes, String targetAxes) {
        orginalAxes = orginalAxes.toLowerCase();
        String[] axesArr = targetAxes.toLowerCase().split("");
        float[] finalSize = new float[targetAxes.length()];
        for (int i = 0; i < finalSize.length; ++i) {
            int ind = orginalAxes.indexOf(axesArr[i]);
            if (ind == -1) continue;
            finalSize[i] = size[ind];
        }
        return finalSize;
    }

    public static long[] arrayToWantedAxesOrderAddZeros(long[] size, String orginalAxes, String targetAxes) {
        orginalAxes = orginalAxes.toLowerCase();
        String[] axesArr = targetAxes.toLowerCase().split("");
        long[] finalSize = new long[targetAxes.length()];
        for (int i = 0; i < finalSize.length; ++i) {
            int ind = orginalAxes.indexOf(axesArr[i]);
            if (ind == -1) continue;
            finalSize[i] = size[ind];
        }
        return finalSize;
    }

    public static int[] arrayToWantedAxesOrderAddZeros(int[] size, String orginalAxes, String targetAxes) {
        orginalAxes = orginalAxes.toLowerCase();
        String[] axesArr = targetAxes.toLowerCase().split("");
        int[] finalSize = new int[targetAxes.length()];
        for (int i = 0; i < finalSize.length; ++i) {
            int ind = orginalAxes.indexOf(axesArr[i]);
            if (ind == -1) continue;
            finalSize[i] = size[ind];
        }
        return finalSize;
    }

    public static String addMissingAxes(String axes1, String axes2) {
        for (String ax : axes1.split("")) {
            if (ax.equals("b") && axes2.indexOf(ax) == -1 && axes2.indexOf("t") != -1) continue;
            if (ax.equals("t") && axes2.indexOf(ax) == -1 && axes2.indexOf("b") != -1) {
                axes2 = axes2 + ax;
                continue;
            }
            if (axes2.indexOf(ax) != -1) continue;
            axes2 = axes2 + ax;
        }
        return axes2;
    }
}

