package edu.ou.utz8239.bayesnet.web;

import edu.ou.utz8239.bayesnet.BIFXMLParser;
import edu.ou.utz8239.bayesnet.BayesianNetwork;
import edu.ou.utz8239.bayesnet.data.sources.ARFFSource;
import edu.ou.utz8239.bayesnet.probabilties.Attribute;
import edu.ou.utz8239.bayesnet.probabilties.AttributeClass;
import edu.ou.utz8239.bayesnet.probabilties.Criteria;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/ou/utz8239/bayesnet/web/ARFFWebScorer.class */
public class ARFFWebScorer implements WebScorer {
    private static Logger logger = Logger.getLogger(ARFFWebScorer.class);
    public static Options opts;
    private File webDir;
    private Map<String, File> components;
    private BIFXMLParser parser = new BIFXMLParser();

    static {
        opts = null;
        opts = new Options();
        opts.addOption("c", "curveFile", true, "File to write curve to").addOption("d", "scratchDirectory", true, "Path to web scratch location");
        opts.addOption("n", "number", true, "Number of networks to use in web");
        opts.addOption("t", "testSet", true, "File containing test set");
        opts.addOption("l", "classLabel", true, "Class we are trying to predict");
        opts.addOption("p", "classValue", true, "Class value for label we are trying to predict as positive");
    }

    public ARFFWebScorer(File file, int i) throws Exception {
        this.webDir = file;
        this.components = getWebFiles(i);
    }

    @Override // edu.ou.utz8239.bayesnet.web.WebScorer
    public double score(Attribute attribute, Criteria criteria) {
        double[] dArr = new double[this.components.size()];
        int i = 0;
        for (String str : this.components.keySet()) {
            logger.info("Reading " + str);
            try {
                FileInputStream openInputStream = FileUtils.openInputStream(this.components.get(str));
                BayesianNetwork fromXML = this.parser.fromXML(openInputStream);
                IOUtils.closeQuietly(openInputStream);
                logger.info("Scoring " + str);
                int i2 = i;
                i++;
                dArr[i2] = fromXML.inferProbability(attribute.getVariable(), criteria).getProbability(attribute.getValue());
            } catch (Exception e) {
                throw new RuntimeException("Unable to open network " + str);
            }
        }
        return calculateScore(dArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.ou.utz8239.bayesnet.web.WebScorer
    public double[] score(Attribute attribute, List<Criteria> list) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = new double[this.components.size()];
        }
        int i2 = 0;
        for (String str : this.components.keySet()) {
            logger.info("Reading " + str);
            try {
                FileInputStream openInputStream = FileUtils.openInputStream(this.components.get(str));
                BayesianNetwork fromXML = this.parser.fromXML(openInputStream);
                IOUtils.closeQuietly(openInputStream);
                logger.info("Scoring " + str);
                for (int i3 = 0; i3 < list.size(); i3++) {
                    double probability = fromXML.inferProbability(attribute.getVariable(), list.get(i3)).getProbability(attribute.getValue());
                    logger.info("On network " + i2 + " num " + i3 + " scored " + probability);
                    dArr[i3][i2] = probability;
                }
                i2++;
            } catch (Exception e) {
                throw new RuntimeException("Unable to open network " + str, e);
            }
        }
        double[] dArr2 = new double[dArr.length];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr2[i4] = calculateScore(dArr[i4]);
        }
        return dArr2;
    }

    protected double calculateScore(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d / dArr.length;
    }

    protected Map<String, File> getWebFiles(int i) throws Exception {
        HashMap hashMap = new HashMap();
        for (File file : getObjectFiles(this.webDir, i)) {
            hashMap.put(new File(file.getParent()).getName(), file);
        }
        return hashMap;
    }

    private Collection<File> getObjectFiles(File file, int i) {
        return getN(getObjectFiles(file), i);
    }

    private Collection<File> getObjectFiles(File file) {
        return FileUtils.listFiles(file, FileFilterUtils.suffixFileFilter("xml"), DirectoryFileFilter.DIRECTORY);
    }

    private <T> Collection<T> getN(Collection<T> collection, int i) {
        if (collection.size() < i) {
            throw new IllegalArgumentException("Only " + collection.size() + " elements to start with");
        }
        Object[] array = collection.toArray();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(array[i2]);
        }
        return arrayList;
    }

    public static void main(String[] strArr) throws Exception {
        CommandLine commandLine = null;
        try {
            commandLine = new PosixParser().parse(opts, strArr);
        } catch (ParseException e) {
            logger.error("Error parsing arguments");
            logger.error(e);
            System.exit(1);
        }
        ARFFSource aRFFSource = new ARFFSource(new File(commandLine.getOptionValue("t")));
        AttributeClass findAttributeClassByName = AttributeClass.findAttributeClassByName(aRFFSource.getProvidedClasses(), commandLine.getOptionValue("p"));
        Attribute attribute = new Attribute(findAttributeClassByName, findAttributeClassByName.getValueId(commandLine.getOptionValue("l")));
        File file = new File(commandLine.getOptionValue("d"));
        logger.debug("Scratch dirctory is " + file);
        ARFFWebScorer aRFFWebScorer = new ARFFWebScorer(file, Integer.parseInt(commandLine.getOptionValue("n")));
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(commandLine.getOptionValue("c"))));
        LinkedList linkedList = new LinkedList();
        int[] iArr = new int[aRFFSource.getProvidedKeys().size()];
        int i = 0;
        for (int i2 : aRFFSource.getProvidedKeys().toArray()) {
            int i3 = i;
            i++;
            iArr[i3] = i2;
            linkedList.add(Criteria.createFromAttributes(Attribute.getAttributes(aRFFSource.getValues(i2, aRFFSource.getProvidedClasses()))));
        }
        double[] score = aRFFWebScorer.score(attribute, linkedList);
        for (int i4 = 0; i4 < iArr.length; i4++) {
            bufferedWriter.write(String.valueOf(score[i4]) + "\t");
            bufferedWriter.write(aRFFSource.getValue(iArr[i4], attribute.getVariable()).getMostLikelyAttribute().getValue() == attribute.getValue() ? "1" : "0");
            bufferedWriter.newLine();
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }
}
