package edu.ou.utz8239.bayesnet.learning;

import edu.ou.utz8239.bayesnet.BayesNode;
import edu.ou.utz8239.bayesnet.BayesianNetwork;
import edu.ou.utz8239.bayesnet.data.sources.InstanceSource;
import edu.ou.utz8239.bayesnet.exceptions.DataTableException;
import edu.ou.utz8239.bayesnet.probabilties.AttributeClass;
import edu.ou.utz8239.bayesnet.probabilties.Criteria;
import edu.ou.utz8239.bayesnet.probabilties.ProbabilityDistribution;
import edu.ou.utz8239.bayesnet.probabilties.utils.MathUtils;
import gnu.trove.THashMap;
import gnu.trove.TIntArrayList;
import gnu.trove.TObjectDoubleHashMap;
import gnu.trove.TObjectProcedure;

/* loaded from: input_file:edu/ou/utz8239/bayesnet/learning/CachingNetworkProbabilityLearner.class */
public class CachingNetworkProbabilityLearner extends NetworkProbabilityLearner {
    private final TObjectDoubleHashMap<CountCacheKey> countCache;
    private final THashMap<DistCacheKey, ProbabilityDistribution> distCache;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/ou/utz8239/bayesnet/learning/CachingNetworkProbabilityLearner$CountCacheKey.class */
    public class CountCacheKey {
        private final TIntArrayList keys;
        private final Criteria criteria;
        private final int hash = generateHash();

        public CountCacheKey(TIntArrayList tIntArrayList, Criteria criteria) {
            this.keys = new TIntArrayList(tIntArrayList.toNativeArray());
            this.criteria = criteria;
        }

        private int generateHash() {
            return (31 * ((31 * ((31 * 1) + getOuterType().hashCode())) + (this.criteria == null ? 0 : this.criteria.hashCode()))) + (this.keys == null ? 0 : this.keys.hashCode());
        }

        public int hashCode() {
            return this.hash;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            CountCacheKey countCacheKey = (CountCacheKey) obj;
            if (!getOuterType().equals(countCacheKey.getOuterType())) {
                return false;
            }
            if (this.criteria == null) {
                if (countCacheKey.criteria != null) {
                    return false;
                }
            } else if (!this.criteria.equals(countCacheKey.criteria)) {
                return false;
            }
            return this.keys == null ? countCacheKey.keys == null : this.keys.equals(countCacheKey.keys);
        }

        public boolean needsToDirty(AttributeClass attributeClass) {
            return this.criteria.findAttribute(attributeClass) != null;
        }

        private CachingNetworkProbabilityLearner getOuterType() {
            return CachingNetworkProbabilityLearner.this;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/ou/utz8239/bayesnet/learning/CachingNetworkProbabilityLearner$DistCacheKey.class */
    public class DistCacheKey {
        private final TIntArrayList keys;
        private final Criteria criteria;
        private final AttributeClass clazz;
        private final int hash = generateHash();

        public DistCacheKey(TIntArrayList tIntArrayList, AttributeClass attributeClass, Criteria criteria) {
            this.keys = new TIntArrayList(tIntArrayList.toNativeArray());
            this.criteria = criteria;
            this.clazz = attributeClass;
        }

        private int generateHash() {
            return (31 * ((31 * ((31 * ((31 * 1) + getOuterType().hashCode())) + (this.clazz == null ? 0 : this.clazz.hashCode()))) + (this.criteria == null ? 0 : this.criteria.hashCode()))) + (this.keys == null ? 0 : this.keys.hashCode());
        }

        public int hashCode() {
            return this.hash;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            DistCacheKey distCacheKey = (DistCacheKey) obj;
            if (!getOuterType().equals(distCacheKey.getOuterType())) {
                return false;
            }
            if (this.clazz == null) {
                if (distCacheKey.clazz != null) {
                    return false;
                }
            } else if (!this.clazz.equals(distCacheKey.clazz)) {
                return false;
            }
            if (this.criteria == null) {
                if (distCacheKey.criteria != null) {
                    return false;
                }
            } else if (!this.criteria.equals(distCacheKey.criteria)) {
                return false;
            }
            return this.keys == null ? distCacheKey.keys == null : this.keys.equals(distCacheKey.keys);
        }

        public boolean needsToDirty(AttributeClass attributeClass) {
            return this.criteria.findAttribute(attributeClass) != null;
        }

        private CachingNetworkProbabilityLearner getOuterType() {
            return CachingNetworkProbabilityLearner.this;
        }
    }

    public CachingNetworkProbabilityLearner(InstanceSource instanceSource) {
        super(instanceSource);
        this.countCache = new TObjectDoubleHashMap<>(MathUtils.factorial(instanceSource.getProvidedClasses().size()));
        this.distCache = new THashMap<>(MathUtils.factorial(instanceSource.getProvidedClasses().size()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner
    public void updateHiddenNode(TIntArrayList tIntArrayList, BayesianNetwork bayesianNetwork, final BayesNode bayesNode) throws Exception {
        super.updateHiddenNode(tIntArrayList, bayesianNetwork, bayesNode);
        this.countCache.forEachKey(new TObjectProcedure<CountCacheKey>() { // from class: edu.ou.utz8239.bayesnet.learning.CachingNetworkProbabilityLearner.1
            public boolean execute(CountCacheKey countCacheKey) {
                if (!countCacheKey.needsToDirty(bayesNode.getProbabilityClass())) {
                    return true;
                }
                CachingNetworkProbabilityLearner.this.countCache.remove(countCacheKey);
                return true;
            }
        });
        this.distCache.forEachKey(new TObjectProcedure<DistCacheKey>() { // from class: edu.ou.utz8239.bayesnet.learning.CachingNetworkProbabilityLearner.2
            public boolean execute(DistCacheKey distCacheKey) {
                if (!distCacheKey.needsToDirty(bayesNode.getProbabilityClass())) {
                    return true;
                }
                CachingNetworkProbabilityLearner.this.distCache.remove(distCacheKey);
                return true;
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner
    public ProbabilityDistribution calculateProbabilities(TIntArrayList tIntArrayList, AttributeClass attributeClass, Criteria criteria) throws DataTableException {
        DistCacheKey distCacheKey = new DistCacheKey(tIntArrayList, attributeClass, criteria);
        ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) this.distCache.get(distCacheKey);
        if (probabilityDistribution == null) {
            probabilityDistribution = super.calculateProbabilities(tIntArrayList, attributeClass, criteria);
            this.distCache.put(distCacheKey, probabilityDistribution);
        }
        return probabilityDistribution;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.ou.utz8239.bayesnet.learning.NetworkProbabilityLearner
    public double fuzzyCount(TIntArrayList tIntArrayList, Criteria criteria) throws DataTableException {
        CountCacheKey countCacheKey = new CountCacheKey(tIntArrayList, criteria);
        if (this.countCache.containsKey(countCacheKey)) {
            return this.countCache.get(countCacheKey);
        }
        double fuzzyCount = super.fuzzyCount(tIntArrayList, criteria);
        this.countCache.put(countCacheKey, fuzzyCount);
        return fuzzyCount;
    }
}
