package edu.ou.utz8239.bayesnet.search;

import edu.ou.utz8239.bayesnet.BayesianNetwork;
import edu.ou.utz8239.bayesnet.BayesianNetworkFactory;
import edu.ou.utz8239.bayesnet.data.sources.InstanceSource;
import edu.ou.utz8239.bayesnet.evaluation.BICScoreProvider;
import edu.ou.utz8239.bayesnet.probabilties.AttributeClass;
import gnu.trove.TIntArrayList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.log4j.Logger;
import org.jgrapht.DirectedGraph;
import org.jgrapht.alg.CycleDetector;
import org.jgrapht.graph.DefaultDirectedGraph;
import org.jgrapht.graph.DefaultEdge;

/* loaded from: input_file:edu/ou/utz8239/bayesnet/search/BayesianNetworkSearchProblem.class */
public class BayesianNetworkSearchProblem implements SearchProblem<BayesianNetworkState> {
    private static Logger logger = Logger.getLogger(BayesianNetworkSearchProblem.class);
    private final Set<AttributeClass> classes;
    private final InstanceSource source;
    private final TIntArrayList trainingKeys;
    private final TIntArrayList testKeys;

    public BayesianNetworkSearchProblem(Set<AttributeClass> set, InstanceSource instanceSource, TIntArrayList tIntArrayList, TIntArrayList tIntArrayList2) {
        this.classes = set;
        this.source = instanceSource;
        this.trainingKeys = tIntArrayList;
        this.testKeys = tIntArrayList2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // edu.ou.utz8239.bayesnet.search.SearchProblem
    public BayesianNetworkState breakTies(List<BayesianNetworkState> list) {
        return list.get(0);
    }

    @Override // edu.ou.utz8239.bayesnet.search.SearchProblem
    public int compareStates(BayesianNetworkState bayesianNetworkState, BayesianNetworkState bayesianNetworkState2) {
        if (bayesianNetworkState == bayesianNetworkState2) {
            return 0;
        }
        if (bayesianNetworkState == null) {
            return -1;
        }
        if (bayesianNetworkState2 == null) {
            return 1;
        }
        return bayesianNetworkState.compareTo(bayesianNetworkState2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // edu.ou.utz8239.bayesnet.search.SearchProblem
    public BayesianNetworkState getInitialState() {
        DefaultDirectedGraph defaultDirectedGraph = new DefaultDirectedGraph(DefaultEdge.class);
        Iterator<AttributeClass> it = this.classes.iterator();
        while (it.hasNext()) {
            defaultDirectedGraph.addVertex(it.next());
        }
        return createBayesianNetworkState("Initial", defaultDirectedGraph);
    }

    @Override // edu.ou.utz8239.bayesnet.search.SearchProblem
    public double getScore(BayesianNetworkState bayesianNetworkState) {
        if (bayesianNetworkState == null) {
            return Double.NaN;
        }
        return bayesianNetworkState.getScore();
    }

    protected Collection<DirectedGraph<AttributeClass, DefaultEdge>> getPossibleStructures(BayesianNetworkState bayesianNetworkState) {
        DirectedGraph<AttributeClass, DefaultEdge> structureGraph = BayesianNetworkFactory.toStructureGraph(bayesianNetworkState.getNetwork());
        ArrayList arrayList = new ArrayList();
        for (AttributeClass attributeClass : structureGraph.vertexSet()) {
            for (AttributeClass attributeClass2 : structureGraph.vertexSet()) {
                if (edgeAllowed(attributeClass, attributeClass2, structureGraph)) {
                    DirectedGraph<AttributeClass, DefaultEdge> copyGraph = copyGraph(structureGraph);
                    copyGraph.addEdge(attributeClass, attributeClass2);
                    arrayList.add(copyGraph);
                }
            }
        }
        return arrayList;
    }

    @Override // edu.ou.utz8239.bayesnet.search.SearchProblem
    public Collection<BayesianNetworkState> getSuccessors(BayesianNetworkState bayesianNetworkState) {
        ArrayList arrayList = new ArrayList(getPossibleStructures(bayesianNetworkState));
        Collections.shuffle(arrayList);
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(createBayesianNetworkState("Network_" + (arrayList2.size() + 1), (DirectedGraph) it.next()));
        }
        return arrayList2;
    }

    private DirectedGraph<AttributeClass, DefaultEdge> copyGraph(DirectedGraph<AttributeClass, DefaultEdge> directedGraph) {
        DefaultDirectedGraph defaultDirectedGraph = new DefaultDirectedGraph(DefaultEdge.class);
        Iterator it = directedGraph.vertexSet().iterator();
        while (it.hasNext()) {
            defaultDirectedGraph.addVertex((AttributeClass) it.next());
        }
        for (DefaultEdge defaultEdge : directedGraph.edgeSet()) {
            defaultDirectedGraph.addEdge((AttributeClass) directedGraph.getEdgeSource(defaultEdge), (AttributeClass) directedGraph.getEdgeTarget(defaultEdge));
        }
        return defaultDirectedGraph;
    }

    private BayesianNetworkState createBayesianNetworkState(String str, DirectedGraph<AttributeClass, DefaultEdge> directedGraph) {
        try {
            BayesianNetwork createBayesianNetwork = BayesianNetworkFactory.createBayesianNetwork(directedGraph, str, this.source, this.trainingKeys);
            double score = BICScoreProvider.getBICScoreProvider().score(createBayesianNetwork, this.source, this.testKeys);
            logger.info("Score for " + str + " is " + score);
            return new BayesianNetworkState(createBayesianNetwork, score);
        } catch (Exception e) {
            throw new RuntimeException("Unable to create state named " + str, e);
        }
    }

    protected boolean edgeAllowed(AttributeClass attributeClass, AttributeClass attributeClass2, DirectedGraph<AttributeClass, DefaultEdge> directedGraph) {
        if (attributeClass.equals(attributeClass2) || directedGraph.containsEdge(attributeClass, attributeClass2) || directedGraph.containsEdge(attributeClass2, attributeClass)) {
            return false;
        }
        directedGraph.addEdge(attributeClass, attributeClass2);
        boolean detectCycles = new CycleDetector(directedGraph).detectCycles();
        directedGraph.removeEdge(attributeClass, attributeClass2);
        return !detectCycles;
    }
}
