Alink Talk (8): Как реализовать двухклассовую оценку AUC, KS, PRC, Precision, Recall, LiftChart

машинное обучение

0x00 сводка

Alink — это платформа алгоритмов машинного обучения нового поколения, разработанная Alibaba на основе вычислительного движка реального времени Flink.Это первая в отрасли платформа машинного обучения, которая поддерживает как пакетные, так и потоковые алгоритмы. Двухклассовая оценка предназначена для оценки влияния результатов прогнозирования двухклассового алгоритма. В этой статье будет проанализирована соответствующая реализация кода в Alink.

0x01 Связанные концепции

Если у вас есть сомнения относительно некоторых понятий в этой статье, вы можете обратиться к предыдущим статьям.[Средний анализ] Разберите понятия на примерах: точность, точность, полнота и F-мера.

0x02 Пример кода

public class EvalBinaryClassExample {

    AlgoOperator getData(boolean isBatch) {
        Row[] rows = new Row[]{
                Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
                Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
                Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
                Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
                Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
        };

        String[] schema = new String[]{"label", "detailInput"};

        if (isBatch) {
            return new MemSourceBatchOp(rows, schema);
        } else {
            return new MemSourceStreamOp(rows, schema);
        }
    }

    public static void main(String[] args) throws Exception {
        EvalBinaryClassExample test = new EvalBinaryClassExample();
        BatchOperator batchData = (BatchOperator) test.getData(true);

        BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
                .setLabelCol("label")
                .setPredictionDetailCol("detailInput")
                .linkFrom(batchData)
                .collectMetrics();

        System.out.println("RocCurve:" + metrics.getRocCurve());
        System.out.println("AUC:" + metrics.getAuc());
        System.out.println("KS:" + metrics.getKs());
        System.out.println("PRC:" + metrics.getPrc());
        System.out.println("Accuracy:" + metrics.getAccuracy());
        System.out.println("Macro Precision:" + metrics.getMacroPrecision());
        System.out.println("Micro Recall:" + metrics.getMicroRecall());
        System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
    }
}

вывод программы

RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
AUC:0.8333333333333333
KS:0.6666666666666666
PRC:0.9027777777777777
Accuracy:0.6
Macro Precision:0.3
Micro Recall:0.6
Weighted Sensitivity:0.6

В Alink есть две реализации пакетной обработки и потоковой обработки для двухклассовой оценки.Далее мы представим их одну за другой (одна из сложностей Alink заключается в большом количестве тонких структур данных, поэтому большое количество переменных в программе будет напечатано ниже для всеобщего понимания).

2.1 Основная идея

  • Разделите [0,1] на гипотетические 100 000 ведер (ячеек). Так что получите положительныйBin/negativeBin два массива по 100000.

  • Присвоить значения PositiveBin/negativeBin на основе ввода. положительныйBin — это TP + FP, а отрицательныйBin — это TN + FN. Они являются основой для последующих расчетов.

  • Пройдите каждую значимую точку в бинах, вычислите totalTrue и totalFalse и вычислите матрицу путаницы, tpr и соответствующие данные rocCurve, RecallPrecisionCurve, liftChart в этой точке в каждой точке;

  • Рассчитать и сохранить AUC/PRC/KS на основе содержания кривой

В продолжении также приводится подробный обзор отношений вызова.

Пакет 0x03

3.1 EvalBinaryClassBatchOp

EvalBinaryClassBatchOp — это реализация оценки двух классов, и ее функция заключается в вычислении метрик оценки двух классов.

Существует два типа ввода:

  • label column and predResult column
  • столбец label и столбец predDetail. Если есть predDetail, predResult игнорируется

В нашем примере"prefix1"это ярлык,"{\"prefix1\": 0.9, \"prefix0\": 0.1}"предварительноПодробнее

Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")

Конкретные классы выделяются следующим образом:

public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>, EvaluationMetricsCollector<BinaryClassMetrics> {
  
	@Override
	public BinaryClassMetrics collectMetrics() {
		return new BinaryClassMetrics(this.collect().get(0));
	}  
}

Как видите, основная его работа выполняется в базовом классе BaseEvalClassBatchOp, поэтому сначала мы рассмотрим BaseEvalClassBatchOp.

3.2 BaseEvalClassBatchOp

Мы по-прежнему начнем с функции linkFrom, которая в основном делает несколько вещей:

  • Получить информацию о конфигурации
  • Извлечь определенные столбцы из ввода: "label", "detailInput"
  • callLabelPredDetailLocal рассчитает метрики оценки в соответствии с разделом
  • Комплексное сокращение приведенных выше результатов расчета
  • Функция SaveDataAsParams введет окончательное значение в выходную таблицу.

Конкретный код выглядит следующим образом

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
    String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);

    // Judge the evaluation type from params.
    ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());

    DataSet<BaseMetricsSummary> res;
    switch (type) {
        case PRED_DETAIL: {
            String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
            // 从输入中提取某些列:"label","detailInput" 
            DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
            // 按照partition分别计算evaluation metrics
            res = calLabelPredDetailLocal(data, positiveValue, binary);
            break;
        }
        ......
    }

    // 综合reduce上述计算结果
    DataSet<BaseMetricsSummary> metrics = res
        .reduce(new EvaluationUtil.ReduceBaseMetrics());

    // 把最终数值输入到 output table
    this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
        new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});

    return (T)this;
}

// 执行中一些变量如下
labelColName = "label"
predDetailColName = "detailInput"  
type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL"
binary = true
positiveValue = null  

3.2.0 Обзор отношения вызова

