/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sequence;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.LabelMetric;
import org.tribuo.classification.evaluation.LabelMetrics;
import org.tribuo.classification.sequence.LabelSequenceEvaluation;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.provenance.EvaluationProvenance;
import org.tribuo.sequence.AbstractSequenceEvaluator;
import org.tribuo.sequence.SequenceModel;

public class LabelSequenceEvaluator
extends AbstractSequenceEvaluator<Label, LabelMetric.Context, LabelSequenceEvaluation, LabelMetric> {
    protected Set<LabelMetric> createMetrics(SequenceModel<Label> model) {
        HashSet<LabelMetric> metrics = new HashSet<LabelMetric>();
        for (Label label : model.getOutputIDInfo().getDomain()) {
            MetricTarget tgt = new MetricTarget((Output)label);
            metrics.add(LabelMetrics.TP.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.FP.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.TN.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.FN.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.PRECISION.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.RECALL.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.F1.forTarget((MetricTarget<Label>)tgt));
            metrics.add(LabelMetrics.ACCURACY.forTarget((MetricTarget<Label>)tgt));
        }
        MetricTarget micro = MetricTarget.microAverageTarget();
        metrics.add(LabelMetrics.TP.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.FP.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.TN.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.FN.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.PRECISION.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.RECALL.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.F1.forTarget((MetricTarget<Label>)micro));
        metrics.add(LabelMetrics.ACCURACY.forTarget((MetricTarget<Label>)micro));
        MetricTarget macro = MetricTarget.macroAverageTarget();
        metrics.add(LabelMetrics.TP.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.FP.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.TN.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.FN.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.PRECISION.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.RECALL.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.F1.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.ACCURACY.forTarget((MetricTarget<Label>)macro));
        metrics.add(LabelMetrics.BALANCED_ERROR_RATE.forTarget((MetricTarget<Label>)macro));
        return metrics;
    }

    protected LabelMetric.Context createContext(SequenceModel<Label> model, List<List<Prediction<Label>>> predictions) {
        return new LabelMetric.Context(model, LabelSequenceEvaluator.flattenList(predictions));
    }

    protected LabelSequenceEvaluation createEvaluation(LabelMetric.Context ctx, Map<MetricID<Label>, Double> results, EvaluationProvenance provenance) {
        return new LabelSequenceEvaluation(results, ctx, provenance);
    }

    private static List<Prediction<Label>> flattenList(List<List<Prediction<Label>>> predictions) {
        ArrayList<Prediction<Label>> flatList = new ArrayList<Prediction<Label>>();
        for (List<Prediction<Label>> list : predictions) {
            flatList.addAll(list);
        }
        return flatList;
    }
}

