package ca.uwaterloo.crysp.trainerLXG;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: classes.dex */
public class KNNClassifier {
    int k;
    public KNNModel model;
    int numFeatures;
    String user;
    public boolean scaleFeatures = true;
    ArrayList<Data> ds = MainActivity.ds;

    /* loaded from: classes.dex */
    public class ComputedDistance {
        double distance;
        int label;

        public ComputedDistance(int i, double d) {
            this.label = i;
            this.distance = d;
        }
    }

    /* loaded from: classes.dex */
    public class KNNModel {
        public List<double[]> data;
        public int k;
        public List<Integer> labels;
        public int numFeatures;
        public double[] scalingFactor;

        public KNNModel(int i, int i2) {
            this.k = i;
            this.numFeatures = i2;
            this.scalingFactor = new double[i2];
        }
    }

    public KNNClassifier(int i, int i2) {
        this.k = i;
        this.numFeatures = i2;
        this.model = new KNNModel(i, i2);
    }

    public int classify(double[] dArr) throws IllegalStateException {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.model.data.size(); i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.model.numFeatures; i2++) {
                d += Math.pow((this.model.data.get(i)[i2] / this.model.scalingFactor[i2]) - (dArr[i2] / this.model.scalingFactor[i2]), 2.0d);
            }
            arrayList.add(new ComputedDistance(this.model.labels.get(i).intValue(), Math.sqrt(d)));
        }
        return getMajorityLabel(sortDistances(arrayList, this.model.k));
    }

    public ArrayList<Integer> classify(List<double[]> list, List<double[]> list2, List<double[]> list3, boolean z) {
        ArrayList arrayList = new ArrayList();
        ArrayList<Integer> arrayList2 = new ArrayList<>();
        double[] dArr = new double[this.numFeatures];
        for (int i = 0; i < this.numFeatures; i++) {
            dArr[i] = 0.0d;
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            for (int i3 = 0; i3 < this.numFeatures; i3++) {
                double abs = Math.abs(list.get(i2)[i3]);
                if (abs >= dArr[i3]) {
                    dArr[i3] = abs;
                }
            }
        }
        for (int i4 = 0; i4 < this.numFeatures; i4++) {
            if (dArr[i4] == 0.0d) {
                dArr[i4] = 1.0d;
            }
        }
        for (double[] dArr2 : list3) {
            arrayList.clear();
            for (double[] dArr3 : list) {
                double d = 0.0d;
                for (int i5 = 1; i5 < this.numFeatures; i5++) {
                    if (!Double.isNaN(dArr[i5])) {
                        d += Math.pow((dArr3[i5] * dArr[i5]) - (dArr2[i5] * dArr[i5]), 2.0d);
                    }
                }
                arrayList.add(new ComputedDistance(1, Math.sqrt(d)));
            }
            for (double[] dArr4 : list2) {
                double d2 = 0.0d;
                for (int i6 = 1; i6 < this.numFeatures; i6++) {
                    if (!Double.isNaN(dArr[i6])) {
                        d2 += Math.pow((dArr4[i6] * dArr[i6]) - (dArr2[i6] * dArr[i6]), 2.0d);
                    }
                }
                arrayList.add(new ComputedDistance(-1, Math.sqrt(d2)));
            }
            if (z) {
                arrayList2.add(Integer.valueOf(getNeighborScore(sortDistances(arrayList, 25))));
            } else {
                arrayList2.add(Integer.valueOf(getMajorityLabel(sortDistances(arrayList, this.k))));
            }
        }
        return arrayList2;
    }

    public int getMajorityLabel(List<ComputedDistance> list) {
        if (list.size() == 0) {
            return 0;
        }
        HashMap hashMap = new HashMap();
        for (ComputedDistance computedDistance : list) {
            if (hashMap.containsKey(Integer.valueOf(computedDistance.label))) {
                hashMap.put(Integer.valueOf(computedDistance.label), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(computedDistance.label))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(computedDistance.label), 1);
            }
        }
        int i = -1;
        int i2 = -1;
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            if (((Integer) hashMap.get(Integer.valueOf(intValue))).intValue() > i2) {
                i = intValue;
                i2 = ((Integer) hashMap.get(Integer.valueOf(intValue))).intValue();
            }
        }
        return i;
    }

    public int getNeighborScore(List<ComputedDistance> list) {
        if (list.size() == 0) {
            return 0;
        }
        int i = 0;
        Iterator<ComputedDistance> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().label == 1) {
                i++;
            }
        }
        return i;
    }

    public ArrayList<ComputedDistance> sortDistances(List<ComputedDistance> list, int i) {
        ArrayList<ComputedDistance> arrayList = new ArrayList<>();
        for (int i2 = 0; i2 < i; i2++) {
            double d = Double.MAX_VALUE;
            int i3 = -1;
            for (int i4 = 0; i4 < list.size(); i4++) {
                if (d > list.get(i4).distance) {
                    i3 = i4;
                    d = list.get(i4).distance;
                }
            }
            if (i3 == -1) {
                break;
            }
            arrayList.add(new ComputedDistance(list.get(i3).label, list.get(i3).distance));
            list.remove(i3);
        }
        return arrayList;
    }

    public void train(List<double[]> list, List<Integer> list2) {
        if (list == null || list.size() < 1) {
            throw new IllegalArgumentException("Provided data is invalid");
        }
        this.model.data = list;
        this.model.labels = list2;
        if (this.scaleFeatures) {
            for (int i = 0; i < this.model.numFeatures; i++) {
                this.model.scalingFactor[i] = 0.0d;
            }
            for (int i2 = 0; i2 < list.size(); i2++) {
                if (list2.get(i2).intValue() != -1) {
                    for (int i3 = 0; i3 < this.model.numFeatures; i3++) {
                        if (this.model.scalingFactor[i3] < Math.abs(list.get(i2)[i3])) {
                            this.model.scalingFactor[i3] = Math.abs(list.get(i2)[i3]);
                        }
                    }
                }
            }
            for (int i4 = 0; i4 < this.model.numFeatures; i4++) {
                if (this.model.scalingFactor[i4] == 0.0d) {
                    this.model.scalingFactor[i4] = 1.0d;
                }
            }
        }
        updateVictimScore();
    }

    public void updateVictimScore() {
        int size = MainActivity.positive.size() / 2;
        ArrayList<double[]> arrayList = MainActivity.positive;
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < MainActivity.trainingSet.size(); i++) {
            if (MainActivity.trainingLabels.get(i).intValue() != 1) {
                arrayList3.add(MainActivity.trainingSet.get(i));
            }
        }
        arrayList2.addAll(classify(arrayList.subList(size, arrayList.size()), arrayList3, arrayList.subList(0, size), true));
        arrayList2.addAll(classify(arrayList.subList(0, size), arrayList3, arrayList.subList(size, arrayList.size()), true));
        int i2 = 0;
        int i3 = -1;
        for (int i4 = 0; i4 < arrayList2.size(); i4++) {
            if (((Integer) arrayList2.get(i4)).intValue() > i3) {
                i3 = ((Integer) arrayList2.get(i4)).intValue();
                i2 = i4;
            }
        }
        MainActivity.targetIndex = i2;
        MainActivity.populateRaw();
    }
}