Поскольку последующее отношение вызова кода сложное, сначала задайте отношение вызова:

  • Извлеките из ввода несколько столбцов: «label», «detailInput», in.select(new String[] {labelColName, predDetailColName}).getDataSet(). Поскольку во входных данных могут быть и другие столбцы, а для нашего расчета нужны только некоторые столбцы, извлекаются только эти столбцы.
  • Вычислить метрики оценки в соответствии с разделом, то есть вызвать callLabelPredDetailLocal(data,positiveValue,binary);
    • flatMap извлечет все метки из столбца метки и столбца предсказания (обратите внимание на имя метки) и отправит их нижестоящему оператору.
    • Основная функция reduceGroup состоит в том, чтобы дедуплицировать «имя метки» через buildLabelIndexLabelArray, затем присвоить каждой метке идентификатор, получить карту и, наконец, вернуть двоичный файл (карта, метки), то есть ({prefix1=0 , префикс0=1},[префикс1, префикс0]). Судя по следующему, Map используется для множественной классификации. Бинарная классификация использует только метки.
    • Раздел mapPartition вызывает CalLabelDetailLocal для вычисления матрицы путаницы, в основном раздел вызывает getDetailStatistics, а два кортежа (карта, метки), полученные в предыдущем разделе, будут переданы в качестве параметров.
      • getDetailStatistics просматривает данные строк, извлекает каждый элемент (например, "prefix1, {"prefix1": 0,8, "prefix0": 0,2}) и затем накапливает данные, необходимые для расчета матрицы путаницы с помощью updateBinaryMetricsSummary.
        • updateBinaryMetricsSummary делит [0,1] на, скажем, 100000 ячеек. Так что получите положительныйBin/negativeBin два массива по 100000. положительныйBin — это TP + FP, а отрицательныйBin — это TN + FN.
          • Если вероятность того, что выборка является положительным значением, равна p, то индекс ячейки, соответствующий выборке, равен p * 100000. PositiveBin[index]++, если p предсказано как положительное значение,
          • В противном случае прогнозируется отрицательное значение, а затем - отрицательныйBin[index]++.
  • Объедините приведенные выше результаты расчета сокращения, metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
    • Конкретный расчет находится в BinaryMetricsSummary.merge, функция которого заключается в слиянии бинов и добавлении logLoss.
  • Введите окончательное значение в выходную таблицу, setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()..);
    • После слияния всех BaseMetrics получите общие BaseMetrics, рассчитайте индексы и сохраните их в параметрах. Collector.collect(t.toMetrics().serialize());
      • Фактический бизнес находится в BinaryMetricsSummary.toMetrics, который рассчитывается на основе информации о корзине, а затем сохраняется в параметрах.
        • Функция ExtractMatrixThreCurve извлекает непустые бины и соответственно вычисляет массив ConfusionMatrix (матрица путаницы), массив порогов, rocCurve/recallPrecisionCurve/LiftChart.
          • Пройдите каждую значимую точку в бинах, вычислите totalTrue и totalFalse и вычислите в каждой точке:
          • curTrue += positiveBin[index]; curFalse += negativeBin[index];
          • Получить матрицу путаницы точки new ConfusionMatrix(new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
          • получить tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
          • rocCurve, RecotPrecisionCurve, liftChart, соответствующие данным на данный момент;
        • Рассчитать и сохранить AUC/PRC/KS на основе содержания кривой
        • Выборка сгенерированного вывода rocCurve/recallPrecisionCurve/LiftChart
        • Сохранение RocCurve/RecallPrecisionCurve/LiftChar на основе выходных данных выборки
        • Метрики для хранения положительных образцов
        • Хранить Логлосс
        • Pick the middle point where threshold is 0.5.

3.2.1 calLabelPredDetailLocal

Эта функция вычисляет метрики оценки в соответствии с разделом. Да, код короткий, но есть один нюанс. Иногда простые вещи легче пропустить. Легкие пропуски:

Метки результата первой строки кода являются параметром второй строки кода, а не телом второй строки. Тело второй строки кода совпадает с телом первой строки кода, обе из которых являются данными.

private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data, final String positiveValue, oolean binary) {
  
    DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() {
        @Override
        public void flatMap(Row row, Collector<String> collector) {
            TreeMap<String, Double> labelProbMap;
            if (EvaluationUtil.checkRowFieldNotNull(row)) {
                labelProbMap = EvaluationUtil.extractLabelProbMap(row);
                labelProbMap.keySet().forEach(collector::collect);
                collector.collect(row.getField(0).toString());
            }
        }
    }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));

    return data
        .rebalance()
        .mapPartition(new CalLabelDetailLocal(binary))
        .withBroadcastSet(labels, LABELS);
}

В callLabelPredDetailLocal есть три шага:

  • В flatMap все метки извлекаются из столбца метки и столбца предсказания (обратите внимание на имя метки) и отправляются нижестоящему оператору.
  • Основная функция reduceGroup состоит в том, чтобы дедуплицировать «имя метки», затем присвоить каждой метке идентификатор, а конечным результатом является карта .
  • mapPartition — это вызов секции CallLabelDetailLocal для вычисления матрицы путаницы.

Подробнее см. ниже.

3.2.1.1 flatMap

В flatMap, в основном, из столбца меток и столбца прогнозирования вынимаются все метки (обратите внимание, что имя меток вынимается) и отправляете нижестоящему оператору.

Функция EvaluationUtil.extractLabelProbMap состоит в том, чтобы проанализировать входной json и получить информацию в конкретном подробном входе.

Нисходящий оператор — reduceGroup, поэтому среда выполнения Flink автоматически дедуплицирует эти метки. Если вам интересна эта часть, вы можете обратиться к моей предыдущей статье о сокращении.[Анализ исходного кода] Что делают groupBy и reduce Flink?

Переменные в программе следующие

row = {Row@8922} "prefix1,{"prefix1": 0.9, "prefix0": 0.1}"
 fields = {Object[2]@8925} 
  0 = "prefix1"
  1 = "{"prefix1": 0.9, "prefix0": 0.1}"
    
labelProbMap = {TreeMap@9008}  size = 2
 "prefix0" -> {Double@9015} 0.1
 "prefix1" -> {Double@9017} 0.9
    
labelProbMap.keySet().forEach(collector::collect); //这里发送 "prefix0", "prefix1" 
collector.collect(row.getField(0).toString());  // 这里发送 "prefix1"   
// 因为下一个操作是reduceGroup,所以这些label会被runtime去重
3.2.1.2 reduceGroup

Основная функция состоит в том, чтобы дедуплицировать метки через buildLabelIndexLabelArray, а затем присвоить каждой метке идентификатор, а конечным результатом является карта .

reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));

Функция DistinctLabelIndexMap состоит в том, чтобы извлечь все различные метки из столбца метки и столбца предсказания и вернуть карту . Согласно последующему коду, эта карта используется только для множественной классификации. Получите все отдельные метки из столбца метки и столбца предсказания и верните карту меток и их идентификаторов.

Как упоминалось ранее, строки параметров здесь были автоматически дедуплицированы.

public static class DistinctLabelIndexMap implements
    GroupReduceFunction<String, Tuple2<Map<String, Integer>, String[]>> {
    ......
    @Override
    public void reduce(Iterable<String> rows, Collector<Tuple2<Map<String, Integer>, String[]>> collector) throws Exception {
        HashSet<String> labels = new HashSet<>();
        rows.forEach(labels::add);
        collector.collect(buildLabelIndexLabelArray(labels, binary, positiveValue));
    }
}

