package org.fastica;

import java.io.File;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioSystem;
import org.fastica.FastICAConfig;
import org.fastica.FastICAException;
import org.fastica.ProgressListener;
import org.fastica.math.EigenValueDecompositionSymm;
import org.fastica.math.Matrix;
import org.fastica.math.Vector;
import org.fastica.util.AudioBuffer;

/* loaded from: input_file:org/fastica/FastICA.class */
public class FastICA {
    private double[][] inVectors;
    private double[] meanValues;
    private double[][] vectorsZeroMean;
    private double[][] whiteningMatrix;
    private double[][] dewhiteningMatrix;
    private double[][] whitenedVectors;
    private double[][] weightMatrix;
    private double[][] mixingMatrix;
    private double[][] separatingMatrix;
    private double[][] icVectors;

    public FastICA(double[][] dArr, int i) throws FastICAException {
        algorithm(dArr, new FastICAConfig(i), new TanhCFunction(1.0d), new BelowEVFilter(1.0E-12d, false), new ProgressListener() { // from class: org.fastica.FastICA.1
            @Override // org.fastica.ProgressListener
            public void progressMade(ProgressListener.ComputationState computationState, int i2, int i3, int i4) {
            }
        });
    }

    public FastICA(double[][] dArr, FastICAConfig fastICAConfig, ContrastFunction contrastFunction, EigenValueFilter eigenValueFilter, ProgressListener progressListener) throws FastICAException {
        algorithm(dArr, fastICAConfig, contrastFunction, eigenValueFilter, progressListener);
    }

