Alink's Talk (21): Анализ исходного кода регрессионной оценки

машинное обучение

0x00 сводка

Alink — это платформа алгоритмов машинного обучения нового поколения, разработанная Alibaba на основе вычислительного движка реального времени Flink.Это первая в отрасли платформа машинного обучения, которая поддерживает как пакетные, так и потоковые алгоритмы. В этой статье вы проанализируете реализацию регрессионной оценки в Alink.

Это самое простое время для анализа Alink. Потому что концепция и логика реализации здесь очень ясны.

0x01 Фоновая концепция

1.1 Описание функций

Оценка регрессии предназначена для оценки влияния результатов прогнозирования алгоритма регрессии, и поддерживаются следующие индикаторы оценки. Эти показатели в основном являются понятиями в области статистики.

1.2 Специальные показатели

Alink предоставляет следующие показатели:

countРяды

SSTСумма в квадрате для итога, которая измеряет разброс Y в выборке.

SST=i=1N(yiyˉ)2SST=\sum_{i=1}^{N}(y_i-\bar{y})^2

SSEСумма квадратов ошибки, мера общего отклонения выборки.

SSE=i=1N(yifi)2"SSE=\sum_{i=1}^{N}(y_i-f_i)^2"

SSRСумма квадратов для регрессии измеряет выборочную вариацию остатков.

SSR=i=1N(fiyˉ)2SSR=\sum_{i=1}^{N}(f_i-\bar{y})^2

R^2Коэффициент детерминации используется для оценки того, хорошо ли уравнение регрессии соответствует выборочным данным.Коэффициент детерминации обеспечивает меру согласия для оцененного уравнения регрессии.

R2=1SSESSTR^2=1-\dfrac{SSE}{SST}

RКоэффициент множественной корреляции — это мера линейной зависимости между случайной величиной и группой случайных величин.

R=R2R=\sqrt{R^2}

MSEСреднеквадратическая ошибка, среднеквадратическая ошибка (стандартное отклонение) и дисперсия используются для описания степени дисперсии набора данных.

Среднеквадратическая ошибка является более удобной мерой «средней ошибки», с помощью которой можно оценить степень разброса данных. С точки зрения категории, он относится к комбинации предсказания, оценки и предсказания; буквально «среднее» относится к среднему, то есть к нахождению его среднего, а «дисперсия» используется для измерения случайных величин и их оценочных значений. в теории вероятностей (это мера степени отклонения между средним значением), а «ошибка» может пониматься как ошибка между измеренным значением и истинным значением.

MSE=1Ni=1N(fiyi)2MSE=\dfrac{1}{N}\sum_{i=1}^{N}(f_i-y_i)^2

RMSEСреднеквадратичная ошибка

RMSE=MSERMSE=\sqrt{MSE}

SAE/SADАбсолютная ошибка (сумма абсолютной ошибки/разности)

SAE=i=1NfiyiSAE=\sum_{i=1}^{N}|f_i-y_i|

MAE/MADСредняя абсолютная ошибка/разница

MAE=1Ni=1NfiyiMAE=\dfrac{1}{N}\sum_{i=1}^{N}|f_i-y_i|

MAPEСредняя абсолютная ошибка в процентах

MAPE=100Ni=1NfiyiyiMAPE=\dfrac{100}{N}\sum_{i=1}^{N}|\dfrac{f_i-y_i}{y_i}|

explained varianceобъясненная дисперсия

explainedVariance=SSRNexplained Variance=\dfrac{SSR}{N}

0x02 Пример кода

Выньте пример кода Alink напрямую.

public class EvalRegressionBatchOpExp {
    
    public static void main(String[] args) throws Exception {
        Row[] data =
                new Row[] {
                        Row.of(0.4, 0.5),
                        Row.of(0.3, 0.5),
                        Row.of(0.2, 0.6),
                        Row.of(0.6, 0.7),
                        Row.of(0.1, 0.5)
                };

        MemSourceBatchOp input = new MemSourceBatchOp(data, new String[] {"label", "pred"});

        RegressionMetrics metrics = new EvalRegressionBatchOp()
                .setLabelCol("label")
                .setPredictionCol("pred")
                .linkFrom(input)
                .collectMetrics();

        System.out.println(metrics.getRmse());
        System.out.println(metrics.getR2());
        System.out.println(metrics.getSse());
        System.out.println(metrics.getMape());
        System.out.println(metrics.getMae());
        System.out.println(metrics.getSsr());
        System.out.println(metrics.getSst());
    }
}

Результат:

0.27568097504180444
-1.5675675675675653
0.38
141.66666666666669
0.24
0.31999999999999973
0.14800000000000013

0x03 Общая логика

Общая логика такова:

  • Вызвать CalcLocal для расчета различных статистических значений по секциям;
  • reduce вызывает ReduceBaseMetrics для слияния различных статистических значений;
  • Вызовите SaveDataAsParams для сохранения;

getLabelCol — это y, а getPredictionCol — это y_hat.

public EvalRegressionBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator in = checkAndGetFirst(inputs);

    // 这里就是找到y, y_hat
    TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getLabelCol());
    TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getPredictionCol());
	
  	// 利用y, y_hat来构建Metrics
    TableUtil.assertNumericalCols(in.getSchema(), this.getLabelCol(), this.getPredictionCol());
    DataSet<Row> out = in.select(new String[] {this.getLabelCol(), this.getPredictionCol()})
        .getDataSet()
        .rebalance()
        .mapPartition(new CalcLocal())
        .reduce(new EvaluationUtil.ReduceBaseMetrics())
        .flatMap(new EvaluationUtil.SaveDataAsParams());

    this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(),
        out, new TableSchema(new String[] {"regression_eval_result"}, new TypeInformation[] {Types.STRING})
    ));
    return this;
}