// 变量为
labels = {HashSet@9008}  size = 2
 0 = "prefix1"
 1 = "prefix0"
binary = true

Функция buildLabelIndexLabelArray состоит в том, чтобы присвоить каждой метке идентификатор, получить карту и, наконец, вернуть два кортежа (карта, метки), то есть ({prefix1=0, prefix0=1}, [prefix1 , префикс0]).

// Give each label an ID, return a map of label and ID.
public static Tuple2<Map<String, Integer>, String[]> buildLabelIndexLabelArray(HashSet<String> set,boolean binary, String positiveValue) {
    String[] labels = set.toArray(new String[0]);
    Arrays.sort(labels, Collections.reverseOrder());

    Map<String, Integer> map = new HashMap<>(labels.length);
    if (binary && null != positiveValue) {
        if (labels[1].equals(positiveValue)) {
            labels[1] = labels[0];
            labels[0] = positiveValue;
        } 
        map.put(labels[0], 0);
        map.put(labels[1], 1);
    } else {
        for (int i = 0; i < labels.length; i++) {
            map.put(labels[i], i);
        }
    }
    return Tuple2.of(map, labels);
}

// 程序变量如下
labels = {String[2]@9013} 
 0 = "prefix1"
 1 = "prefix0"
map = {HashMap@9014}  size = 2
 "prefix1" -> {Integer@9020} 0
 "prefix0" -> {Integer@9021} 1
3.2.1.3 mapPartition

Основная функция здесь — вызов CalLabelDetailLocal по разделам для подготовки к последующему вычислению матрицы путаницы.

return data
    .rebalance()
    .mapPartition(new CalLabelDetailLocal(binary)) //这里是业务所在
    .withBroadcastSet(labels, LABELS);

Конкретную работу выполняет CalLabelDetailLocal, функция которого состоит в вызове getDetailStatistics по разделам.

// Calculate the confusion matrix based on the label and predResult.
static class CalLabelDetailLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> {
        private Tuple2<Map<String, Integer>, String[]> map;
        private boolean binary;

        @Override
        public void open(Configuration parameters) throws Exception {
            List<Tuple2<Map<String, Integer>, String[]>> list = getRuntimeContext().getBroadcastVariable(LABELS);
            this.map = list.get(0);// 前文生成的二元组(map, labels)
        }

        @Override
        public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector) {
            // 调用到了 getDetailStatistics
            collector.collect(getDetailStatistics(rows, binary, map));
        }
    }  

Функция getDetailStatistics заключается в инициализации метрик оценки базовой классификации для оценки классификации и накоплении данных, необходимых для расчета матрицы путаницы. Главное — просмотреть данные строк, извлечь каждый элемент (например, «prefix1, {»prefix1»: 0,8, «prefix0»: 0,2}), а затем накопить данные, необходимые для расчета матрицы путаницы.

// Initialize the base classification evaluation metrics. There are two cases: BinaryClassMetrics and MultiClassMetrics.
    private static BaseMetricsSummary getDetailStatistics(Iterable<Row> rows,
                                         String positiveValue,
                                         boolean binary,
                                         Tuple2<Map<String, Integer>, String[]> tuple) {
        BinaryMetricsSummary binaryMetricsSummary = null;
        MultiMetricsSummary multiMetricsSummary = null;
        Tuple2<Map<String, Integer>, String[]> labelIndexLabelArray = tuple;  // 前文生成的二元组(map, labels)

        Iterator<Row> iterator = rows.iterator();
        Row row = null;
        while (iterator.hasNext() && !checkRowFieldNotNull(row)) {
            row = iterator.next();
        }

        Map<String, Integer> labelIndexMap = null;
        if (binary) {
           // 二分法在这里 
            binaryMetricsSummary = new BinaryMetricsSummary(
                new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],
                new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],
                labelIndexLabelArray.f1, 0.0, 0L);
        } else {
            // 
            labelIndexMap = labelIndexLabelArray.f0; // 前文生成的<labels, ID>Map看来是多分类才用到。
            multiMetricsSummary = new MultiMetricsSummary(
                new long[labelIndexMap.size()][labelIndexMap.size()],
                labelIndexLabelArray.f1, 0.0, 0L);
        }

        while (null != row) {
            if (checkRowFieldNotNull(row)) {
                TreeMap<String, Double> labelProbMap = extractLabelProbMap(row);
                String label = row.getField(0).toString();
                if (ArrayUtils.indexOf(labelIndexLabelArray.f1, label) >= 0) {
                    if (binary) {
                        // 二分法在这里 
                        updateBinaryMetricsSummary(labelProbMap, label, binaryMetricsSummary);
                    } else {
                        updateMultiMetricsSummary(labelProbMap, label, labelIndexMap, multiMetricsSummary);
                    }
                }
            }
            row = iterator.hasNext() ? iterator.next() : null;
        }

        return binary ? binaryMetricsSummary : multiMetricsSummary;
}

//变量如下
tuple = {Tuple2@9252} "({prefix1=0, prefix0=1},[prefix1, prefix0])"
 f0 = {HashMap@9257}  size = 2
  "prefix1" -> {Integer@9264} 0
  "prefix0" -> {Integer@9266} 1
 f1 = {String[2]@9258} 
  0 = "prefix1"
  1 = "prefix0"
 
row = {Row@9271} "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"
 fields = {Object[2]@9276} 
  0 = "prefix1"
  1 = "{"prefix1": 0.8, "prefix0": 0.2}"
    
labelIndexLabelArray = {Tuple2@9240} "({prefix1=0, prefix0=1},[prefix1, prefix0])"
 f0 = {HashMap@9288}  size = 2
  "prefix1" -> {Integer@9294} 0
  "prefix0" -> {Integer@9296} 1
 f1 = {String[2]@9242} 
  0 = "prefix1"
  1 = "prefix0"
    
labelProbMap = {TreeMap@9342}  size = 2
 "prefix0" -> {Double@9378} 0.1
 "prefix1" -> {Double@9380} 0.9    

Сначала вспомните матрицу путаницы:

Прогнозируемое значение 0 Прогнозируемое значение 1
истинное значение 0 TN FP
истинное значение 1 FN TP

