/*
 * Decompiled with CFR 0.152.
 */
package mitiv.conv;

import mitiv.array.ShapedArray;
import mitiv.base.Shape;
import mitiv.conv.Convolution;
import mitiv.conv.ConvolutionDouble1D;
import mitiv.conv.ConvolutionDouble2D;
import mitiv.conv.ConvolutionDouble3D;
import mitiv.conv.ConvolutionFloat1D;
import mitiv.conv.ConvolutionFloat2D;
import mitiv.conv.ConvolutionFloat3D;
import mitiv.conv.WeightedConvolutionDouble1D;
import mitiv.conv.WeightedConvolutionDouble2D;
import mitiv.conv.WeightedConvolutionDouble3D;
import mitiv.conv.WeightedConvolutionFloat1D;
import mitiv.conv.WeightedConvolutionFloat2D;
import mitiv.conv.WeightedConvolutionFloat3D;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.cost.WeightedData;
import mitiv.exception.IllegalTypeException;
import mitiv.linalg.Vector;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;
import mitiv.utils.Timer;

public abstract class WeightedConvolutionCost
extends WeightedData
implements DifferentiableCostFunction {
    protected Convolution cnvl;
    protected final ShapedVectorSpace objectSpace;
    protected Timer timerForFFT = new Timer();
    protected Timer timer = new Timer();

    protected WeightedConvolutionCost(ShapedVectorSpace objectSpace, ShapedVectorSpace dataSpace) {
        super(dataSpace);
        this.objectSpace = objectSpace;
    }

    public ShapedVectorSpace getObjectSpace() {
        return this.objectSpace;
    }

    @Override
    public ShapedVectorSpace getInputSpace() {
        return this.objectSpace;
    }

    public static WeightedConvolutionCost build(Convolution cnvl) {
        switch (cnvl.getType()) {
            case 4: {
                switch (cnvl.getRank()) {
                    case 1: {
                        return new WeightedConvolutionFloat1D((ConvolutionFloat1D)cnvl);
                    }
                    case 2: {
                        return new WeightedConvolutionFloat2D((ConvolutionFloat2D)cnvl);
                    }
                    case 3: {
                        return new WeightedConvolutionFloat3D((ConvolutionFloat3D)cnvl);
                    }
                }
                break;
            }
            case 5: {
                switch (cnvl.getRank()) {
                    case 1: {
                        return new WeightedConvolutionDouble1D((ConvolutionDouble1D)cnvl);
                    }
                    case 2: {
                        return new WeightedConvolutionDouble2D((ConvolutionDouble2D)cnvl);
                    }
                    case 3: {
                        return new WeightedConvolutionDouble3D((ConvolutionDouble3D)cnvl);
                    }
                }
                break;
            }
            default: {
                throw new IllegalTypeException("Only float and double types are implemented");
            }
        }
        throw new IllegalArgumentException("Only 1D, 2D and 3D convolution are implemented");
    }

    public static WeightedConvolutionCost build(ShapedVectorSpace space) {
        return WeightedConvolutionCost.build(Convolution.build(space));
    }

    public static WeightedConvolutionCost build(ShapedVectorSpace objectSpace, ShapedVectorSpace dataSpace) {
        return WeightedConvolutionCost.build(Convolution.build(objectSpace, dataSpace));
    }

    public static WeightedConvolutionCost build(Shape wrk, ShapedVectorSpace inp, ShapedVectorSpace out) {
        return WeightedConvolutionCost.build(Convolution.build(wrk, inp, out));
    }

    public static WeightedConvolutionCost build(Shape wrk, ShapedVectorSpace inp, int[] inpOff, ShapedVectorSpace out, int[] outOff) {
        return WeightedConvolutionCost.build(Convolution.build(wrk, inp, inpOff, out, outOff));
    }

    private final void checkObject(Vector x) {
        if (!x.belongsTo(this.objectSpace)) {
            throw new IllegalArgumentException("Variables X does not belong to the object space");
        }
    }

    private final void checkGradient(Vector gx) {
        if (!gx.belongsTo(this.objectSpace)) {
            throw new IllegalArgumentException("Gradient GX does not belong to the object space");
        }
    }

    @Override
    public double evaluate(double alpha, Vector x) {
        this.checkObject(x);
        if (alpha == 0.0) {
            return 0.0;
        }
        return this._cost(alpha, x);
    }

    @Override
    public double computeCostAndGradient(double alpha, Vector x, Vector gx, boolean clr) {
        this.checkObject(x);
        this.checkGradient(gx);
        if (alpha == 0.0) {
            if (clr) {
                gx.zero();
            }
            return 0.0;
        }
        return this._cost(alpha, x, gx, clr);
    }

    protected abstract double _cost(double var1, Vector var3);

    protected abstract double _cost(double var1, Vector var3, Vector var4, boolean var5);

    public abstract void setPSF(ShapedVector var1);

    public void setPSF(ShapedArray psf) {
        this.setPSF(psf, null, false);
    }

    public void setPSF(ShapedArray psf, int[] off) {
        this.setPSF(psf, off, false);
    }

    public void setPSF(ShapedArray psf, boolean normalize) {
        this.setPSF(psf, null, normalize);
    }

    public abstract void setPSF(ShapedArray var1, int[] var2, boolean var3);

    public abstract ShapedVector getModel(ShapedVector var1);

    public ShapedVector getModel() {
        return this.getModel((ShapedVector)null);
    }

    public ShapedVector getModel(ShapedArray objArray) {
        return this.getModel(this.objectSpace.create(objArray));
    }

    protected static int outputOffset(int rank, Shape inputShape, Shape outputShape, int[] offset) {
        if (inputShape.rank() != rank) {
            throw new IllegalArgumentException("Bad rank for input space");
        }
        if (outputShape.rank() != rank) {
            throw new IllegalArgumentException("Bad rank for output space");
        }
        if (offset != null && offset.length != rank) {
            throw new IllegalArgumentException("Bad number of coordinates for the first position");
        }
        int totalOffset = 0;
        int stride = 1;
        for (int k = 0; k < rank; ++k) {
            int thisOffset;
            int inpDim = inputShape.dimension(k);
            int outDim = outputShape.dimension(k);
            if (outDim > inpDim) {
                throw new IllegalArgumentException("Output dimensions must be at most as large as input dimensions");
            }
            if (offset == null) {
                thisOffset = inpDim / 2 - outDim / 2;
            } else {
                thisOffset = offset[k];
                if (thisOffset < 0 || thisOffset + outDim > inpDim) {
                    throw new IllegalArgumentException("Output region is outside bounds");
                }
            }
            totalOffset += stride * thisOffset;
            stride *= inpDim;
        }
        return totalOffset;
    }

    public void resetTimers() {
        this.timerForFFT.stop();
        this.timerForFFT.reset();
        this.timer.stop();
        this.timer.reset();
    }

    public double getElapsedTime() {
        return this.timer.getElapsedTime();
    }

    public double getElapsedTimeInFFT() {
        return this.timerForFFT.getElapsedTime();
    }
}

