package edu.ou.utz8239.bayesnet.learning;

import edu.ou.utz8239.bayesnet.BayesNode;
import edu.ou.utz8239.bayesnet.BayesianNetwork;
import edu.ou.utz8239.bayesnet.data.sources.InstanceSource;
import edu.ou.utz8239.bayesnet.evaluation.LogLikelihoodProvider;
import edu.ou.utz8239.bayesnet.exceptions.DataTableException;
import edu.ou.utz8239.bayesnet.probabilties.Attribute;
import edu.ou.utz8239.bayesnet.probabilties.AttributeClass;
import edu.ou.utz8239.bayesnet.probabilties.ConditionalProbabilityTable;
import edu.ou.utz8239.bayesnet.probabilties.Criteria;
import edu.ou.utz8239.bayesnet.probabilties.ProbabilityDistribution;
import gnu.trove.THashSet;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import gnu.trove.TIntObjectProcedure;
import gnu.trove.TIntProcedure;
import gnu.trove.TObjectProcedure;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.log4j.Logger;
import org.jgrapht.graph.EdgeReversedGraph;
import org.jgrapht.traverse.TopologicalOrderIterator;

/* loaded from: input_file:edu/ou/utz8239/bayesnet/learning/NetworkProbabilityLearner.class */
public class NetworkProbabilityLearner {
    private static Logger logger = Logger.getLogger(NetworkProbabilityLearner.class);
    protected final InstanceSource source;

    public NetworkProbabilityLearner(InstanceSource instanceSource) {
        this.source = instanceSource;
    }

    public void learnPartiallyObservableNetwork(BayesianNetwork bayesianNetwork, TIntArrayList tIntArrayList, Set<AttributeClass> set, double d, int i, Random random) throws Exception {
        logger.info("Learning network " + bayesianNetwork.getName() + " threshold=" + d + " maximum iterations=" + i);
        learnHiddenNodes(bayesianNetwork, tIntArrayList, set, d, i, random);
        performMStepFuzzyCounting(tIntArrayList, bayesianNetwork);
    }

    public void learnFullyObservableNetwork(BayesianNetwork bayesianNetwork, TIntArrayList tIntArrayList) throws Exception {
        performMStepFuzzyCounting(tIntArrayList, bayesianNetwork);
    }

    private void learnHiddenNodes(BayesianNetwork bayesianNetwork, TIntArrayList tIntArrayList, Set<AttributeClass> set, double d, int i, Random random) throws Exception {
        double d2;
        int i2 = 0;
        double d3 = Double.NEGATIVE_INFINITY;
        do {
            performMStepFuzzyCounting(tIntArrayList, bayesianNetwork);
            performEStepUpdateHiddens(tIntArrayList, bayesianNetwork, set);
            d2 = d3;
            d3 = LogLikelihoodProvider.getLogLikelihoodProvider().score(bayesianNetwork, this.source, tIntArrayList);
            i2++;
            logger.debug("After iteration " + i2 + " likelihood is " + d3 + " previous likelihood was " + d2);
            if (Math.abs(d3 - d2) <= d) {
                logger.info("Hit threshold so we mutate before next iterations");
                Iterator<AttributeClass> it = set.iterator();
                while (it.hasNext()) {
                    BayesNode node = bayesianNetwork.getNode(it.next());
                    logger.debug("Mutating " + node.getCPT());
                    mutateCPT(node.getCPT(), random);
                    logger.debug("Mutated to " + node.getCPT());
                }
            }
            if (i2 >= i) {
                break;
            }
        } while (Math.abs(d3 - d2) > d);
        if (i2 >= i) {
            logger.warn("Learning stopped after maximum iterations where run");
        } else {
            logger.info("Stopped learning after threshold reached at iteration=" + i2);
        }
    }

    private void mutateCPT(ConditionalProbabilityTable conditionalProbabilityTable, Random random) {
        conditionalProbabilityTable.setDistribution(generateRandomCriteria(conditionalProbabilityTable.getConditionals(), random), ProbabilityDistribution.createRandomDistribution(conditionalProbabilityTable.getVariable()));
    }

    private Criteria generateRandomCriteria(List<AttributeClass> list, Random random) {
        HashSet hashSet = new HashSet();
        for (AttributeClass attributeClass : list) {
            hashSet.add(new Attribute(attributeClass, (byte) random.nextInt(attributeClass.getDegree())));
        }
        return Criteria.createFromAttributes(hashSet);
    }