Для матрицы путаницы роль BinaryMetricsSummary заключается в сохранении данных оценки для двоичной классификации. Конкретная идея расчета функции такова:

  • Разделите [0,1] на столько-то сегментов (бинов) ClassificationEvaluationUtil.DETAIL_BIN_NUMBER (100000). Таким образом, PositiveBin/negativeBin для binaryMetricsSummary — это два массива по 100 000 соответственно. Если вероятность того, что выборка является положительным значением, равна p, то индекс ячейки, соответствующий выборке, равен p * 100000. Если предсказано, что p будет положительным значением, то PositiveBin[index]++, в противном случае предсказано, что будет отрицательное значение (отрицательное значение), тогда отрицательныйBin[index]++. положительныйBin — это TP + FP, а отрицательныйBin — это TN + FN.

  • Так что здесь вход будет пройден, если определенный вход (с"prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"пример), 0,9 — это вероятность для префикса 1 (положительный пример), а 0,1 — вероятность для префикса 0 (отрицательный пример).

    • Поскольку этот алгоритм выбирает префикс 1 (положительный пример), это означает, что этот алгоритм распознается как положительный, поэтому +1 добавляется к 90000 положительного бина.
    • Предполагая, что алгоритм выбирает префикс 0 (отрицательный пример), это означает, что алгоритм распознается как отрицательный, поэтому он должен быть +1 при 90000 отрицательного бина.

В частности, в соответствии с 5 примерами нашего примера кода классификация выглядит следующим образом:

Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),  positiveBin 90000处+1
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),  positiveBin 80000处+1
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),  positiveBin 70000处+1
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), negativeBin 75000处+1
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")  negativeBin 60000处+1

Конкретный код выглядит следующим образом

public static void updateBinaryMetricsSummary(TreeMap<String, Double> labelProbMap,
                                              String label,
                                              BinaryMetricsSummary binaryMetricsSummary) {
    binaryMetricsSummary.total++;
    binaryMetricsSummary.logLoss += extractLogloss(labelProbMap, label);

    double d = labelProbMap.get(binaryMetricsSummary.labels[0]);
    int idx = d == 1.0 ? ClassificationEvaluationUtil.DETAIL_BIN_NUMBER - 1 :
        (int)Math.floor(d * ClassificationEvaluationUtil.DETAIL_BIN_NUMBER);
    if (idx >= 0 && idx < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER) {
        if (label.equals(binaryMetricsSummary.labels[0])) {
            binaryMetricsSummary.positiveBin[idx] += 1;
        } else if (label.equals(binaryMetricsSummary.labels[1])) {
            binaryMetricsSummary.negativeBin[idx] += 1;
        } else {
					.....
        }
    }
}

private static double extractLogloss(TreeMap<String, Double> labelProbMap, String label) {
   Double prob = labelProbMap.get(label);
   prob = null == prob ? 0. : prob;
   return -Math.log(Math.max(Math.min(prob, 1 - LOG_LOSS_EPS), LOG_LOSS_EPS));
}

// 变量如下
ClassificationEvaluationUtil.DETAIL_BIN_NUMBER=100000
  
// 当 "prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}" 时候
labelProbMap = {TreeMap@9305}  size = 2
 "prefix0" -> {Double@9331} 0.1
 "prefix1" -> {Double@9333} 0.9
  
d = 0.9
idx = 90000
binaryMetricsSummary = {BinaryMetricsSummary@9262} 
 labels = {String[2]@9242} 
  0 = "prefix1"
  1 = "prefix0"
 total = 1
 positiveBin = {long[100000]@9263}  // 90000处+1
 negativeBin = {long[100000]@9264} 
 logLoss = 0.10536051565782628
   
// 当 "prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}" 时候  
labelProbMap = {TreeMap@9514}  size = 2
 "prefix0" -> {Double@9546} 0.4
 "prefix1" -> {Double@9547} 0.6
   
d = 0.6
idx = 60000    
 binaryMetricsSummary = {BinaryMetricsSummary@9262} 
 labels = {String[2]@9242} 
  0 = "prefix1"
  1 = "prefix0"
 total = 2
 positiveBin = {long[100000]@9263}  
 negativeBin = {long[100000]@9264} // 60000处+1
 logLoss = 1.0216512475319812  

3.2.2 ReduceBaseMetrics

Функция ReduceBaseMetrics заключается в агрегировании локально рассчитанных BaseMetrics.

DataSet<BaseMetricsSummary> metrics = res
    .reduce(new EvaluationUtil.ReduceBaseMetrics());

ReduceBaseMetrics выглядит следующим образом

public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
    @Override
    public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
        return null == t1 ? t2 : t1.merge(t2);
    }
}

Конкретный расчет находится в BinaryMetricsSummary.merge, функция которого заключается в слиянии бинов и добавлении logLoss.

@Override
public BinaryMetricsSummary merge(BinaryMetricsSummary binaryClassMetrics) {
    for (int i = 0; i < this.positiveBin.length; i++) {
        this.positiveBin[i] += binaryClassMetrics.positiveBin[i];
    }
    for (int i = 0; i < this.negativeBin.length; i++) {
        this.negativeBin[i] += binaryClassMetrics.negativeBin[i];
    }
    this.logLoss += binaryClassMetrics.logLoss;
    this.total += binaryClassMetrics.total;
    return this;
}

// 程序变量是
this = {BinaryMetricsSummary@9316} 
 labels = {String[2]@9322} 
  0 = "prefix1"
  1 = "prefix0"
 total = 2
 positiveBin = {long[100000]@9320} 
 negativeBin = {long[100000]@9323} 
 logLoss = 1.742969305058623

3.2.3 SaveDataAsParams

this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
    new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});

После слияния всех BaseMetrics получается общее количество BaseMetrics, индексы рассчитываются и сохраняются в params.

public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> {
    @Override
    public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception {
        collector.collect(t.toMetrics().serialize());
    }
}

Собственно дело завершается в BinaryMetricsSummary.toMetrics, то есть на основе информационного расчета бинов получается массив путаницыMatrix, массив порогов, rocCurve/recallPrecisionCurve/LiftChart и т. д., которые затем сохраняются в params.

public BinaryClassMetrics toMetrics() {
    Params params = new Params();
    // 生成若干曲线,比如rocCurve/recallPrecisionCurve/LiftChart
    Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> matrixThreCurve =
        extractMatrixThreCurve(positiveBin, negativeBin, total);

    // 依据曲线内容计算并且存储 AUC/PRC/KS
    setCurveAreaParams(params, matrixThreCurve.f2);

    // 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
    Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> sampledMatrixThreCurve = sample(
        PROBABILITY_INTERVAL, matrixThreCurve);

    // 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
    setCurvePointsParams(params, sampledMatrixThreCurve);
    ConfusionMatrix[] matrices = sampledMatrixThreCurve.f0;
  
    // 存储正例样本的度量指标
    setComputationsArrayParams(params, sampledMatrixThreCurve.f1, sampledMatrixThreCurve.f0);
  
    // 存储Logloss
    setLoglossParams(params, logLoss, total);
  
    // Pick the middle point where threshold is 0.5.
    int middleIndex = getMiddleThresholdIndex(sampledMatrixThreCurve.f1);  
    setMiddleThreParams(params, matrices[middleIndex], labels);
    return new BinaryClassMetrics(params);
}

