/*
 * Decompiled with CFR 0.152.
 */
package mitiv.conv;

import mitiv.array.ShapedArray;
import mitiv.base.Shape;
import mitiv.conv.ConvolutionDouble;
import mitiv.conv.ConvolutionDouble3D;
import mitiv.conv.WeightedConvolutionDouble;
import mitiv.linalg.Vector;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.ShapedVector;

class WeightedConvolutionDouble3D
extends WeightedConvolutionDouble {
    private final int dim1;
    private final int off1;
    private final int end1;
    private final int dim2;
    private final int off2;
    private final int end2;
    private final int dim3;
    private final int off3;
    private final int end3;

    public WeightedConvolutionDouble3D(ConvolutionDouble3D cnvl) {
        super(cnvl.getInputSpace(), cnvl.getOutputSpace());
        this.cnvl = cnvl;
        Shape workShape = cnvl.workShape;
        Shape dataShape = cnvl.getOutputSpace().getShape();
        int[] dataOffsets = cnvl.outputOffsets;
        this.dim1 = workShape.dimension(0);
        this.off1 = dataOffsets[0];
        this.end1 = this.off1 + dataShape.dimension(0);
        this.dim2 = workShape.dimension(1);
        this.off2 = dataOffsets[1];
        this.end2 = this.off2 + dataShape.dimension(1);
        this.dim3 = workShape.dimension(2);
        this.off3 = dataOffsets[2];
        this.end3 = this.off3 + dataShape.dimension(2);
    }

    @Override
    protected double _cost(double alpha, Vector x) {
        this.checkSetup();
        this.cnvl.push((ShapedVector)x, false);
        this.cnvl.convolve(false);
        double sum = 0.0;
        double[] z = ((ConvolutionDouble)this.cnvl).getWorkArray();
        int j = 0;
        if (this.wgt == null) {
            for (int i3 = this.off3; i3 < this.end3; ++i3) {
                for (int i2 = this.off2; i2 < this.end2; ++i2) {
                    int k = 2 * (this.off1 + this.dim1 * (i2 + this.dim2 * i3));
                    for (int i1 = this.off1; i1 < this.end1; ++i1) {
                        double r = z[k] - this.dat[j];
                        sum += r * r;
                        ++j;
                        k += 2;
                    }
                }
            }
        } else {
            for (int i3 = this.off3; i3 < this.end3; ++i3) {
                for (int i2 = this.off2; i2 < this.end2; ++i2) {
                    int k = 2 * (this.off1 + this.dim1 * (i2 + this.dim2 * i3));
                    for (int i1 = this.off1; i1 < this.end1; ++i1) {
                        double w = this.wgt[j];
                        double r = z[k] - this.dat[j];
                        sum += w * r * r;
                        ++j;
                        k += 2;
                    }
                }
            }
        }
        return alpha * sum / 2.0;
    }

    @Override
    protected double _cost(double alpha, Vector x, Vector gx, boolean clr) {
        int i1;
        int i2;
        int i3;
        this.checkSetup();
        this.cnvl.push((ShapedVector)x, false);
        this.cnvl.convolve(false);
        boolean weighted = this.wgt != null;
        double zero = 0.0;
        double q = alpha;
        double sum = 0.0;
        double[] z = ((ConvolutionDouble)this.cnvl).getWorkArray();
        int j = 0;
        int k = 0;
        for (i3 = 0; i3 < this.off3; ++i3) {
            for (i2 = 0; i2 < this.dim2; ++i2) {
                for (i1 = 0; i1 < this.dim1; ++i1) {
                    z[k] = 0.0;
                    z[k + 1] = 0.0;
                    k += 2;
                }
            }
        }
        for (i3 = this.off3; i3 < this.end3; ++i3) {
            for (i2 = 0; i2 < this.off2; ++i2) {
                for (i1 = 0; i1 < this.dim1; ++i1) {
                    z[k] = 0.0;
                    z[k + 1] = 0.0;
                    k += 2;
                }
            }
            for (i2 = this.off2; i2 < this.end2; ++i2) {
                for (i1 = 0; i1 < this.off1; ++i1) {
                    z[k] = 0.0;
                    z[k + 1] = 0.0;
                    k += 2;
                }
                if (weighted) {
                    for (i1 = this.off1; i1 < this.end1; ++i1) {
                        double w = this.wgt[j];
                        double r = z[k] - this.dat[j];
                        double wr = w * r;
                        sum += r * wr;
                        z[k] = q * wr;
                        z[k + 1] = 0.0;
                        ++j;
                        k += 2;
                    }
                } else {
                    for (i1 = this.off1; i1 < this.end1; ++i1) {
                        double r = z[k] - this.dat[j];
                        sum += r * r;
                        z[k] = q * r;
                        z[k + 1] = 0.0;
                        ++j;
                        k += 2;
                    }
                }
                for (i1 = this.end1; i1 < this.dim1; ++i1) {
                    z[k] = 0.0;
                    z[k + 1] = 0.0;
                    k += 2;
                }
            }
            for (i2 = this.end2; i2 < this.dim2; ++i2) {
                for (i1 = 0; i1 < this.dim1; ++i1) {
                    z[k] = 0.0;
                    z[k + 1] = 0.0;
                    k += 2;
                }
            }
        }
        for (i3 = this.end3; i3 < this.dim3; ++i3) {
            for (i2 = 0; i2 < this.dim2; ++i2) {
                for (i1 = 0; i1 < this.dim1; ++i1) {
                    z[k] = 0.0;
                    z[k + 1] = 0.0;
                    k += 2;
                }
            }
        }
        double[] g = ((DoubleShapedVector)gx).getData();
        this.cnvl.convolve(true);
        if (clr) {
            j = 0;
            k = 0;
            while (j < g.length) {
                g[j] = z[k];
                ++j;
                k += 2;
            }
        } else {
            j = 0;
            k = 0;
            while (j < g.length) {
                int n = j++;
                g[n] = g[n] + z[k];
                k += 2;
            }
        }
        return alpha * sum / 2.0;
    }

    @Override
    public void setPSF(ShapedArray psf, int[] off, boolean normalize) {
        this.cnvl.setPSF(psf, off, normalize);
    }

    @Override
    public void setPSF(ShapedVector psf) {
        this.cnvl.setPSF(psf);
    }

    @Override
    public ShapedVector getModel(ShapedVector x) {
        this.checkSetup();
        ShapedVector dst = this.cnvl.getOutputSpace().create();
        if (x != null) {
            this.cnvl.push(x, false);
            this.cnvl.convolve(false);
        }
        this.cnvl.pull(dst, false);
        return dst;
    }
}

