package mitiv.optim;

import mitiv.exception.IllegalLinearOperationException;
import mitiv.exception.IncorrectSpaceException;
import mitiv.linalg.LinearEndomorphism;
import mitiv.linalg.LinearOperator;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;

/* loaded from: input_file:mitiv/optim/LBFGSOperator.class */
public class LBFGSOperator extends LinearEndomorphism {
    static final int NO_SCALING = 0;
    static final int OREN_SPEDICATO_SCALING = 1;
    static final int BARZILAI_BORWEIN_SCALING = 2;
    static final int CONSTANT_SCALING = 4;
    protected Vector[] s;
    protected Vector[] y;
    protected final int m;
    protected int mp;
    protected int updates;
    protected double[] rho;
    protected double gamma;
    protected double[] alpha;
    protected LinearOperator H0;
    protected int rule;

    public LBFGSOperator(VectorSpace vectorSpace, int i) {
        super(vectorSpace);
        this.rule = 1;
        this.m = i;
        this.H0 = null;
        allocateWorkspace();
    }

    public LBFGSOperator(LinearEndomorphism linearEndomorphism, int i) {
        super(linearEndomorphism.getSpace());
        this.rule = 1;
        this.m = i;
        this.H0 = linearEndomorphism;
        allocateWorkspace();
    }

    private void allocateWorkspace() {
        this.s = new Vector[this.m];
        this.y = new Vector[this.m];
        for (int i = 0; i < this.m; i++) {
            this.s[i] = this.space.create();
            this.y[i] = this.space.create();
        }
        this.alpha = new double[this.m];
        this.rho = new double[this.m];
        this.mp = 0;
        this.updates = 0;
        this.gamma = 1.0d;
    }

    public void reset() {
        this.mp = 0;
    }

    public void setScaling(int i) {
        this.rule = i;
    }

    public int getScaling() {
        return this.rule;
    }

    public void setScale(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("scale factor must be strictly positive");
        }
        this.gamma = d;
        this.rule = 4;
    }

    public double getScale() {
        return this.gamma;
    }

    protected int slot(int i) {
        if (i < 0 || i > this.mp) {
            throw new IndexOutOfBoundsException("BFGS slot index is out of bounds");
        }
        return (this.updates - i) % this.m;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Vector s(int i) {
        return this.s[slot(i)];
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Vector y(int i) {
        return this.y[slot(i)];
    }

    private boolean applyInPlace(Vector vector) {
        for (int i = 1; i <= this.mp; i++) {
            int slot = slot(i);
            if (this.rho[slot] > 0.0d) {
                this.alpha[slot] = this.rho[slot] * vector.dot(this.s[slot]);
                vector.add(-this.alpha[slot], this.y[slot]);
            }
        }
        if (this.H0 != null) {
            this.H0.apply(vector, vector);
        } else if (this.gamma != 1.0d) {
            vector.scale(this.gamma);
        }
        for (int i2 = this.mp; i2 >= 1; i2--) {
            int slot2 = slot(i2);
            if (this.rho[slot2] > 0.0d) {
                vector.add(this.alpha[slot2] - (this.rho[slot2] * vector.dot(this.y[slot2])), this.s[slot2]);
            }
        }
        return true;
    }

    private boolean applyInPlace(Vector vector, Vector vector2) {
        boolean z = this.H0 == null;
        this.gamma = 0.0d;
        for (int i = 1; i <= this.mp; i++) {
            int slot = slot(i);
            double dot = vector.dot(this.y[slot], this.s[slot]);
            if (dot <= 0.0d) {
                this.rho[slot] = 0.0d;
            } else {
                this.rho[slot] = 1.0d / dot;
                this.alpha[slot] = this.rho[slot] * vector.dot(vector2, this.s[slot]);
                vector2.add(-this.alpha[slot], this.y[slot]);
                if (z) {
                    double dot2 = vector.dot(this.y[slot], this.y[slot]);
                    if (dot2 > 0.0d) {
                        this.gamma = dot / dot2;
                        z = false;
                    }
                }
            }
        }
        if (z) {
            return false;
        }
        if (this.H0 != null) {
            this.H0.apply(vector2, vector2);
        } else if (this.gamma != 1.0d) {
            vector2.scale(this.gamma);
        }
        for (int i2 = this.mp; i2 >= 1; i2--) {
            int slot2 = slot(i2);
            if (this.rho[slot2] > 0.0d) {
                vector2.add(this.alpha[slot2] - (this.rho[slot2] * vector.dot(vector2, this.y[slot2])), this.s[slot2]);
            }
        }
        return true;
    }

    public boolean apply(Vector vector, Vector vector2, Vector vector3) {
        if ((vector != null && !vector.belongsTo(this.space)) || !vector2.belongsTo(this.space) || !vector3.belongsTo(this.space)) {
            throw new IncorrectSpaceException();
        }
        if (this.mp < 1) {
            return false;
        }
        if (vector2 != vector3) {
            vector3.copy(vector2);
        }
        return vector == null ? applyInPlace(vector2) : applyInPlace(vector, vector2);
    }

    @Override // mitiv.linalg.LinearOperator
    protected void _apply(Vector vector, Vector vector2, int i) {
        if (i != DIRECT && i != ADJOINT) {
            throw new IllegalLinearOperationException();
        }
        if (vector != vector2) {
            vector.copy(vector2);
        }
        applyInPlace(vector);
    }

    public void update(Vector vector, Vector vector2, Vector vector3, Vector vector4) throws IncorrectSpaceException {
        update(vector, vector2, vector3, vector4, false);
    }

    public void update(Vector vector, Vector vector2, Vector vector3, Vector vector4, boolean z) throws IncorrectSpaceException {
        int slot = slot(0);
        this.s[slot].combine(1.0d, vector, -1.0d, vector2);
        this.y[slot].combine(1.0d, vector3, -1.0d, vector4);
        if (z) {
            this.rho[slot] = 0.0d;
            this.gamma = 0.0d;
        } else {
            double dot = this.s[slot].dot(this.y[slot]);
            if (dot <= 0.0d) {
                this.rho[slot] = 0.0d;
                return;
            }
            this.rho[slot] = 1.0d / dot;
            if (this.rule == 1 || (this.rule == 5 && this.updates == 0)) {
                double norm2 = this.y[slot].norm2();
                this.gamma = (dot / norm2) / norm2;
            } else if (this.rule == 2 || (this.rule == 6 && this.updates == 0)) {
                double norm22 = this.s[slot].norm2();
                this.gamma = (norm22 / dot) * norm22;
            }
        }
        this.updates++;
        if (this.mp < this.m) {
            this.mp++;
        }
    }
}