ExtractMatrixThreCurve находится в центре внимания всего текста. Вот извлеките ячейки, которые не пусты, оставьте средний порог 0,5, затем инициализируйте RocCurve, Recall-Precision Curve и Lift Curve и рассчитайте массив ConfusionMatrix (матрица путаницы), массив порогов, rocCurve/recallPrecisionCurve/LiftChart..

/**
 * Extract the bins who are not empty, keep the middle threshold 0.5.
 * Initialize the RocCurve, Recall-Precision Curve and Lift Curve.
 * RocCurve: (FPR, TPR), starts with (0,0). Recall-Precision Curve: (recall, precision), starts with (0, p), p is the precision with the lowest. LiftChart: (TP+FP/total, TP), starts with (0,0). confusion matrix = [TP FP][FN * TN].
 *
 * @param positiveBin positiveBins.
 * @param negativeBin negativeBins.
 * @param total       sample number
 * @return ConfusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart.
 */
static Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> extractMatrixThreCurve(long[] positiveBin, long[] negativeBin, long total) {
    ArrayList<Integer> effectiveIndices = new ArrayList<>();
    long totalTrue = 0, totalFalse = 0;
  
    // 计算totalTrue,totalFalse,effectiveIndices
    for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
        if (0L != positiveBin[i] || 0L != negativeBin[i]
            || i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) {
            effectiveIndices.add(i);
            totalTrue += positiveBin[i];
            totalFalse += negativeBin[i];
        }
    }

// 以我们例子,得到  
effectiveIndices = {ArrayList@9273}  size = 6
 0 = {Integer@9277} 50000 //这里加入了中间点
 1 = {Integer@9278} 60000
 2 = {Integer@9279} 70000
 3 = {Integer@9280} 75000
 4 = {Integer@9281} 80000
 5 = {Integer@9282} 90000
totalTrue = 3
totalFalse = 2
  
    // 继续初始化,生成若干curve
    final int length = effectiveIndices.size();
    final int newLen = length + 1;
    final double m = 1.0 / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER;
    EvaluationCurvePoint[] rocCurve = new EvaluationCurvePoint[newLen];
    EvaluationCurvePoint[] recallPrecisionCurve = new EvaluationCurvePoint[newLen];
    EvaluationCurvePoint[] liftChart = new EvaluationCurvePoint[newLen];
    ConfusionMatrix[] data = new ConfusionMatrix[newLen];
    double[] threshold = new double[newLen];
    long curTrue = 0;
    long curFalse = 0;
  
// 以我们例子,得到 
length = 6
newLen = 7
m = 1.0E-5
  
    // 计算, 其中rocCurve,recallPrecisionCurve,liftChart 都可以从代码中看出
    for (int i = 1; i < newLen; i++) {
        int index = effectiveIndices.get(length - i);
        curTrue += positiveBin[index];
        curFalse += negativeBin[index];
        threshold[i] = index * m;
        // 计算出混淆矩阵
        data[i] = new ConfusionMatrix(
            new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
        double tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
        // 比如当 90000 这点,得到 curTrue = 1 curFalse = 0 i = 1 index = 90000 tpr = 0.3333333333333333。totalTrue = 3 totalFalse = 2, 
        // 我们也知道,TPR = TP / (TP + FN) ,所以可以计算 tpr = 1 / 3   
        rocCurve[i] = new EvaluationCurvePoint(totalFalse == 0 ? 1.0 : 1.0 * curFalse / totalFalse, tpr, threshold[i]);
        recallPrecisionCurve[i] = new EvaluationCurvePoint(tpr, curTrue + curTrue == 0 ? 1.0 : 1.0 * curTrue / (curTrue + curFalse), threshold[i]);
        liftChart[i] = new EvaluationCurvePoint(1.0 * (curTrue + curFalse) / total, curTrue, threshold[i]);
    }
  
// 以我们例子,得到 
curTrue = 3
curFalse = 2
  
threshold = {double[7]@9349} 
 0 = 0.0
 1 = 0.9
 2 = 0.8
 3 = 0.7500000000000001
 4 = 0.7000000000000001
 5 = 0.6000000000000001
 6 = 0.5  
   
rocCurve = {EvaluationCurvePoint[7]@9315} 
 1 = {EvaluationCurvePoint@9440} 
  x = 0.0
  y = 0.3333333333333333
  p = 0.9
 2 = {EvaluationCurvePoint@9448} 
  x = 0.0
  y = 0.6666666666666666
  p = 0.8
 3 = {EvaluationCurvePoint@9449} 
  x = 0.5
  y = 0.6666666666666666
  p = 0.7500000000000001
 4 = {EvaluationCurvePoint@9450} 
  x = 0.5
  y = 1.0
  p = 0.7000000000000001
 5 = {EvaluationCurvePoint@9451} 
  x = 1.0
  y = 1.0
  p = 0.6000000000000001
 6 = {EvaluationCurvePoint@9452} 
  x = 1.0
  y = 1.0
  p = 0.5
    
recallPrecisionCurve = {EvaluationCurvePoint[7]@9320} 
 1 = {EvaluationCurvePoint@9444} 
  x = 0.3333333333333333
  y = 1.0
  p = 0.9
 2 = {EvaluationCurvePoint@9453} 
  x = 0.6666666666666666
  y = 1.0
  p = 0.8
 3 = {EvaluationCurvePoint@9454} 
  x = 0.6666666666666666
  y = 0.6666666666666666
  p = 0.7500000000000001
 4 = {EvaluationCurvePoint@9455} 
  x = 1.0
  y = 0.75
  p = 0.7000000000000001
 5 = {EvaluationCurvePoint@9456} 
  x = 1.0
  y = 0.6
  p = 0.6000000000000001
 6 = {EvaluationCurvePoint@9457} 
  x = 1.0
  y = 0.6
  p = 0.5
    
liftChart = {EvaluationCurvePoint[7]@9325} 
 1 = {EvaluationCurvePoint@9458} 
  x = 0.2
  y = 1.0
  p = 0.9
 2 = {EvaluationCurvePoint@9459} 
  x = 0.4
  y = 2.0
  p = 0.8
 3 = {EvaluationCurvePoint@9460} 
  x = 0.6
  y = 2.0
  p = 0.7500000000000001
 4 = {EvaluationCurvePoint@9461} 
  x = 0.8
  y = 3.0
  p = 0.7000000000000001
 5 = {EvaluationCurvePoint@9462} 
  x = 1.0
  y = 3.0
  p = 0.6000000000000001
 6 = {EvaluationCurvePoint@9463} 
  x = 1.0
  y = 3.0
  p = 0.5
    
data = {ConfusionMatrix[7]@9339} 
 0 = {ConfusionMatrix@9486} 
  longMatrix = {LongMatrix@9488} 
   matrix = {long[2][]@9491} 
    0 = {long[2]@9492} 
     0 = 0
     1 = 0
    1 = {long[2]@9493} 
     0 = 3
     1 = 2
   rowNum = 2
   colNum = 2
  labelCnt = 2
  total = 5
  actualLabelFrequency = {long[2]@9489} 
   0 = 3
   1 = 2
  predictLabelFrequency = {long[2]@9490} 
   0 = 0
   1 = 5
  tpCount = 2.0
  tnCount = 2.0
  fpCount = 3.0
  fnCount = 3.0
 1 = {ConfusionMatrix@9435} 
  longMatrix = {LongMatrix@9469} 
   matrix = {long[2][]@9472} 
    0 = {long[2]@9474} 
     0 = 1
     1 = 0
    1 = {long[2]@9475} 
     0 = 2
     1 = 2
   rowNum = 2
   colNum = 2
  labelCnt = 2
  total = 5
  actualLabelFrequency = {long[2]@9470} 
   0 = 3
   1 = 2
  predictLabelFrequency = {long[2]@9471} 
   0 = 1
   1 = 4
  tpCount = 3.0
  tnCount = 3.0
  fpCount = 2.0
  fnCount = 2.0
  ......  
    
    threshold[0] = 1.0;
    data[0] = new ConfusionMatrix(new long[][] {{0, 0}, {totalTrue, totalFalse}});
    rocCurve[0] = new EvaluationCurvePoint(0, 0, threshold[0]);
    recallPrecisionCurve[0] = new EvaluationCurvePoint(0, recallPrecisionCurve[1].getY(), threshold[0]);
    liftChart[0] = new EvaluationCurvePoint(0, 0, threshold[0]);

    return Tuple3.of(data, threshold, new EvaluationCurve[] {new EvaluationCurve(rocCurve),
        new EvaluationCurve(recallPrecisionCurve), new EvaluationCurve(liftChart)});
}

