package weka.classifiers.mi;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.mi.miti.AlgorithmConfiguration;
import weka.classifiers.mi.miti.Bag;
import weka.classifiers.mi.miti.NextSplitHeuristic;
import weka.classifiers.mi.miti.TreeNode;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

/* loaded from: input_file:weka/classifiers/mi/MITI.class */
public class MITI extends RandomizableClassifier implements OptionHandler, AdditionalMeasureProducer, TechnicalInformationHandler, MultiInstanceCapabilitiesHandler {
    static final long serialVersionUID = -217735168397644244L;
    protected MultiInstanceDecisionTree tree;
    public static final int SPLITMETHOD_GINI = 1;
    public static final int SPLITMETHOD_MAXBEPP = 2;
    public static final int SPLITMETHOD_SSBEPP = 3;
    public static final Tag[] TAGS_SPLITMETHOD = {new Tag(1, "Gini: E * (1 - E)"), new Tag(2, "MaxBEPP: E"), new Tag(3, "Sum Squared BEPP: E * E")};
    protected int m_SplitMethod = 2;
    protected boolean m_scaleK = false;
    protected boolean m_useBagCount = false;
    protected boolean m_unbiasedEstimate = false;
    protected int m_kBEPPConstant = 5;
    protected int m_AttributesToSplit = -1;
    protected int m_AttributeSplitChoices = 1;
    protected double m_bagInstanceMultiplier = 0.5d;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/mi/MITI$MultiInstanceDecisionTree.class */
    public class MultiInstanceDecisionTree implements Serializable {
        private static final long serialVersionUID = 4037700809781784985L;
        private TreeNode root;
        private final HashMap<Instance, Bag> m_instanceBags;
        private int numLeaves;

        public int getNumLeaves() {
            return this.numLeaves;
        }

        protected MultiInstanceDecisionTree(Instances instances) {
            this.numLeaves = 0;
            this.m_instanceBags = new HashMap<>();
            ArrayList<Instance> arrayList = new ArrayList<>();
            double d = 0.0d;
            double d2 = 0.0d;
            Iterator<Instance> it = instances.iterator();
            while (it.hasNext()) {
                Bag bag = new Bag(it.next());
                Iterator<Instance> it2 = bag.instances().iterator();
                while (it2.hasNext()) {
                    Instance next = it2.next();
                    this.m_instanceBags.put(next, bag);
                    arrayList.add(next);
                }
                d2 += 1.0d;
                d += bag.instances().numInstances();
            }
            double d3 = d / d2;
            if (MITI.this.m_scaleK) {
                Iterator<Bag> it3 = this.m_instanceBags.values().iterator();
                while (it3.hasNext()) {
                    it3.next().setBagWeightMultiplier(d3);
                }
            }
            makeTree(this.m_instanceBags, arrayList, false);
        }

        public MultiInstanceDecisionTree(HashMap<Instance, Bag> hashMap, ArrayList<Instance> arrayList, boolean z) {
            this.numLeaves = 0;
            this.m_instanceBags = hashMap;
            makeTree(hashMap, arrayList, z);
        }