    private synchronized void algorithm(double[][] dArr, FastICAConfig fastICAConfig, ContrastFunction contrastFunction, EigenValueFilter eigenValueFilter, ProgressListener progressListener) throws FastICAException {
        progressListener.progressMade(ProgressListener.ComputationState.WHITENING, 0, 0, fastICAConfig.getNumICs());
        this.inVectors = dArr;
        this.icVectors = null;
        PCA pca = new PCA(dArr);
        this.meanValues = pca.getMeanValues();
        this.vectorsZeroMean = pca.getVectorsZeroMean();
        eigenValueFilter.passEigenValues(pca.getEigenValues(), pca.getEigenVectors());
        double[] eigenValues = eigenValueFilter.getEigenValues();
        if (eigenValues == null || eigenValues.length == 0) {
            this.mixingMatrix = null;
            this.separatingMatrix = null;
            this.icVectors = null;
            throw new FastICAException(FastICAException.Reason.NO_MORE_EIGENVALUES);
        }
        double[][] eigenVectors = eigenValueFilter.getEigenVectors();
        this.whiteningMatrix = Matrix.mult(Matrix.diag(invVector(sqrtVector(eigenValues))), Matrix.transpose(eigenVectors));
        this.dewhiteningMatrix = Matrix.mult(eigenVectors, Matrix.diag(sqrtVector(eigenValues)));
        this.whitenedVectors = Matrix.mult(this.whiteningMatrix, this.vectorsZeroMean);
        int numOfRows = Matrix.getNumOfRows(this.whitenedVectors);
        int numOfColumns = Matrix.getNumOfColumns(this.whitenedVectors);
        int numICs = fastICAConfig.getNumICs();
        if (numOfRows < numICs) {
            numICs = numOfRows;
        }
        if (fastICAConfig.getInitialMixingMatrix() == null) {
            this.weightMatrix = Matrix.random(numICs, numOfRows);
        } else if (Matrix.getNumOfColumns(fastICAConfig.getInitialMixingMatrix()) == numICs && Matrix.getNumOfRows(fastICAConfig.getInitialMixingMatrix()) == Matrix.getNumOfRows(this.vectorsZeroMean)) {
            this.weightMatrix = Matrix.transpose(Matrix.mult(this.whiteningMatrix, fastICAConfig.getInitialMixingMatrix()));
        } else {
            this.weightMatrix = Matrix.random(numICs, numOfRows);
        }
        this.weightMatrix = Matrix.mult(powerSymmMatrix(Matrix.square(this.weightMatrix), -0.5d), this.weightMatrix);
        int maxIterations = fastICAConfig.getMaxIterations();
        switch (fastICAConfig.getApproach()) {
            case SYMMETRIC:
                boolean z = false;
                for (int i = 0; i < maxIterations && !z; i++) {
                    progressListener.progressMade(ProgressListener.ComputationState.SYMMETRIC, 0, i, numICs);
                    double[][] clone = Matrix.clone(this.weightMatrix);
                    for (int i2 = 0; i2 < numICs; i2++) {
                        double[] vecOfRow = Matrix.getVecOfRow(this.weightMatrix, i2);
                        double d = 0.0d;
                        double[] newVector = Vector.newVector(numOfRows, 0.0d);
                        double d2 = 0.0d;
                        for (int i3 = 0; i3 < numOfColumns; i3++) {
                            double[] vecOfCol = Matrix.getVecOfCol(this.whitenedVectors, i3);
                            double dot = Vector.dot(vecOfRow, vecOfCol);
                            double function = contrastFunction.function(dot);
                            d += dot * function;
                            d2 += contrastFunction.derivative(dot);
                            newVector = Vector.add(newVector, Vector.scale(function, vecOfCol));
                        }
                        double d3 = d / numOfColumns;
                        double[] sub = Vector.sub(vecOfRow, Vector.scale(1.0d / ((d2 / numOfColumns) - d3), Vector.sub(Vector.scale(1.0d / numOfColumns, newVector), Vector.scale(d3, vecOfRow))));
                        for (int i4 = 0; i4 < numOfRows; i4++) {
                            this.weightMatrix[i2][i4] = sub[i4];
                        }
                    }
                    this.weightMatrix = Matrix.mult(powerSymmMatrix(Matrix.square(this.weightMatrix), -0.5d), this.weightMatrix);
                    if (deltaMatrices(this.weightMatrix, clone) < fastICAConfig.getEpsilon()) {
                        z = true;
                    }
                }
            case DEFLATION:
                for (int i5 = 0; i5 < numICs; i5++) {
                    double[] vecOfRow2 = Matrix.getVecOfRow(this.weightMatrix, i5);
                    boolean z2 = false;
                    for (int i6 = 0; i6 < maxIterations && !z2; i6++) {
                        progressListener.progressMade(ProgressListener.ComputationState.DEFLATION, i5, i6, numICs);
                        double[] clone2 = Vector.clone(vecOfRow2);
                        double d4 = 0.0d;
                        double[] newVector2 = Vector.newVector(numOfRows, 0.0d);
                        double d5 = 0.0d;
                        for (int i7 = 0; i7 < numOfColumns; i7++) {
                            double[] vecOfCol2 = Matrix.getVecOfCol(this.whitenedVectors, i7);
                            double dot2 = Vector.dot(clone2, vecOfCol2);
                            double function2 = contrastFunction.function(dot2);
                            d4 += dot2 * function2;
                            d5 += contrastFunction.derivative(dot2);
                            newVector2 = Vector.add(newVector2, Vector.scale(function2, vecOfCol2));
                        }
                        double d6 = d4 / numOfColumns;
                        double[] sub2 = Vector.sub(clone2, Vector.scale(1.0d / ((d5 / numOfColumns) - d6), Vector.sub(Vector.scale(1.0d / numOfColumns, newVector2), Vector.scale(d6, clone2))));
                        for (int i8 = 0; i8 < i5; i8++) {
                            sub2 = Vector.sub(sub2, Vector.scale(Vector.dot(sub2, this.weightMatrix[i8]), this.weightMatrix[i8]));
                        }
                        vecOfRow2 = Vector.scale(1.0d / Math.sqrt(Vector.dot(sub2, sub2)), sub2);
                        for (int i9 = 0; i9 < numOfRows; i9++) {
                            this.weightMatrix[i5][i9] = vecOfRow2[i9];
                        }
                        if (deltaVectors(vecOfRow2, clone2) < fastICAConfig.getEpsilon()) {
                            z2 = true;
                        }
                    }
                }
                break;
        }
        this.mixingMatrix = Matrix.mult(this.dewhiteningMatrix, Matrix.transpose(this.weightMatrix));
        this.separatingMatrix = Matrix.mult(this.weightMatrix, this.whiteningMatrix);
        progressListener.progressMade(ProgressListener.ComputationState.READY, numICs, maxIterations, numICs);
    }

