package plugins.nherve.toolbox.image.feature;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import plugins.nherve.matrix.Matrix;
import plugins.nherve.toolbox.image.feature.signature.DenseVectorSignature;
import plugins.nherve.toolbox.image.feature.signature.SignatureException;
import plugins.nherve.toolbox.image.feature.signature.VectorSignature;

/* loaded from: input_file:plugins/nherve/toolbox/image/feature/LDA.class */
public class LDA extends DimensionReductionAlgorithm {
    private List<Integer> classes;
    private Matrix invPooled;
    private Matrix constStuff;
    private HashMap<Integer, Matrix> classesMean;
    int nbGroups;

    public LDA(List<VectorSignature> list, List<Integer> list2) {
        super(list);
        this.classes = list2;
        this.invPooled = null;
        this.classesMean = null;
        this.constStuff = null;
    }

    @Override // plugins.nherve.toolbox.image.feature.DimensionReductionAlgorithm
    public void compute() throws SignatureException {
        check();
        Matrix matrix = getMatrix(this.signatures);
        Matrix mean = getMean(matrix);
        if (isLogEnabled()) {
            log("Global mean : ");
            mean.print(20, 15);
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            int intValue = this.classes.get(i).intValue();
            if (hashMap.containsKey(Integer.valueOf(intValue))) {
                hashMap.put(Integer.valueOf(intValue), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(intValue))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(intValue), 1);
            }
        }
        this.nbGroups = hashMap.size();
        HashMap hashMap2 = new HashMap();
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            int intValue2 = ((Integer) it.next()).intValue();
            log("Class " + intValue2 + " has " + hashMap.get(Integer.valueOf(intValue2)) + " members");
            hashMap2.put(Integer.valueOf(intValue2), new Matrix(((Integer) hashMap.get(Integer.valueOf(intValue2))).intValue(), this.dim));
        }
        for (int i2 = 0; i2 < matrix.getRowDimension(); i2++) {
            int intValue3 = this.classes.get(i2).intValue();
            Matrix matrix2 = (Matrix) hashMap2.get(Integer.valueOf(intValue3));
            int intValue4 = ((Integer) hashMap.get(Integer.valueOf(intValue3))).intValue() - 1;
            for (int i3 = 0; i3 < this.dim; i3++) {
                matrix2.set(intValue4, i3, matrix.get(i2, i3));
            }
            hashMap.put(Integer.valueOf(intValue3), Integer.valueOf(intValue4));
        }
        this.classesMean = new HashMap<>();
        Iterator it2 = hashMap2.keySet().iterator();
        while (it2.hasNext()) {
            int intValue5 = ((Integer) it2.next()).intValue();
            Matrix mean2 = getMean((Matrix) hashMap2.get(Integer.valueOf(intValue5)));
            if (isLogEnabled()) {
                log("Class " + intValue5 + " mean : ");
                mean2.print(20, 15);
            }
            this.classesMean.put(Integer.valueOf(intValue5), mean2);
        }
        Iterator it3 = hashMap2.keySet().iterator();
        while (it3.hasNext()) {
            Matrix matrix3 = (Matrix) hashMap2.get(Integer.valueOf(((Integer) it3.next()).intValue()));
            matrix3.plusEquals(new Matrix(matrix3.getRowDimension(), 1, -1.0d).times(mean));
        }
        HashMap hashMap3 = new HashMap();
        Iterator it4 = hashMap2.keySet().iterator();
        while (it4.hasNext()) {
            int intValue6 = ((Integer) it4.next()).intValue();
            Matrix varCovMatrix = getVarCovMatrix((Matrix) hashMap2.get(Integer.valueOf(intValue6)));
            if (isLogEnabled()) {
                log("C" + intValue6 + ": ");
                varCovMatrix.print(20, 15);
            }
            hashMap3.put(Integer.valueOf(intValue6), varCovMatrix);
        }
        Matrix matrix4 = new Matrix(this.dim, this.dim, 0.0d);
        Iterator it5 = hashMap2.keySet().iterator();
        while (it5.hasNext()) {
            matrix4.plusEquals(((Matrix) hashMap3.get(Integer.valueOf(((Integer) it5.next()).intValue()))).times(((Matrix) hashMap2.get(Integer.valueOf(r0))).getRowDimension()));
        }
        matrix4.timesEquals(1.0d / matrix.getRowDimension());
        if (isLogEnabled()) {
            log("Pooled C: ");
            matrix4.print(20, 15);
        }
        this.invPooled = matrix4.inverse();
        Matrix matrix5 = new Matrix(1, this.nbGroups);
        Iterator it6 = hashMap2.keySet().iterator();
        while (it6.hasNext()) {
            matrix5.set(0, ((Integer) it6.next()).intValue(), Math.log(((Matrix) hashMap2.get(Integer.valueOf(r0))).getRowDimension() / matrix.getRowDimension()));
        }
        this.constStuff = new Matrix(1, this.nbGroups);
        Iterator it7 = hashMap2.keySet().iterator();
        while (it7.hasNext()) {
            int intValue7 = ((Integer) it7.next()).intValue();
            Matrix matrix6 = this.classesMean.get(Integer.valueOf(intValue7));
            this.constStuff.set(0, intValue7, matrix5.get(0, intValue7) - (0.5d * matrix6.times(this.invPooled).times(matrix6.transpose()).get(0, 0)));
        }
        log("LDA done");
    }

    public VectorSignature project(VectorSignature vectorSignature) throws SignatureException {
        DenseVectorSignature denseVectorSignature = new DenseVectorSignature(Math.max(this.nbGroups, 3));
        ArrayList arrayList = new ArrayList();
        arrayList.add(vectorSignature);
        Matrix matrix = getMatrix(arrayList);
        for (int i = 0; i < this.nbGroups; i++) {
            denseVectorSignature.set(i, this.classesMean.get(Integer.valueOf(i)).times(this.invPooled).times(matrix.transpose()).get(0, 0) + this.constStuff.get(0, i));
        }
        return denseVectorSignature;
    }

    @Override // plugins.nherve.toolbox.image.feature.DimensionReductionAlgorithm
    public List<VectorSignature> project(List<VectorSignature> list) throws SignatureException {
        ArrayList arrayList = new ArrayList();
        Iterator<VectorSignature> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(project(it.next()));
        }
        return arrayList;
    }
}