3.2.4 Вычисление матрицы путаницы

Здесь я расскажу вам, как рассчитать матрицу путаницы Идея здесь довольно запутанная.

3.2.4.1 Исходная матрица

где вызов:

// 调用之处
data[i] = new ConfusionMatrix(
        new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
// 调用时候各种赋值
i = 1
index = 90000
totalTrue = 3
totalFalse = 2
curTrue = 1
curFalse = 0

Получить исходную матрицу, все нижеследующее cur, описание только для текущей точки.

curTrue = 1 curFalse = 0
totalTrue - curTrue = 2 totalFalse - curFalse = 2
3.2.4.2 Расчетные метки

В последующем вычислении ConfusionMatrix мы можем получить

actualLabelFrequency = longMatrix.getColSums();
predictLabelFrequency = longMatrix.getRowSums();

actualLabelFrequency = {long[2]@9322} 
 0 = 3
 1 = 2
predictLabelFrequency = {long[2]@9323} 
 0 = 1
 1 = 4  

Можно видеть, что алгоритм Alink считает, что сумма каждого столбца связана с фактической меткой, а сумма каждой строки связана с предсказанной меткой.

Новая матрица получается следующим образом

predictLabelFrequency
curTrue = 1 curFalse = 0 1 = curTrue + curFalse
totalTrue - curTrue = 2 totalFalse - curFalse = 2 4 = total - curTrue - curFalse
actualLabelFrequency 3 = totalTrue 2 = totalFalse

Дальнейшие расчеты будут основываться на них:

В расчете используются данные по диагонали longMatrix, а именно longMatrix(0)(0) и longMatrix(1)(1). Следует отметить, что здесь рассматриваетсяТекущий статус (нарисовано для акцента).

матрица дракона (0) (0): грубая правда

Матрица дракона (1) (1): полная ложь - грубый человек ложь

totalFalse : (TN + FN)

totalTrue : (TP + FP)

double numTrueNegative(Integer labelIndex) {
  // labelIndex为 0 时候,return 1 + 5 - 1 - 3 = 2;
  // labelIndex为 1 时候,return 2 + 5 - 4 - 2 = 1;
	return null == labelIndex ? tnCount : longMatrix.getValue(labelIndex, labelIndex) + total - predictLabelFrequency[labelIndex] - actualLabelFrequency[labelIndex];
}

double numTruePositive(Integer labelIndex) {
  // labelIndex为 0 时候,return 1; 这个是 curTrue,就是真实标签是True,判别也是True。是TP
  // labelIndex为 1 时候,return 2; 这个是 totalFalse - curFalse,总判别错 - 当前判别错。这就意味着“本来判别错了但是当前没有发现”,所以认为在当前状态下,这也算是TP
	return null == labelIndex ? tpCount : longMatrix.getValue(labelIndex, labelIndex);
}

double numFalseNegative(Integer labelIndex) {
  // labelIndex为 0 时候,return 3 - 1; 
  // actualLabelFrequency[0] = totalTrue。所以return totalTrue - curTrue,即当前“全部正确”中没有“判别为正确”,这个就可以认为是“判别错了且判别为负”
  // labelIndex为 1 时候,return 2 - 2;   
  // actualLabelFrequency[1] = totalFalse。所以return totalFalse - ( totalFalse - curFalse )  = curFalse
	return null == labelIndex ? fnCount : actualLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
}

double numFalsePositive(Integer labelIndex) {
  // labelIndex为 0 时候,return 1 - 1;
  // predictLabelFrequency[0] = curTrue + curFalse。
  // 所以 return = curTrue + curFalse - curTrue = curFalse = current( TN + FN ) 这可以认为是判断错了实际是正确标签
  // labelIndex为 1 时候,return 4 - 2; 
  // predictLabelFrequency[1] = total - curTrue - curFalse。
  // 所以 return = total - curTrue - curFalse - (totalFalse - curFalse) = totalTrue - curTrue = ( TP + FP ) - currentTP = currentFP 
	return null == labelIndex ? fpCount : predictLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
}

// 最后得到
tpCount = 3.0
tnCount = 3.0
fpCount = 2.0
fnCount = 2.0
3.2.4.3 Специальный код
// 具体计算 
public ConfusionMatrix(LongMatrix longMatrix) {
  
longMatrix = {LongMatrix@9297} 
  0 = {long[2]@9324} 
   0 = 1
   1 = 0
  1 = {long[2]@9325} 
   0 = 2
   1 = 2
     
    this.longMatrix = longMatrix;
    labelCnt = this.longMatrix.getRowNum();
    // 这里就是计算
    actualLabelFrequency = longMatrix.getColSums();
    predictLabelFrequency = longMatrix.getRowSums();
  
actualLabelFrequency = {long[2]@9322} 
 0 = 3
 1 = 2
predictLabelFrequency = {long[2]@9323} 
 0 = 1
 1 = 4  
labelCnt = 2
total = 5  

    total = longMatrix.getTotal();
    for (int i = 0; i < labelCnt; i++) {
        tnCount += numTrueNegative(i);
        tpCount += numTruePositive(i);
        fnCount += numFalseNegative(i);
        fpCount += numFalsePositive(i);
    }
}

обработка потока 0x04

4.1 Пример

В исходном примере кода Alink на Python часть Stream не имеет вывода, потому что MemSourceStreamOp не связан со временем, а Alink не предоставляет основанный на времени StreamOperator, поэтому я могу написать только один, имитирующий MemSourceBatchOp. Код немного уродлив, но, по крайней мере, он обеспечивает вывод, чтобы его можно было отладить.

4.1.1 Основной класс

public class EvalBinaryClassExampleStream {

    AlgoOperator getData(boolean isBatch) {
        Row[] rows = new Row[]{
                Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
        };
        String[] schema = new String[]{"label", "detailInput"};
        if (isBatch) {
            return new MemSourceBatchOp(rows, schema);
        } else {
            return new TimeMemSourceStreamOp(rows, schema, new EvalBinaryStreamSource());
        }
    }

    public static void main(String[] args) throws Exception {
        EvalBinaryClassExampleStream test = new EvalBinaryClassExampleStream();
        StreamOperator streamData = (StreamOperator) test.getData(false);
        StreamOperator sOp = new EvalBinaryClassStreamOp()
                .setLabelCol("label")
                .setPredictionDetailCol("detailInput")
                .setTimeInterval(1)
                .linkFrom(streamData);
        sOp.print();
        StreamOperator.execute();
    }
}

4.1.2 TimeMemSourceStreamOp

Я сделал это сам. Заимствовано из MemSourceStreamOp.

public final class TimeMemSourceStreamOp extends StreamOperator<TimeMemSourceStreamOp> {

    public TimeMemSourceStreamOp(Row[] rows, String[] colNames, EvalBinaryStrSource source) {
        super(null);
        init(source, Arrays.asList(rows), colNames);
    }

    private void init(EvalBinaryStreamSource source, List <Row> rows, String[] colNames) {
        Row first = rows.iterator().next();
        int arity = first.getArity();
        TypeInformation <?>[] types = new TypeInformation[arity];

        for (int i = 0; i < arity; ++i) {
            types[i] = TypeExtractor.getForObject(first.getField(i));
        }

        init(source, colNames, types);
    }

    private void init(EvalBinaryStreamSource source, String[] colNames, TypeInformation <?>[] colTypes) {
        DataStream <Row> dastr = MLEnvironmentFactory.get(getMLEnvironmentId())
                .getStreamExecutionEnvironment().addSource(source);
        StringBuilder sbd = new StringBuilder();
        sbd.append(colNames[0]);
      
        for (int i = 1; i < colNames.length; i++) {
            sbd.append(",").append(colNames[i]);
        }
        this.setOutput(dastr, colNames, colTypes);
    }

    @Override
    public TimeMemSourceStreamOp linkFrom(StreamOperator<?>... inputs) {
        return null;
    }
}

4.1.3 Source

Строка предоставляется регулярно, и случайные числа добавляются для изменения вероятности.

class EvalBinaryStreamSource extends RichSourceFunction[Row] {

  override def run(ctx: SourceFunction.SourceContext[Row]) = {
    while (true) {
      val rdm = Math.random() // 这里加入了随机数,让概率有变化
      val rows: Array[Row] = Array[Row](
        Row.of("prefix1", "{\"prefix1\": " + rdm + ", \"prefix0\": " + (1-rdm) + "}"),
        Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
        Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
        Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
        Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}"))
      for(row <- rows) {
        println(s"当前值:$row")
        ctx.collect(row)
      }
      Thread.sleep(1000)
    }
  }

  override def cancel() = ???
}

4.2 BaseEvalClassStreamOp

Класс обработки потока Alink — EvalBinaryClassStreamOp, который в основном работает со своим базовым классом BaseEvalClassStreamOp, поэтому мы сосредоточимся на последнем.

public class BaseEvalClassStreamOp<T extends BaseEvalClassStreamOp<T>> extends StreamOperator<T> {
    @Override
    public T linkFrom(StreamOperator<?>... inputs) {
        StreamOperator<?> in = checkAndGetFirst(inputs);
        String labelColName = this.get(MultiEvaluationStreamParams.LABEL_COL);
        String positiveValue = this.get(BinaryEvaluationStreamParams.POS_LABEL_VAL_STR);
        Integer timeInterval = this.get(MultiEvaluationStreamParams.TIME_INTERVAL);

        ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());

        DataStream<BaseMetricsSummary> statistics;

        switch (type) {
            case PRED_RESULT: {
              ......
            }
            case PRED_DETAIL: {               
                String predDetailColName = this.get(MultiEvaluationStreamParams.PREDICTION_DETAIL_COL);
                // 
                PredDetailLabel eval = new PredDetailLabel(positiveValue, binary);
                // 获取输入数据,重点是timeWindowAll
                statistics = in.select(new String[] {labelColName, predDetailColName})
                    .getDataStream()
                    .timeWindowAll(Time.of(timeInterval, TimeUnit.SECONDS))
                    .apply(eval);
                break;
            }
        }
        // 把各个窗口的数据累积到 totalStatistics,注意,这里是新变量了。
        DataStream<BaseMetricsSummary> totalStatistics = statistics
            .map(new EvaluationUtil.AllDataMerge())
            .setParallelism(1); // 并行度设置为1

        // 基于两种 bins 计算&序列化,得到当前的 statistics
        DataStream<Row> windowOutput = statistics.map(
            new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
        // 基于bins计算&序列化,得到累积的 totalStatistics
        DataStream<Row> allOutput = totalStatistics.map(
            new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));

      	// "当前" 和 "累积" 做联合,最终返回
        DataStream<Row> union = windowOutput.union(allOutput);

        this.setOutput(union,
            new String[] {ClassificationEvaluationUtil.STATISTICS_OUTPUT, DATA_OUTPUT},
            new TypeInformation[] {Types.STRING, Types.STRING});

        return (T)this;
    }
}