    private static double deltaMatrices(double[][] dArr, double[][] dArr2) {
        double[][] sub = Matrix.sub(dArr, dArr2);
        double d = 0.0d;
        int numOfRows = Matrix.getNumOfRows(dArr);
        int numOfColumns = Matrix.getNumOfColumns(dArr);
        for (int i = 0; i < numOfRows; i++) {
            for (int i2 = 0; i2 < numOfColumns; i2++) {
                d += Math.abs(sub[i][i2]);
            }
        }
        return d / (numOfRows * numOfColumns);
    }

    private static double deltaVectors(double[] dArr, double[] dArr2) {
        double[] sub = Vector.sub(dArr, dArr2);
        double d = 0.0d;
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            d += Math.abs(sub[i]);
        }
        return d / length;
    }

    private static double[][] powerSymmMatrix(double[][] dArr, double d) {
        EigenValueDecompositionSymm eigenValueDecompositionSymm = new EigenValueDecompositionSymm(dArr);
        int numOfRows = Matrix.getNumOfRows(dArr);
        double[][] eigenVectors = eigenValueDecompositionSymm.getEigenVectors();
        double[] eigenValues = eigenValueDecompositionSymm.getEigenValues();
        for (int i = 0; i < numOfRows; i++) {
            eigenValues[i] = Math.pow(eigenValues[i], d);
        }
        return Matrix.mult(Matrix.mult(eigenVectors, Matrix.diag(eigenValues)), Matrix.transpose(eigenVectors));
    }

    private static double[] invVector(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = 1.0d / dArr[i];
        }
        return dArr2;
    }

    private static double[] sqrtVector(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = Math.sqrt(dArr[i]);
        }
        return dArr2;
    }

    public synchronized double[][] getICVectors() {
        if (this.icVectors == null) {
            this.icVectors = Matrix.mult(this.separatingMatrix, this.inVectors);
        }
        return this.icVectors;
    }

    public double[][] getMixingMatrix() {
        return this.mixingMatrix;
    }

    public double[][] getSeparatingMatrix() {
        return this.separatingMatrix;
    }

    public static void main(String[] strArr) {
        if (strArr.length != 3) {
            System.out.println("Usage:");
            System.out.println("java org.fastica.FastICA [input wave] [number of independent components] [output wave]");
            System.out.println();
            return;
        }
        try {
            AudioBuffer audioBuffer = new AudioBuffer(new File(strArr[0]));
            double[][] newMatrix = Matrix.newMatrix(5, 2);
            newMatrix[0][0] = 0.5d;
            newMatrix[0][1] = 0.5d;
            newMatrix[1][0] = 0.3d;
            newMatrix[1][1] = 0.7d;
            newMatrix[2][0] = 0.6d;
            newMatrix[2][1] = 0.2d;
            newMatrix[3][0] = 0.2d;
            newMatrix[3][1] = 0.6d;
            newMatrix[4][0] = 0.3d;
            newMatrix[4][1] = 0.5d;
            double[][] mult = Matrix.mult(newMatrix, audioBuffer.getData());
            CompositeEVFilter compositeEVFilter = new CompositeEVFilter();
            compositeEVFilter.add(new BelowEVFilter(1.0E-8d, false));
            compositeEVFilter.add(new SortingEVFilter(true, true));
            FastICAConfig fastICAConfig = new FastICAConfig(Integer.parseInt(strArr[1]), FastICAConfig.Approach.DEFLATION, 1.0d, 1.0E-16d, 1000, null);
            ProgressListener progressListener = new ProgressListener() { // from class: org.fastica.FastICA.2
                @Override // org.fastica.ProgressListener
                public void progressMade(ProgressListener.ComputationState computationState, int i, int i2, int i3) {
                    System.out.print("\r" + Integer.toString(i) + " - " + Integer.toString(i2) + "     ");
                }
            };
            System.out.println("Performing ICA");
            FastICA fastICA = new FastICA(mult, fastICAConfig, new Power3CFunction(), compositeEVFilter, progressListener);
            System.out.println();
            AudioSystem.write(new AudioBuffer(fastICA.getICVectors(), audioBuffer.getSampleRate()).getStream(), AudioFileFormat.Type.WAVE, new File(strArr[2]));
        } catch (Exception e) {
            e.printStackTrace(System.err);
        }
    }
}