    private void performEStepUpdateHiddens(TIntArrayList tIntArrayList, BayesianNetwork bayesianNetwork, Set<AttributeClass> set) throws Exception {
        TopologicalOrderIterator topologicalOrderIterator = new TopologicalOrderIterator(new EdgeReversedGraph(bayesianNetwork.getNetworkStructure()));
        while (topologicalOrderIterator.hasNext()) {
            BayesNode bayesNode = (BayesNode) topologicalOrderIterator.next();
            if (set.contains(bayesNode.getProbabilityClass())) {
                updateHiddenNode(tIntArrayList, bayesianNetwork, bayesNode);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateHiddenNode(TIntArrayList tIntArrayList, final BayesianNetwork bayesianNetwork, final BayesNode bayesNode) throws Exception {
        final TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap();
        final Set<AttributeClass> markovBlanket = getMarkovBlanket(bayesianNetwork, bayesNode);
        tIntArrayList.forEach(new TIntProcedure() { // from class: edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner.1
            public boolean execute(int i) {
                try {
                    ProbabilityDistribution inferProbability = bayesianNetwork.inferProbability(bayesNode.getProbabilityClass(), Criteria.createFromDistributions(new HashSet(NetworkProbabilityLearner.this.source.getValues(i, markovBlanket).values())));
                    THashSet tHashSet = (THashSet) tIntObjectHashMap.get(i);
                    if (tHashSet == null) {
                        tHashSet = new THashSet();
                    }
                    tHashSet.add(inferProbability);
                    tIntObjectHashMap.put(i, tHashSet);
                    return true;
                } catch (DataTableException e) {
                    throw new RuntimeException("Error updating hiddens for " + i, e);
                }
            }
        });
        tIntObjectHashMap.forEachEntry(new TIntObjectProcedure<THashSet<ProbabilityDistribution>>() { // from class: edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner.2
            public boolean execute(final int i, THashSet<ProbabilityDistribution> tHashSet) {
                tHashSet.forEach(new TObjectProcedure<ProbabilityDistribution>() { // from class: edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner.2.1
                    public boolean execute(ProbabilityDistribution probabilityDistribution) {
                        try {
                            NetworkProbabilityLearner.this.source.setValue(i, probabilityDistribution.getAttributeClass(), probabilityDistribution);
                            return true;
                        } catch (DataTableException e) {
                            throw new RuntimeException("Error updating hiddens for " + i, e);
                        }
                    }
                });
                return true;
            }
        });
    }

    private Set<AttributeClass> getMarkovBlanket(BayesianNetwork bayesianNetwork, BayesNode bayesNode) {
        HashSet hashSet = new HashSet();
        Iterator<BayesNode> it = bayesNode.getParents().iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getProbabilityClass());
        }
        for (BayesNode bayesNode2 : bayesianNetwork.getChildren(bayesNode)) {
            hashSet.add(bayesNode2.getProbabilityClass());
            for (BayesNode bayesNode3 : bayesNode2.getParents()) {
                if (!bayesNode3.equals(bayesNode)) {
                    hashSet.add(bayesNode3.getProbabilityClass());
                }
            }
        }
        return Collections.unmodifiableSet(hashSet);
    }

    private void performMStepFuzzyCounting(TIntArrayList tIntArrayList, BayesianNetwork bayesianNetwork) throws DataTableException {
        TopologicalOrderIterator topologicalOrderIterator = new TopologicalOrderIterator(bayesianNetwork.getNetworkStructure());
        while (topologicalOrderIterator.hasNext()) {
            calculateDistributions(tIntArrayList, (BayesNode) topologicalOrderIterator.next());
        }
    }

    private void calculateDistributions(TIntArrayList tIntArrayList, BayesNode bayesNode) throws DataTableException {
        ConditionalProbabilityTable cpt = bayesNode.getCPT();
        if (cpt == null) {
            throw new IllegalArgumentException(bayesNode + "'s cpt must be initialized");
        }
        for (Criteria criteria : cpt) {
            bayesNode.getCPT().setDistribution(criteria, calculateProbabilities(tIntArrayList, bayesNode.getProbabilityClass(), criteria));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ProbabilityDistribution calculateProbabilities(TIntArrayList tIntArrayList, AttributeClass attributeClass, Criteria criteria) throws DataTableException {
        double fuzzyCount = fuzzyCount(tIntArrayList, criteria);
        ProbabilityDistribution probabilityDistribution = new ProbabilityDistribution(attributeClass);
        for (int i = 0; i < attributeClass.getDegree(); i++) {
            probabilityDistribution.setProbability(i, fuzzyCount(tIntArrayList, criteria.add(new Attribute(attributeClass, (byte) i))) / fuzzyCount);
        }
        return probabilityDistribution;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double fuzzyCount(TIntArrayList tIntArrayList, Criteria criteria) throws DataTableException {
        return this.source.count(tIntArrayList, criteria);
    }
}
