package net.imagej.ops.deconvolve.accelerate;

import net.imagej.ops.Op;
import net.imagej.ops.Ops;
import net.imagej.ops.special.function.Functions;
import net.imagej.ops.special.function.UnaryFunctionOp;
import net.imagej.ops.special.inplace.AbstractUnaryInplaceOp;
import net.imglib2.Cursor;
import net.imglib2.Dimensions;
import net.imglib2.FinalDimensions;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Util;
import net.imglib2.view.Views;
import org.scijava.plugin.Plugin;

@Plugin(type = Ops.Deconvolve.Accelerate.class, priority = 0.0d)
/* loaded from: input_file:net/imagej/ops/deconvolve/accelerate/VectorAccelerator.class */
public class VectorAccelerator<T extends RealType<T> & NativeType<T>> extends AbstractUnaryInplaceOp<RandomAccessibleInterval<T>> implements Ops.Deconvolve.Accelerate {
    Img<T> gk;
    Img<T> gkm1;
    private UnaryFunctionOp<Dimensions, Img<T>> create;
    ArrayImgFactory<T> factory;
    Img<T> xkm1_previous = null;
    Img<T> yk_prediction = null;
    Img<T> hk_vector = null;
    double accelerationFactor = 0.0d;

    @Override // net.imagej.ops.Initializable
    public void initialize() {
        super.initialize();
        this.factory = new ArrayImgFactory<>();
        this.create = Functions.unary(ops(), (Class<? extends Op>) Ops.Create.Img.class, Img.class, Dimensions.class, Util.getTypeFromInterval(out()), this.factory);
    }

    @Override // net.imagej.ops.special.inplace.UnaryInplaceOp
    public void mutate(RandomAccessibleInterval<T> randomAccessibleInterval) {
        accelerate(randomAccessibleInterval);
    }

    public void initialize(RandomAccessibleInterval<T> randomAccessibleInterval) {
        if (this.yk_prediction == null) {
            long[] jArr = new long[randomAccessibleInterval.numDimensions()];
            randomAccessibleInterval.dimensions(jArr);
            FinalDimensions finalDimensions = new FinalDimensions(jArr);
            this.yk_prediction = this.create.calculate(finalDimensions);
            this.xkm1_previous = this.create.calculate(finalDimensions);
            this.yk_prediction = this.create.calculate(finalDimensions);
            this.gk = this.create.calculate(finalDimensions);
            this.hk_vector = this.create.calculate(finalDimensions);
        }
    }

    public void accelerate(RandomAccessibleInterval<T> randomAccessibleInterval) {
        if (this.yk_prediction != null) {
            this.accelerationFactor = computeAccelerationFactor(randomAccessibleInterval);
            if (this.accelerationFactor < 0.0d) {
                this.gkm1 = null;
                this.accelerationFactor = 0.0d;
            }
            if (this.accelerationFactor > 1.0d) {
                this.accelerationFactor = 1.0d;
            }
        }
        if (this.accelerationFactor > 0.0d) {
            Subtract(randomAccessibleInterval, this.xkm1_previous, this.hk_vector);
            this.yk_prediction = AddAndScale(randomAccessibleInterval, this.hk_vector, (float) this.accelerationFactor);
        } else {
            initialize(randomAccessibleInterval);
            Copy(randomAccessibleInterval, this.yk_prediction);
        }
        Copy(randomAccessibleInterval, this.xkm1_previous);
        Copy(this.yk_prediction, randomAccessibleInterval);
    }

    double computeAccelerationFactor(RandomAccessibleInterval<T> randomAccessibleInterval) {
        Subtract(randomAccessibleInterval, this.yk_prediction, this.gk);
        if (this.gkm1 == null) {
            this.gkm1 = this.gk.copy();
            return 0.0d;
        }
        double DotProduct = DotProduct(this.gk, this.gkm1);
        double DotProduct2 = DotProduct(this.gkm1, this.gkm1);
        this.gkm1 = this.gk.copy();
        return DotProduct / DotProduct2;
    }

    public double DotProduct(Img<T> img, Img<T> img2) {
        Cursor<T> cursor = img.cursor();
        Cursor<T> cursor2 = img2.cursor();
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (!cursor.hasNext()) {
                return d2;
            }
            cursor.fwd();
            cursor2.fwd();
            d = d2 + (((RealType) cursor.get()).getRealFloat() * ((RealType) cursor2.get()).getRealFloat());
        }
    }

    protected void Copy(RandomAccessibleInterval<T> randomAccessibleInterval, RandomAccessibleInterval<T> randomAccessibleInterval2) {
        Cursor cursor = Views.iterable(randomAccessibleInterval).cursor();
        Cursor cursor2 = Views.iterable(randomAccessibleInterval2).cursor();
        while (cursor.hasNext()) {
            cursor.fwd();
            cursor2.fwd();
            ((RealType) cursor2.get()).set((Type) cursor.get());
        }
    }

    protected void Subtract(RandomAccessibleInterval<T> randomAccessibleInterval, RandomAccessibleInterval<T> randomAccessibleInterval2, RandomAccessibleInterval<T> randomAccessibleInterval3) {
        Cursor cursor = Views.iterable(randomAccessibleInterval).cursor();
        Cursor cursor2 = Views.iterable(randomAccessibleInterval2).cursor();
        Cursor cursor3 = Views.iterable(randomAccessibleInterval3).cursor();
        while (cursor.hasNext()) {
            cursor.fwd();
            cursor2.fwd();
            cursor3.fwd();
            ((RealType) cursor3.get()).set((Type) cursor.get());
            ((RealType) cursor3.get()).sub(cursor2.get());
        }
    }

    public Img<T> AddAndScale(RandomAccessibleInterval<T> randomAccessibleInterval, Img<T> img, float f) {
        Img<T> calculate = this.create.calculate(randomAccessibleInterval);
        Cursor cursor = Views.iterable(randomAccessibleInterval).cursor();
        Cursor<T> cursor2 = img.cursor();
        Cursor<T> cursor3 = calculate.cursor();
        while (cursor.hasNext()) {
            cursor.fwd();
            cursor2.fwd();
            cursor3.fwd();
            ((RealType) cursor3.get()).setReal(Math.max(((RealType) cursor.get()).getRealFloat() + (f * ((RealType) cursor2.get()).getRealFloat()), 1.0E-4f));
        }
        return calculate;
    }
}
