package edu.ou.utz8239.bayesnet;

import edu.ou.utz8239.bayesnet.data.sources.HiddenVariableSource;
import edu.ou.utz8239.bayesnet.data.sources.InstanceSource;
import edu.ou.utz8239.bayesnet.data.sources.MultiInstanceSource;
import edu.ou.utz8239.bayesnet.evaluation.LogLikelihoodProvider;
import edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner;
import edu.ou.utz8239.bayesnet.learning.ProbabilityLearnerFactory;
import edu.ou.utz8239.bayesnet.probabilties.AttributeClass;
import edu.ou.utz8239.bayesnet.probabilties.utils.RandomFactory;
import gnu.trove.THashSet;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntHashSet;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.InputStream;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;
import java.util.zip.GZIPInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.log4j.Logger;
import org.jgrapht.DirectedGraph;
import org.jgrapht.graph.DefaultDirectedGraph;
import org.jgrapht.graph.DefaultEdge;
import org.jgrapht.graph.SimpleDirectedGraph;

/* loaded from: input_file:edu/ou/utz8239/bayesnet/BayesianNetworkFactory.class */
public class BayesianNetworkFactory {
    private static Logger logger = Logger.getLogger(BayesianNetworkFactory.class);
    private static int HIDDEN_INIT_TRIES = 20;

    public static BayesianNetwork createBayesianNetwork(File file) throws Exception {
        BIFXMLParser bIFXMLParser = new BIFXMLParser();
        InputStream bufferedInputStream = new BufferedInputStream(FileUtils.openInputStream(file));
        if (FilenameUtils.isExtension(file.getName(), new String[]{"gz", "gzip"})) {
            bufferedInputStream = new GZIPInputStream(bufferedInputStream);
        }
        BayesianNetwork fromXML = bIFXMLParser.fromXML(bufferedInputStream);
        bufferedInputStream.close();
        return fromXML;
    }

    public static BayesianNetwork createBayesianNetwork(DirectedGraph<AttributeClass, DefaultEdge> directedGraph, String str, InstanceSource instanceSource, TIntArrayList tIntArrayList) throws Exception {
        logger.info("Attempting to learn a network for structure " + str);
        logger.info(directedGraph);
        Set<AttributeClass> tHashSet = new THashSet<>();
        tHashSet.addAll(directedGraph.vertexSet());
        tHashSet.removeAll(instanceSource.getProvidedClasses());
        if (tHashSet.isEmpty()) {
            NetworkProbabilityLearner createProbabilityLearner = ProbabilityLearnerFactory.createProbabilityLearner(instanceSource);
            BayesianNetwork createNetwork = createNetwork(directedGraph, str);
            createProbabilityLearner.learnFullyObservableNetwork(createNetwork, tIntArrayList);
            return createNetwork;
        }
        BayesianNetwork bayesianNetwork = null;
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < HIDDEN_INIT_TRIES; i++) {
            BayesianNetwork createNetwork2 = createNetwork(directedGraph, str);
            MultiInstanceSource multiInstanceSource = new MultiInstanceSource(new HiddenVariableSource(new TIntHashSet(tIntArrayList.toNativeArray()), tHashSet), instanceSource);
            ProbabilityLearnerFactory.createProbabilityLearner(multiInstanceSource).learnPartiallyObservableNetwork(createNetwork2, tIntArrayList, tHashSet, 0.001d, 25, RandomFactory.getRandomGenerator());
            double score = LogLikelihoodProvider.getLogLikelihoodProvider().score(createNetwork2, multiInstanceSource, tIntArrayList);
            logger.info("Iteration attempt " + i + " had likelihood of " + score);
            if (score > d) {
                bayesianNetwork = createNetwork2;
                d = score;
            }
        }
        logger.info("Best initalization attempt yielded a network with likelihood of " + d);
        logger.info(bayesianNetwork);
        return bayesianNetwork;
    }

    public static DirectedGraph<AttributeClass, DefaultEdge> toStructureGraph(BayesianNetwork bayesianNetwork) {
        DefaultDirectedGraph defaultDirectedGraph = new DefaultDirectedGraph(DefaultEdge.class);
        Iterator it = bayesianNetwork.getNetworkStructure().vertexSet().iterator();
        while (it.hasNext()) {
            defaultDirectedGraph.addVertex(((BayesNode) it.next()).getProbabilityClass());
        }
        for (DefaultEdge defaultEdge : bayesianNetwork.getNetworkStructure().edgeSet()) {
            defaultDirectedGraph.addEdge(((BayesNode) bayesianNetwork.getNetworkStructure().getEdgeSource(defaultEdge)).getProbabilityClass(), ((BayesNode) bayesianNetwork.getNetworkStructure().getEdgeTarget(defaultEdge)).getProbabilityClass());
        }
        return defaultDirectedGraph;
    }

    private static BayesianNetwork createNetwork(DirectedGraph<AttributeClass, DefaultEdge> directedGraph, String str) {
        SimpleDirectedGraph simpleDirectedGraph = new SimpleDirectedGraph(DefaultEdge.class);
        Iterator it = directedGraph.vertexSet().iterator();
        while (it.hasNext()) {
            simpleDirectedGraph.addVertex(new BayesNode((AttributeClass) it.next()));
        }
        for (DefaultEdge defaultEdge : directedGraph.edgeSet()) {
            simpleDirectedGraph.addEdge(locateNode((AttributeClass) directedGraph.getEdgeSource(defaultEdge), simpleDirectedGraph.vertexSet()), locateNode((AttributeClass) directedGraph.getEdgeTarget(defaultEdge), simpleDirectedGraph.vertexSet()));
        }
        return new BayesianNetwork(str, simpleDirectedGraph);
    }

    private static BayesNode locateNode(AttributeClass attributeClass, Collection<BayesNode> collection) {
        for (BayesNode bayesNode : collection) {
            if (bayesNode.getProbabilityClass().equals(attributeClass)) {
                return bayesNode;
            }
        }
        return null;
    }
}