        private void makeTree(HashMap<Instance, Bag> hashMap, ArrayList<Instance> arrayList, boolean z) {
            Random random = new Random(MITI.this.getSeed());
            AlgorithmConfiguration settings = MITI.this.getSettings();
            ArrayList arrayList2 = new ArrayList();
            this.root = new TreeNode(null, arrayList);
            arrayList2.add(this.root);
            this.numLeaves = 0;
            while (arrayList2.size() > 0) {
                TreeNode treeNode = (TreeNode) arrayList2.remove(random.nextInt(Math.min(1, arrayList2.size())));
                if (treeNode != null) {
                    if (treeNode.isPurePositive(hashMap)) {
                        treeNode.makeLeafNode(true);
                        ArrayList arrayList3 = new ArrayList();
                        treeNode.deactivateRelatedInstances(hashMap, arrayList3);
                        if (MITI.this.m_Debug && arrayList3.size() > 0) {
                            Bag.printDeactivatedInstances(arrayList3);
                        }
                        Iterator it = arrayList2.iterator();
                        while (it.hasNext()) {
                            TreeNode treeNode2 = (TreeNode) it.next();
                            treeNode2.removeDeactivatedInstances(hashMap);
                            treeNode2.calculateNodeScore(hashMap, MITI.this.m_unbiasedEstimate, MITI.this.m_kBEPPConstant, MITI.this.m_useBagCount, MITI.this.m_bagInstanceMultiplier);
                        }
                        if (z && arrayList3.size() > 0) {
                            return;
                        }
                    } else if (treeNode.isPureNegative(hashMap)) {
                        treeNode.makeLeafNode(false);
                    } else {
                        treeNode.splitInstances(hashMap, settings, random, MITI.this.m_Debug);
                        if (treeNode.isLeafNode()) {
                            if (treeNode.isPositiveLeaf()) {
                                Iterator it2 = arrayList2.iterator();
                                while (it2.hasNext()) {
                                    TreeNode treeNode3 = (TreeNode) it2.next();
                                    treeNode3.removeDeactivatedInstances(hashMap);
                                    treeNode3.calculateNodeScore(hashMap, MITI.this.m_unbiasedEstimate, MITI.this.m_kBEPPConstant, MITI.this.m_useBagCount, MITI.this.m_bagInstanceMultiplier);
                                }
                                if (z) {
                                    return;
                                }
                            }
                        } else if (treeNode.split.isNominal) {
                            for (TreeNode treeNode4 : treeNode.nominals()) {
                                treeNode4.calculateNodeScore(hashMap, MITI.this.m_unbiasedEstimate, MITI.this.m_kBEPPConstant, MITI.this.m_useBagCount, MITI.this.m_bagInstanceMultiplier);
                                arrayList2.add(treeNode4);
                            }
                        } else {
                            treeNode.left().calculateNodeScore(hashMap, MITI.this.m_unbiasedEstimate, MITI.this.m_kBEPPConstant, MITI.this.m_useBagCount, MITI.this.m_bagInstanceMultiplier);
                            arrayList2.add(treeNode.left());
                            treeNode.right().calculateNodeScore(hashMap, MITI.this.m_unbiasedEstimate, MITI.this.m_kBEPPConstant, MITI.this.m_useBagCount, MITI.this.m_bagInstanceMultiplier);
                            arrayList2.add(treeNode.right());
                        }
                    }
                    if (treeNode.isLeafNode()) {
                        this.numLeaves++;
                    }
                    Collections.sort(arrayList2, Collections.reverseOrder(new NextSplitHeuristic()));
                }
            }
            if (MITI.this.m_Debug) {
                System.out.println(this.root.render(1, hashMap));
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public boolean isPositive(Instance instance) {
            TreeNode traverseTree = traverseTree(instance);
            return traverseTree != null && traverseTree.isPositiveLeaf();
        }

        private TreeNode traverseTree(Instance instance) {
            TreeNode treeNode;
            TreeNode treeNode2 = this.root;
            while (true) {
                treeNode = treeNode2;
                if (treeNode == null || treeNode.isLeafNode()) {
                    break;
                }
                Attribute attribute = treeNode.split.attribute;
                treeNode2 = attribute.isNominal() ? treeNode.nominals()[(int) instance.value(attribute)] : instance.value(attribute) < treeNode.split.splitPoint ? treeNode.left() : treeNode.right();
            }
            return treeNode;
        }

        public String render() {
            return this.root.render(0, this.m_instanceBags);
        }

        public boolean trimNegativeBranches() {
            return this.root.trimNegativeBranches();
        }

        public int[] numPosRulesAndNumPosConditions() {
            return numPosRulesAndNumPosConditions(this.root);
        }

        private int[] numPosRulesAndNumPosConditions(TreeNode treeNode) {
            int[] iArr = new int[2];
            if (treeNode == null || !treeNode.isLeafNode()) {
                if (treeNode != null) {
                    if (treeNode.split.attribute.isNominal()) {
                        for (TreeNode treeNode2 : treeNode.nominals()) {
                            int[] numPosRulesAndNumPosConditions = numPosRulesAndNumPosConditions(treeNode2);
                            iArr[0] = iArr[0] + numPosRulesAndNumPosConditions[0];
                            iArr[1] = iArr[1] + numPosRulesAndNumPosConditions[1] + numPosRulesAndNumPosConditions[0];
                        }
                    } else {
                        int[] numPosRulesAndNumPosConditions2 = numPosRulesAndNumPosConditions(treeNode.left());
                        iArr[0] = iArr[0] + numPosRulesAndNumPosConditions2[0];
                        iArr[1] = iArr[1] + numPosRulesAndNumPosConditions2[1] + numPosRulesAndNumPosConditions2[0];
                        int[] numPosRulesAndNumPosConditions3 = numPosRulesAndNumPosConditions(treeNode.right());
                        iArr[0] = iArr[0] + numPosRulesAndNumPosConditions3[0];
                        iArr[1] = iArr[1] + numPosRulesAndNumPosConditions3[1] + numPosRulesAndNumPosConditions3[0];
                    }
                }
            } else if (treeNode.isPositiveLeaf()) {
                iArr[0] = 1;
            }
            return iArr;
        }
    }

    public String globalInfo() {
        return "MITI (Multi Instance Tree Inducer): multi-instance classification  based a decision tree learned using Blockeel et al.'s algorithm. For more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Hendrik Blockeel and David Page and Ashwin Srinivasan");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Multi-instance Tree Learning");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the International Conference on Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2005");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "57-64");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "ACM");
        TechnicalInformation add = technicalInformation.add(TechnicalInformation.Type.INPROCEEDINGS);
        add.setValue(TechnicalInformation.Field.AUTHOR, "Luke Bjerring and Eibe Frank");
        add.setValue(TechnicalInformation.Field.TITLE, "Beyond Trees: Adopting MITI to Learn Rules and Ensemble Classifiers for Multi-instance Data");
        add.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the Australasian Joint Conference on Artificial Intelligence");
        add.setValue(TechnicalInformation.Field.YEAR, "2011");
        add.setValue(TechnicalInformation.Field.PUBLISHER, "Springer");
        return technicalInformation;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.disable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    @Override // weka.core.MultiInstanceCapabilitiesHandler
    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.tree = new MultiInstanceDecisionTree(instances);
    }

    public Enumeration<String> enumerateMeasures() {
        Vector vector = new Vector(3);
        vector.addElement("measureNumRules");
        vector.addElement("measureNumPositiveRules");
        vector.addElement("measureNumConditionsInPositiveRules");
        return vector.elements();
    }

    public double getMeasure(String str) {
        if (str.equalsIgnoreCase("measureNumRules")) {
            return this.tree.getNumLeaves();
        }
        if (str.equalsIgnoreCase("measureNumPositiveRules")) {
            return this.tree.numPosRulesAndNumPosConditions()[0];
        }
        if (str.equalsIgnoreCase("measureNumConditionsInPositiveRules")) {
            return this.tree.numPosRulesAndNumPosConditions()[1];
        }
        throw new IllegalArgumentException(str + " not supported (MultiInstanceRuleLearner)");
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[2];
        boolean z = false;
        Iterator<Instance> it = instance.relationalValue(1).iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (this.tree.isPositive(it.next())) {
                z = true;
                break;
            }
        }
        dArr[1] = z ? 1.0d : KStarConstants.FLOOR;
        dArr[0] = 1.0d - dArr[1];
        return dArr;
    }

    protected AlgorithmConfiguration getSettings() {
        return new AlgorithmConfiguration(this.m_SplitMethod, this.m_unbiasedEstimate, this.m_kBEPPConstant, this.m_useBagCount, this.m_bagInstanceMultiplier, this.m_AttributesToSplit, this.m_AttributeSplitChoices);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tThe method used to determine best split:\n\t1. Gini; 2. MaxBEPP; 3. SSBEPP", "M", 1, "-M [1|2|3]"));
        vector.addElement(new Option("\tThe constant used in the tozero() hueristic", "K", 1, "-K [kBEPPConstant]"));
        vector.addElement(new Option("\tScales the value of K to the size of the bags", "L", 0, "-L"));
        vector.addElement(new Option("\tUse unbiased estimate rather than BEPP, i.e. UEPP.", "U", 0, "-U"));
        vector.addElement(new Option("\tUses the instances present for the bag counts at each node when splitting,\n\tweighted according to 1 - Ba ^ n, where n is the number of instances\n\tpresent which belong to the bag, and Ba is another parameter (default 0.5)", "B", 0, "-B"));
        vector.addElement(new Option("\tMultiplier for count influence of a bag based on the number of its instances", "Ba", 1, "-Ba [multiplier]"));
        vector.addElement(new Option("\tThe number of randomly selected attributes to split\n\t-1: All attributes\n\t-2: square root of the total number of attributes", "A", 1, "-A [number of attributes]"));
        vector.addElement(new Option("\tThe number of top scoring attribute splits to randomly pick from\n\t-1: All splits (completely random selection)\n\t-2: square root of the number of splits", "An", 1, "-An [number of splits]"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('M', strArr);
        if (option.length() != 0) {
            setSplitMethod(new SelectedTag(Integer.parseInt(option), TAGS_SPLITMETHOD));
        } else {
            setSplitMethod(new SelectedTag(2, TAGS_SPLITMETHOD));
        }
        String option2 = Utils.getOption('K', strArr);
        if (option2.length() != 0) {
            setK(Integer.parseInt(option2));
        } else {
            setK(5);
        }
        setL(Utils.getFlag('L', strArr));
        setUnbiasedEstimate(Utils.getFlag('U', strArr));
        String option3 = Utils.getOption('A', strArr);
        if (option3.length() != 0) {
            setAttributesToSplit(Integer.parseInt(option3));
        } else {
            setAttributesToSplit(-1);
        }
        String option4 = Utils.getOption("An", strArr);
        if (option4.length() != 0) {
            setTopNAttributesToSplit(Integer.parseInt(option4));
        } else {
            setTopNAttributesToSplit(1);
        }
        setB(Utils.getFlag('B', strArr));
        String option5 = Utils.getOption("Ba", strArr);
        if (option5.length() != 0) {
            setBa(Double.parseDouble(option5));
        } else {
            setBa(0.5d);
        }
        super.setOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-K");
        vector.add("" + this.m_kBEPPConstant);
        if (getL()) {
            vector.add("-L");
        }
        if (getUnbiasedEstimate()) {
            vector.add("-U");
        }
        if (getB()) {
            vector.add("-B");
        }
        vector.add("-Ba");
        vector.add("" + this.m_bagInstanceMultiplier);
        vector.add("-M");
        vector.add("" + this.m_SplitMethod);
        vector.add("-A");
        vector.add("" + this.m_AttributesToSplit);
        vector.add("-An");
        vector.add("" + this.m_AttributeSplitChoices);
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String kTipText() {
        return "The value used in the tozero() method.";
    }

    public int getK() {
        return this.m_kBEPPConstant;
    }

    public void setK(int i) {
        this.m_kBEPPConstant = i;
    }

    public String lTipText() {
        return "Whether to scale based on the number of instances.";
    }

    public boolean getL() {
        return this.m_scaleK;
    }

    public void setL(boolean z) {
        this.m_scaleK = z;
    }

    public String unbiasedEstimateTipText() {
        return "Whether to used unbiased estimate (EPP instead of BEPP).";
    }

    public boolean getUnbiasedEstimate() {
        return this.m_unbiasedEstimate;
    }

    public void setUnbiasedEstimate(boolean z) {
        this.m_unbiasedEstimate = z;
    }

    public String bTipText() {
        return "Whether to use bag-based statistics for estimates of proportion.";
    }

    public boolean getB() {
        return this.m_useBagCount;
    }

    public void setB(boolean z) {
        this.m_useBagCount = z;
    }

    public String baTipText() {
        return "Multiplier for count influence of a bag based on the number of its instances.";
    }

    public double getBa() {
        return this.m_bagInstanceMultiplier;
    }

    public void setBa(double d) {
        this.m_bagInstanceMultiplier = d;
    }

    public String attributesToSplitTipText() {
        return "The number of randomly chosen attributes to consider for splitting.";
    }

    public int getAttributesToSplit() {
        return this.m_AttributesToSplit;
    }

    public void setAttributesToSplit(int i) {
        this.m_AttributesToSplit = i;
    }

    public String topNAttributesToSplitTipText() {
        return "Value of N to use for top-N attributes to choose randomly from.";
    }

    public int getTopNAttributesToSplit() {
        return this.m_AttributeSplitChoices;
    }

    public void setTopNAttributesToSplit(int i) {
        this.m_AttributeSplitChoices = i;
    }

    public String splitMethodTipText() {
        return "The method used to determine best split: 1. Gini; 2. MaxBEPP; 3. SSBEPP";
    }

    public void setSplitMethod(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_SPLITMETHOD) {
            this.m_SplitMethod = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getSplitMethod() {
        return new SelectedTag(this.m_SplitMethod, TAGS_SPLITMETHOD);
    }

    public String toString() {
        if (this.tree == null) {
            return "No model built yet!";
        }
        return (this.tree.render() + "\n\nNumber of positive rules: " + getMeasure("measureNumPositiveRules") + "\n") + "Number of conditions in positive rules: " + getMeasure("measureNumConditionsInPositiveRules") + "\n";
    }

    public static void main(String[] strArr) {
        runClassifier(new MITI(), strArr);
    }
}