Конкретный бизнес это:

  • PredDetailLabel выполнит дедупликацию имени метки и соберет данные, необходимые для расчета матрицы путаницы.
    • buildLabelIndexLabelArray дедуплицирует «имя метки», затем присваивает каждой метке идентификатор, а конечным результатом является карта .
    • getDetailStatistics просматривает данные строк, извлекает каждый элемент (например, "prefix1, {"prefix1": 0,8, "prefix0": 0,2}) и затем накапливает данные, необходимые для расчета матрицы путаницы с помощью updateBinaryMetricsSummary.
  • Получить данные из окна в соответствии со статистикой метки = in.select().getDataStream().timeWindowAll() .apply(eval);
  • EvaluationUtil.AllDataMerge накапливает данные каждого окна в totalStatistics.
  • Получите windowOutput -------- EvaluationUtil.SaveDataStream и обработайте "текущую статистику данных". Фактический бизнес находится в BinaryMetricsSummary.toMetrics, который рассчитывается на основе информации о корзине, затем сохраняется в параметрах, сериализуется и возвращается в строку.
    • Функция ExtractMatrixThreCurve извлекает непустые бины и соответственно вычисляет массив ConfusionMatrix (матрица путаницы), массив порогов, rocCurve/recallPrecisionCurve/LiftChart.
    • Рассчитать и сохранить AUC/PRC/KS на основе содержания кривой
    • Выборка сгенерированного вывода rocCurve/recallPrecisionCurve/LiftChart
    • Сохранение RocCurve/RecallPrecisionCurve/LiftChar на основе выходных данных выборки
    • Метрики для хранения положительных образцов
    • Хранить Логлосс
    • Pick the middle point where threshold is 0.5.
  • Получите allOutput -------- EvaluationUtil.SaveDataStream и обработайте «накопленные данные totalStatistics».
    • Подробный процесс обработки такой же, как и для windowOutput.
  • windowOutput и allOutput выполняют объединение. Наконец, верните DataStream union = windowOutput.union(allOutput);

4.2.1 PredDetailLabel

static class PredDetailLabel implements AllWindowFunction<Row, BaseMetricsSummary, TimeWindow> {
    @Override
    public void apply(TimeWindow timeWindow, Iterable<Row> rows, Collector<BaseMetricsSummary> collector) throws Exception {
        HashSet<String> labels = new HashSet<>();
        // 首先还是获取 labels 名字
        for (Row row : rows) {
            if (EvaluationUtil.checkRowFieldNotNull(row)) {
                labels.addAll(EvaluationUtil.extractLabelProbMap(row).keySet());
                labels.add(row.getField(0).toString());
            }
        }
labels = {HashSet@9757}  size = 2
 0 = "prefix1"
 1 = "prefix0"   
        // 之前介绍过,buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels, ID>Map。
        // getDetailStatistics 遍历 rows 数据,累积计算混淆矩阵所需数据( "TP + FN"  /  "TN + FP")。
        if (labels.size() > 0) {
            collector.collect(
                getDetailStatistics(rows, binary, buildLabelIndexLabelArray(labels, binary, positiveValue)));
        }
    }
}

4.2.2 AllDataMerge

EvaluationUtil.AllDataMerge накапливает данные каждого окна

/**
 * Merge data from different windows.
 */
public static class AllDataMerge implements MapFunction<BaseMetricsSummary, BaseMetricsSummary> {
    private BaseMetricsSummary statistics;
    @Override
    public BaseMetricsSummary map(BaseMetricsSummary value) {
        this.statistics = (null == this.statistics ? value : this.statistics.merge(value));
        return this.statistics;
    }
}

4.2.3 SaveDataStream

Функции, специально вызываемые SaveDataStream, были введены в пакетную обработку ранее.Фактический бизнес находится в BinaryMetricsSummary.toMetrics, который вычисляется на основе информации о корзине и сохраняется в параметрах.

Отличие от пакетной обработки здесь в том, что «информация о созданных метриках» возвращается непосредственно пользователю.

public static class SaveDataStream implements MapFunction<BaseMetricsSummary, Row> {
    @Override
    public Row map(BaseMetricsSummary baseMetricsSummary) throws Exception {
        BaseMetricsSummary metrics = baseMetricsSummary;
        BaseMetrics baseMetrics = metrics.toMetrics();
        Row row = baseMetrics.serialize();
        return Row.of(funtionName, row.getField(0));
    }
}

// 最后得到的 row 其实就是最终返回给用户的度量信息
row = {Row@10008} "{"PRC":"0.9164636268708667","SensitivityArray":"[0.38461538461538464,0.6923076923076923,0.6923076923076923,1.0,1.0,1.0]","ConfusionMatrix":"[[13,8],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.0,0.5,0.5,1.0,1.0]" ...... 还有很多其他的

4.2.4 Union

DataStream<Row> windowOutput = statistics.map(
    new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
DataStream<Row> allOutput = totalStatistics.map(
    new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));

DataStream<Row> union = windowOutput.union(allOutput);

Наконец вернуть два вида статистики

4.2.4.1 allOutput
all|{"PRC":"0.7341146115890359","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","ConfusionMatrix":"[[13,10],[2,0]]","MacroRecall":"0.43333333333333335","MacroSpecificity":"0.43333333333333335","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","AUC":"0.5666666666666667","MacroAccuracy":"0.52", ......

4.2.4.2 windowOutput

window|{"PRC":"0.7638888888888888","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","ConfusionMatrix":"[[3,2],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","AUC":"0.6666666666666666","MacroAccuracy":"0.6","RecallArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","KappaArray":"[0.28571428571428564,-0.15384615384615377,0.1666666666666666,0.5454545454545455,0.0,0.0]","MicroFalseNegativeRate":"0.4","WeightedRecall":"0.6","WeightedPrecision":"0.36","Recall":"1.0","MacroPrecision":"0.3",......

ссылка 0xFF

[Средний анализ] Разберите понятия на примерах: точность, точность, полнота и F-мера.

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat: мысли Росси

Если вы хотите получать своевременные новости о статьях, написанных отдельными лицами, или хотите видеть технические материалы, рекомендованные отдельными лицами, обратите внимание.