0x04 Статистика расчета раздела

Вызовите CalcLocal для вычисления различных статистических значений по секциям и косвенно вызовите getRegressionStatistics.

/**
 * Get the label sum, predResult sum, SSE, MAE, MAPE of one partition.
 */
public static class CalcLocal implements MapPartitionFunction<Row, BaseMetricsSummary> {
    @Override
    public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector)
        throws Exception {
        collector.collect(getRegressionStatistics(rows));
    }
}

Функция getRegressionStatistics состоит в том, чтобы просматривать входные данные, вычислять различные накопленные значения внутри этого раздела и готовиться к последующим действиям.

/**
 * Calculate the RegressionMetrics from local data.
 *
 * @param rows Input rows, the first field is label value, the second field is prediction value.
 * @return RegressionMetricsSummary.
 */
public static RegressionMetricsSummary getRegressionStatistics(Iterable<Row> rows) {
    RegressionMetricsSummary regressionSummary = new RegressionMetricsSummary();
    for (Row row : rows) {
        if (checkRowFieldNotNull(row)) {
            double yVal = ((Number)row.getField(0)).doubleValue();
            double predictVal = ((Number)row.getField(1)).doubleValue();
            double diff = Math.abs(yVal - predictVal);
            regressionSummary.ySumLocal += yVal;
            regressionSummary.ySum2Local += yVal * yVal;
            regressionSummary.predSumLocal += predictVal;
            regressionSummary.predSum2Local += predictVal * predictVal;
            regressionSummary.maeLocal += diff;
            regressionSummary.sseLocal += diff * diff;
            regressionSummary.mapeLocal += Math.abs(diff / yVal);
            regressionSummary.total++;
        }
    }
    return regressionSummary.total == 0 ? null : regressionSummary;
}

0x05 Статистика слияния

reduce вызывает ReduceBaseMetrics для слияния различной статистики:

/**
 * Merge the BaseMetrics calculated locally.
 */
public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
    @Override
    public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
        return null == t1 ? t2 : t1.merge(t2);
    }
}

Модель хранения 0x06

Это вызывает SaveDataAsParams для сохранения модели.

/**
 * After merging all the BaseMetrics, we get the total BaseMetrics. Calculate the indexes and save them into params.
 */
public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> {
    @Override
    public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception {
        collector.collect(t.toMetrics().serialize());
    }
}

0x07 toMetrics

Наконец, представлены статистические показатели.

public RegressionMetrics toMetrics() {
    Params params = new Params();
    params.set(RegressionMetrics.SST, ySum2Local - ySumLocal * ySumLocal / total);
    params.set(RegressionMetrics.SSE, sseLocal);
    params.set(RegressionMetrics.SSR,
        predSum2Local - 2 * ySumLocal * predSumLocal / total + ySumLocal * ySumLocal / total);
    params.set(RegressionMetrics.R2, 1 - params.get(RegressionMetrics.SSE) / params.get(RegressionMetrics.SST));
    params.set(RegressionMetrics.R, Math.sqrt(params.get(RegressionMetrics.R2)));
    params.set(RegressionMetrics.MSE, params.get(RegressionMetrics.SSE) / total);
    params.set(RegressionMetrics.RMSE, Math.sqrt(params.get(RegressionMetrics.MSE)));
    params.set(RegressionMetrics.SAE, maeLocal);
    params.set(RegressionMetrics.MAE, params.get(RegressionMetrics.SAE) / total);
    params.set(RegressionMetrics.COUNT, (double)total);
    params.set(RegressionMetrics.MAPE, mapeLocal * 100 / total);
    params.set(RegressionMetrics.Y_MEAN, ySumLocal / total);
    params.set(RegressionMetrics.PREDICTION_MEAN, predSumLocal / total);
    params.set(RegressionMetrics.EXPLAINED_VARIANCE, params.get(RegressionMetrics.SSR) / total);

    return new RegressionMetrics(params);
}

окончательный результат

params = {Params@9098} "Params {R2=-1.5675675675675693, predictionMean=0.5599999999999999, SSE=0.38, count=5.0, MAPE=141.66666666666666, RMSE=0.27568097504180444, MAE=0.24, R=NaN, SSR=0.3200000000000002, yMean=0.32, SST=0.1479999999999999, SAE=1.2, Explained Variance=0.06400000000000003, MSE=0.076}"
 params = {HashMap@9101}  size = 14
  "R2" -> "-1.5675675675675693"
  "predictionMean" -> "0.5599999999999999"
  "SSE" -> "0.38"
  "count" -> "5.0"
  "MAPE" -> "141.66666666666666"
  "RMSE" -> "0.27568097504180444"
  "MAE" -> "0.24"
  "R" -> "NaN"
  "SSR" -> "0.3200000000000002"
  "yMean" -> "0.32"
  "SST" -> "0.1479999999999999"
  "SAE" -> "1.2"
  "Explained Variance" -> "0.06400000000000003"
  "MSE" -> "0.076"

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

Если вы хотите получать своевременные новости о статьях, написанных отдельными лицами, или хотите видеть технические материалы, рекомендованные отдельными лицами, обратите внимание.

ссылка 0xFF

среднеквадратическая ошибка