package edu.ou.utz8239.bayesnet.evaluation;

import edu.ou.utz8239.bayesnet.BayesianNetwork;
import edu.ou.utz8239.bayesnet.BayesianNetworkFactory;
import edu.ou.utz8239.bayesnet.data.sources.ARFFSource;
import edu.ou.utz8239.bayesnet.data.sources.InstanceSource;
import edu.ou.utz8239.bayesnet.data.sources.KeyFilteringSource;
import edu.ou.utz8239.bayesnet.data.sources.MemoryBasedInstanceSource;
import edu.ou.utz8239.bayesnet.probabilties.Attribute;
import edu.ou.utz8239.bayesnet.probabilties.AttributeClass;
import edu.ou.utz8239.bayesnet.probabilties.Criteria;
import edu.ou.utz8239.bayesnet.probabilties.ProbabilityDistribution;
import gnu.trove.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.apache.commons.math.stat.StatUtils;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/ou/utz8239/bayesnet/evaluation/VariableImportanceCalculator.class */
public class VariableImportanceCalculator {
    private static Logger logger = Logger.getLogger(VariableImportanceCalculator.class);
    private static final String OOB = "outOfBag.txt";
    private final Collection<File> components;
    private final InstanceSource allTraining;
    private final Attribute attr;

    public Attribute getAttr() {
        return this.attr;
    }

    public VariableImportanceCalculator(File file, int i, InstanceSource instanceSource, Attribute attribute) {
        this.components = FileUtils.listFiles(file, FileFilterUtils.andFileFilter(FileFilterUtils.prefixFileFilter("network"), FileFilterUtils.suffixFileFilter("xml")), FileFilterUtils.trueFileFilter());
        this.allTraining = instanceSource;
        this.attr = attribute;
    }

    public VariableImportance getImportanceScores(AttributeClass attributeClass, int i) throws Exception {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        for (int i2 = 0; i2 < i; i2++) {
            descriptiveStatistics.addValue(getImportanceScore(attributeClass).getZScore());
        }
        return new VariableImportance(attributeClass, descriptiveStatistics.getMean(), descriptiveStatistics.getStandardDeviation());
    }

    public VariableImportance getImportanceScore(AttributeClass attributeClass) throws Exception {
        double[] dArr = new double[this.components.size()];
        int i = 0;
        for (File file : this.components) {
            BayesianNetwork network = getNetwork(file);
            KeyFilteringSource keyFilteringSource = new KeyFilteringSource(this.allTraining, getOutOfBag(file));
            int i2 = i;
            i++;
            dArr[i2] = getVotesCast(network, keyFilteringSource, this.attr) - getVotesCast(network, permuteClass(keyFilteringSource, attributeClass), this.attr);
        }
        double mean = StatUtils.mean(dArr) / (new StandardDeviation().evaluate(dArr) / Math.sqrt(dArr.length));
        return Double.isNaN(mean) ? new VariableImportance(attributeClass, Double.NaN, Double.NaN) : new VariableImportance(attributeClass, mean, 0.0d);
    }

    private double getVotesCast(BayesianNetwork bayesianNetwork, InstanceSource instanceSource, Attribute attribute) throws Exception {
        double d = 0.0d;
        for (int i : instanceSource.getProvidedKeys().toArray()) {
            if (bayesianNetwork.inferProbability(attribute.getVariable(), Criteria.createFromAttributes(Attribute.getAttributes(instanceSource.getValues(i, instanceSource.getProvidedClasses())))).getMostLikelyAttribute().equals(attribute)) {
                d += 1.0d;
            }
        }
        return d;
    }

    private InstanceSource permuteClass(InstanceSource instanceSource, AttributeClass attributeClass) throws Exception {
        MemoryBasedInstanceSource memoryBasedInstanceSource = new MemoryBasedInstanceSource();
        ArrayList arrayList = new ArrayList();
        for (int i : instanceSource.getProvidedKeys().toArray()) {
            arrayList.add(instanceSource.getValue(i, attributeClass).getMostLikelyAttribute());
        }
        Collections.shuffle(arrayList);
        for (int i2 : instanceSource.getProvidedKeys().toArray()) {
            Iterator it = instanceSource.getProvidedClasses().iterator();
            while (it.hasNext()) {
                AttributeClass attributeClass2 = (AttributeClass) it.next();
                if (attributeClass2.equals(attributeClass)) {
                    Attribute attribute = (Attribute) arrayList.remove(0);
                    memoryBasedInstanceSource.setValue(i2, attributeClass2, ProbabilityDistribution.createAbsoluteDistribution(attribute.getVariable(), attribute.getValue()));
                } else {
                    memoryBasedInstanceSource.setValue(i2, attributeClass2, instanceSource.getValue(i2, attributeClass2));
                }
            }
        }
        return memoryBasedInstanceSource;
    }

    private BayesianNetwork getNetwork(File file) throws Exception {
        return BayesianNetworkFactory.createBayesianNetwork(file);
    }

    private TIntHashSet getOutOfBag(File file) throws Exception {
        return new TIntHashSet(getIds(new File(file.getParent(), OOB)));
    }

    private int[] getIds(File file) throws IOException {
        ArrayList arrayList = new ArrayList();
        String trim = FileUtils.readFileToString(file).trim();
        for (String str : trim.substring(1, trim.length() - 1).split(",")) {
            arrayList.add(Integer.valueOf(Integer.parseInt(str.trim())));
        }
        int[] iArr = new int[arrayList.size()];
        int i = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            iArr[i2] = ((Integer) it.next()).intValue();
        }
        return iArr;
    }

    public static void main(String[] strArr) throws Exception {
        ARFFSource aRFFSource = new ARFFSource(new File("path to training"));
        File file = new File("path to web");
        AttributeClass findAttributeClassByName = AttributeClass.findAttributeClassByName(aRFFSource.getProvidedClasses(), "PlainsFcst");
        VariableImportanceCalculator variableImportanceCalculator = new VariableImportanceCalculator(file, 200, aRFFSource, new Attribute(findAttributeClassByName, findAttributeClassByName.getValueId("SIG")));
        ArrayList arrayList = new ArrayList();
        Iterator it = aRFFSource.getProvidedClasses().iterator();
        while (it.hasNext()) {
            AttributeClass attributeClass = (AttributeClass) it.next();
            if (!attributeClass.equals(findAttributeClassByName)) {
                arrayList.add(variableImportanceCalculator.getImportanceScore(attributeClass));
            }
        }
        Collections.sort(arrayList, Collections.reverseOrder());
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            System.out.println((VariableImportance) it2.next());
        }
    }
}
