Alink's Talk (19): Квантиль дискретизации анализа исходного кода

машинное обучение

0x00 сводка

Alink — это платформа алгоритмов машинного обучения нового поколения, разработанная Alibaba на основе вычислительного движка реального времени Flink.Это первая в отрасли платформа машинного обучения, которая поддерживает как пакетные, так и потоковые алгоритмы. В этой статье вы проанализируете реализацию Quantile в Alink.

Поскольку общедоступная информация Alink слишком мала, нижеследующее - все предположения, и обязательно будут упущения и ошибки. Я надеюсь, что все укажут, и я обновлю их в любое время.

Причина написания этой статьи в том, что я хочу проанализировать GBDT и обнаружить, что GBDT включает использование Quantile, поэтому я могу сначала проанализировать только Quantile.

0x01 Фоновая концепция

1.1 Дискретность

Дискретизация: это отображение конечных людей в бесконечном пространстве в конечное пространство (обработка биннинга). Большинство операций дискретизации данных выполняются над непрерывными данными, и распределение диапазона данных после обработки изменится с непрерывных атрибутов на дискретные атрибуты.

Метод дискретизации повлияет на последующее моделирование данных и эффекты приложения:

  • Использование деревьев решений имеет тенденцию отдавать предпочтение небольшому количеству интервалов дискретизации, а слишком большая дискретизация приведет к тому, что фрагментированные интервалы повлияют на слишком много правил.
  • Правила ассоциации должны дискретизировать все функции вместе.Правила ассоциации сосредоточены на отношениях ассоциации всех функций.Если каждый столбец дискретизирован отдельно, общая регулярность будет потеряна.

Результаты дискретизации непрерывных данных можно разделить на две категории:

  • Один представляет собой набор, который делит непрерывные данные на определенные интервалы, такие как {(0,10], (10,20], (20,50],(50,100]};
  • Один класс предназначен для разделения непрерывных данных на определенные классы, такие как класс 1, класс 2, класс 3;

1.2 Квантиль

Квантиль, также известный как квантиль, относится к числовой точке, которая делит диапазон распределения вероятностей случайной величины на несколько равных частей, процентилей и т. д.

Если имеется 1000 чисел (положительных чисел), квантили 5%, 30%, 50%, 70%, 99% этих чисел равны [3,0, 5,0, 6,0, 9,0, 12,0] соответственно, что означает, что

  • 5% чисел находятся в диапазоне от 0 до 3,0.
  • 25% чисел попадают в диапазон от 3,0 до 5,0.
  • 20% чисел попадают в диапазон от 5,0 до 6,0.
  • 20% чисел попадают в диапазон от 6,0 до 9,0.
  • 29% чисел попадают в диапазон от 9,0 до 12,0.
  • 1% чисел больше 12,0

Это статистическое понимание квантилей.

Следовательно, чтобы найти квантиль числа в группе чисел, вам нужно только отсортировать группу чисел, затем посчитать число, меньшее или равное числу, и разделить на общее количество чисел.

Два способа определения позиций p-квантиля

  • position = (n+1)p
  • position = 1 + (n-1)p

1.3 Квартили

Здесь мы используем квартили для дальнейшего объяснения.

Квартиль концепция: Расположите заданные случайные значения от меньшего к большему и разделите их на четыре равные части.Значения в трех точках деления являются квартилями.

1-й квартиль (Q1), также известный как «меньший квартиль», равен 25-му процентилю всех значений в выборке, расположенных от меньшего к большему.

2-й квартиль (Q2), также известный как «медиана», равен 50-му процентилю всех значений в выборке, расположенных от меньшего к большему.

3-й квартиль (Q3), также известный как «больший квартиль», равен 75-му процентилю всех значений в выборке, расположенных от меньшего к большему.

Межквартильный диапазон(Межквартильный диапазон, IQR) = разрыв между 3-м квартилем и 1-м квартилем.

0x02 Пример кода

Что завершает функцию квантиля в Alink?QuantileDiscretizer.QuantileDiscretizerВводите столбцы непрерывных признаков и выводите категориальные признаки.

  • Квантильная дискретизация вычисляет квантили для выбранных столбцов, а затем использует эти квантили для дискретизации. Сгенерируйте q-квантиль, соответствующий выбранному столбцу, один из которых может быть указан для всех столбцов или один для каждого столбца.
  • Количество бинов (количество необходимых дискретов, т.е. разбитых на сегменты) определяется параметромnumBuckets(количество ковшей) указать. Границы бинов получаются с помощью алгоритма аппроксимации.

Пример кода этой статьи выглядит следующим образом.

public class QuantileDiscretizerExample {
    public static void main(String[] args) throws Exception {
        NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(1001, 2000, "col0"); // 就是把1001 ~ 2000 这个连续数值分段

        Pipeline pipeline = new Pipeline()
                .add(new QuantileDiscretizer()
                        .setNumBuckets(6) // 指定分箱数数目
                        .setSelectedCols(new String[]{"col0"}));

        List<Row> result = pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect();
        System.out.println(result);
    }
}

вывод

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 
.....
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
.....
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]

0x03 Общая логика

Сначала приведем общую логическую легенду

-------------------------------- 准备阶段 --------------------------------
       │
       │
       │  
┌───────────────────┐ 
│  getSelectedCols  │ 获取需要分位的列名字
└───────────────────┘ 
       │
       │
       │
┌─────────────────────┐ 
│     quantileNum     │ 获取分箱数
└─────────────────────┘ 
       │
       │
       │
┌──────────────────────┐ 
│ Preprocessing.select │ 从输入中根据列名字select出数据
└──────────────────────┘ 
       │
       │
       │
-------------------------------- 预处理阶段 --------------------------------
       │ 
       │
       │
┌──────────────────────┐ 
│       quantile       │ 后续步骤 就是 计算分位数
└──────────────────────┘ 
       │
       │
       │ 
┌────────────────────────────────┐ 
│   countElementsPerPartition    │ 在每一个partition中获取该分区的所有元素个数
└────────────────────────────────┘ 
       │ <task id, count in this task>
       │
       │
┌──────────────────────┐ 
│       sum(1)         │ 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数
└──────────────────────┘ 
       │  
       │
       │
┌──────────────────────┐ 
│        map           │ 取出所有元素个数,cnt在后续会使用
└──────────────────────┘ 
       │    
       │    
       │
       │    
┌──────────────────────┐ 
│     missingCount     │ 分区查找应选的列中,有哪些数据没有被查到,比如zeroAsMissing, null, isNaN
└──────────────────────┘ 
       │
       │
       │
┌────────────────┐ 
│  mapPartition  │ 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来
└────────────────┘ 
       │ <idx in row, item in row>, 即<row中第几个元素,元素>
       │
       │  
┌──────────────┐ 
│    pSort     │ 将flatten数据进行排序
└──────────────┘ 
       │ 返回的是二元组
       │ f0: dataset which is indexed by partition id
       │ f1: dataset which has partition id and count
       │ 
       │  
-------------------------------- 计算阶段 --------------------------------
       │ 
       │
       │ 
┌─────────────────┐ 
│  MultiQuantile  │ 后续都是具体计算步骤
└─────────────────┘ 
       │
       │ 
       │
┌─────────────────┐ 
│      open       │ 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)
└─────────────────┘ 
       │
       │ 
       │
┌─────────────────┐ 
│  mapPartition   │ 具体计算
└─────────────────┘         
       │
       │ 
       │
┌─────────────────┐ 
│    groupBy(0)   │ 依据 列idx 分组
└─────────────────┘   
       │
       │ 
       │
┌─────────────────┐ 
│   reduceGroup   │ 归并排序
└─────────────────┘    
       │set(Tuple2<column idx, 真实数据值>)
       │ 
       │ 
-------------------------------- 序列化模型 --------------------------------
       │ 
       │
       │    
┌──────────────┐ 
│  reduceGroup │ 分组归并
└──────────────┘ 
       │ 
       │
       │   
┌─────────────────┐ 
│  SerializeModel │ 序列化模型
└─────────────────┘ 
  

Изображение ниже предназначено для масштабирования и настройки отображения на мобильных телефонах.

QuantileDiscretizerTrainBatchOp.linkFrom выглядит следующим образом:

public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
   BatchOperator<?> in = checkAndGetFirst(inputs);

   // 示例中设置了 .setSelectedCols(new String[]{"col0"}));, 所以这里 quantileColNames 的数值是"col0 
   String[] quantileColNames = getSelectedCols();

   int[] quantileNum = null;

   // 示例中设置了 .setNumBuckets(6),所以这里 quantileNum 是 quantileNum = {int[1]@2705} 0 = 6
   if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) {
      quantileNum = new int[quantileColNames.length];
      Arrays.fill(quantileNum, getNumBuckets());
   } else {
      quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray();
   }

   /* filter the selected column from input */
   // 获取了 选择的列 "col0"
   DataSet<Row> input = Preprocessing.select(in, quantileColNames).getDataSet();

   // 计算分位数
   DataSet<Row> quantile = quantile(
      input, quantileNum,
      getParams().get(HasRoundMode.ROUND_MODE),
      getParams().get(Preprocessing.ZERO_AS_MISSING)
   );

   // 序列化模型
   quantile = quantile.reduceGroup(
      new SerializeModel(
         getParams(),
         quantileColNames,
         TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
         BinTypes.BinDivideType.QUANTILE
      )
   );

   /* set output */
   setOutput(quantile, new QuantileDiscretizerModelDataConverter().getModelSchema());

   return this;
}

Его общая логика такова:

  • Получите имя столбца, для которого требуется квантиль
  • Получить количество ящиков
  • Выберите данные из ввода на основе имени столбца
  • вызвать квантили для вычисления квантилей
    • Вызовите countElementsPerPartition, чтобы получить количество всех элементов раздела в каждом разделе, возвратите , а затем суммируйте сумму (1) для количества элементов, то есть «количество в этой задаче», чтобы накапливать, получать Получить количество всех элементов cnt;
    • В столбцах, которые должны быть выбраны при поиске раздела, данные которых не найдены, с точки зрения кода, это zeroAsMissing, null, isNaN, а затем groupBy(0) в соответствии с идентификатором раздела и накапливать и суммировать, чтобы получитьmissingCount ;
    • Разбейте строку входных данных и отправьте дочерние элементы в строке один за другим в соответствии с порядком в строке, что делает тип строки плоским и возвращает flatten = , то есть in ;
    • Сортируйте данные сглаживания, pSort — это крупномасштабная сортировка разделов, и в настоящее время классификация отсутствует. pSort возвращает двоичный файл sortedData, f0: набор данных, который индексируется по идентификатору раздела, f1: набор данных, который имеет идентификатор раздела и количество;
    • Вызовите MultiQuantile, чтобы вычислить квантиль для sortedData.f0 (f0: набор данных, который индексируется по идентификатору секции), в частности, для расчета секции mapPartition:
      • Накопить, чтобы получить начальную позицию текущей задачи, то есть какие данные начать вычислять из числа n входных данных;
      • Какие данные должна обрабатывать эта задача, получается из подсчетов по taskId, то есть начальной и конечной позиции данных;
      • Вставить данные в allRows.add(value), значение можно рассматривать как ;
      • Вызовите QIndex, чтобы вычислить квантильные метаданные; quantileNum делится на несколько сегментов, а q1 — это размер каждого сегмента. Если он разделен на 6 сегментов, размер каждого сегмента равен 1/6;
      • Пройдите до количества контейнеров BINS и вызовите Qindex.genindex (j) каждый раз, чтобы получить индекс каждого корзина. Затем получают реальное значение данных от входных данных в соответствии с индексом этого биннинга, и это значение реального данных является индексом реальных данных. Например, непрерывная область составляет 1001 ~ 2000, разделенная на 6 частей, затем первой вызовы QIndex.genindex (j), чтобы получить 167, то согласно 167 году, полученные реальные данные 1001 + 167 = 1168, то есть, то есть В 1001 ~ 2000 года первый квантильный индекс составляет 1168.
    • Группа по столбцу IDX для установки (Tuple2 );
  • сериализовать модель

0x04 обучение

4.1 quantile

Обучение проводится с помощью квантилей и примерно состоит из следующих шагов.

  • Вызовите countElementsPerPartition, чтобы получить количество всех элементов раздела в каждом разделе, возвратите , а затем суммируйте сумму (1) для количества элементов, то есть «количество в этой задаче», чтобы накапливать, получать Получить количество всех элементов cnt;
  • В столбцах, которые должны быть выбраны при поиске раздела, данные которых не найдены, с точки зрения кода, это zeroAsMissing, null, isNaN, а затем groupBy(0) в соответствии с идентификатором раздела и накапливать и суммировать, чтобы получитьmissingCount ;
  • Разбейте строку входных данных и отправьте дочерние элементы в строке один за другим в соответствии с порядком в строке, что делает тип строки плоским и возвращает flatten = , то есть in ;
  • Сортируйте данные сглаживания, pSort — это крупномасштабная сортировка разделов, и в настоящее время классификация отсутствует. pSort возвращает двоичный файл sortedData, f0: набор данных, который индексируется по идентификатору раздела, f1: набор данных, который имеет идентификатор раздела и количество;
  • Вызовите MultiQuantile, чтобы вычислить квантили для sortedData.f0 (f0: набор данных, индексированный по идентификатору раздела).

детали следующим образом

public static DataSet<Row> quantile(
   DataSet<Row> input,
   final int[] quantileNum,
   final HasRoundMode.RoundMode roundMode,
   final boolean zeroAsMissing) {
  
   /* instance count of dataset */
   // countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数,返回<task id, count in this task>。
   DataSet<Long> cnt = DataSetUtils
      .countElementsPerPartition(input)
      .sum(1) // 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数。
      .map(new MapFunction<Tuple2<Integer, Long>, Long>() {
         @Override
         public Long map(Tuple2<Integer, Long> value) throws Exception {
            return value.f1; // 取出所有元素个数
         }
      }); // cnt在后续会使用

   /* missing count of columns */
   // 会查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing, null, isNaN这几种情况
   DataSet<Tuple2<Integer, Long>> missingCount = input
      .mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() {
         public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Long>> out) {
            StreamSupport.stream(values.spliterator(), false)
               .flatMap(x -> {
                  long[] counts = new long[x.getArity()];

                  Arrays.fill(counts, 0L);
   
                  // 如果发现有数据没有查到,就增加counts
                  for (int i = 0; i < x.getArity(); ++i) {
                     if (x.getField(i) == null
                     || (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0)
                     || Double.isNaN(((Number)x.getField(i)).doubleValue())) {
                        counts[i]++;
                     }
                  }

                  return IntStream.range(0, x.getArity())
                     .mapToObj(y -> Tuple2.of(y, counts[y]));
               })
               .collect(Collectors.groupingBy(
                  x -> x.f0,
                  Collectors.mapping(x -> x.f1, Collectors.reducing((a, b) -> a + b))
                  )
               )
               .entrySet()
               .stream()
               .map(x -> Tuple2.of(x.getKey(), x.getValue().get()))
               .forEach(out::collect);
         }
      })
      .groupBy(0) //按第一个元素分组
      .reduce(new RichReduceFunction<Tuple2<Integer, Long>>() {
         @Override
         public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2) {
            return Tuple2.of(value1.f0, value1.f1 + value2.f1); //累积求和
         }
      });

   /* flatten dataset to 1d */
   // 把输入数据打散。
   DataSet<PairComparable> flatten = input
      .mapPartition(new RichMapPartitionFunction<Row, PairComparable>() {
         PairComparable pairBuff;
         public void mapPartition(Iterable<Row> values, Collector<PairComparable> out) {
            for (Row value : values) { // 遍历分区内所有输入元素
               for (int i = 0; i < value.getArity(); ++i) { // 如果输入元素Row本身包含多个子元素
                  pairBuff.first = i; // 则对于这些子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了
                  if (value.getField(i) == null
                     || (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0)
                     || Double.isNaN(((Number)value.getField(i)).doubleValue())) {
                     pairBuff.second = null;
                  } else {
                     pairBuff.second = (Number) value.getField(i);
                  }
                  out.collect(pairBuff); // 返回<idx in row, item in row>, 即<row中第几个元素,元素>
               }
            }
         }
      });

   /* sort data */
   // 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类
   // pSort返回的是二元组,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count.
   Tuple2<DataSet<PairComparable>, DataSet<Tuple2<Integer, Long>>> sortedData
      = SortUtilsNext.pSort(flatten);

   /* calculate quantile */
   return sortedData.f0 //f0: dataset which is indexed by partition id
      .mapPartition(new MultiQuantile(quantileNum, roundMode))
      .withBroadcastSet(sortedData.f1, "counts") //f1: dataset which has partition id and count
      .withBroadcastSet(cnt, "totalCnt")
      .withBroadcastSet(missingCount, "missingCounts")
      .groupBy(0) // 依据 列idx 分组
      .reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Number>, Row>() {
         @Override
         public void reduce(Iterable<Tuple2<Integer, Number>> values, Collector<Row> out) {
            TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() {
               @Override
               public int compare(Number o1, Number o2) {
                  return SortUtils.OBJECT_COMPARATOR.compare(o1, o2);
               }
            });

            int id = -1;
            for (Tuple2<Integer, Number> val : values) {
               // Tuple2<column idx, 数据>
               id = val.f0;
               set.add(val.f1); 
            }

// runtime变量           
set = {TreeSet@9379}  size = 5
 0 = {Long@9389} 167 // 就是第 0 列的第一段 idx
 1 = {Long@9392} 333 // 就是第 0 列的第二段 idx
 2 = {Long@9393} 500 
 3 = {Long@9394} 667
 4 = {Long@9382} 833
  
            out.collect(Row.of(id, set.toArray(new Number[0])));
         }
      });
}

Несколько ключевых функций описаны ниже.

4.2 countElementsPerPartition

Функция countElementsPerPartition состоит в том, чтобы получить количество всех элементов раздела в каждом разделе.

public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
   return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
      @Override
      public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
         long counter = 0;
         for (T value : values) {
            counter++; // 在每一个partition中获取该分区的所有元素个数
         }
         out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
      }
   });
}

4.3 MultiQuantile

MultiQuantile используется для расчета определенных квантилей.

В функции open из трансляции получаются переменные, и изначально обрабатываются counts (сортировка), totalCnt,missingCounts (сортировка) и т.д.

Функция mapPartition выполняет определенные вычисления, общие шаги таковы:

  • Накопить, чтобы получить начальную позицию текущей задачи, то есть какие данные начать вычислять из числа n входных данных;
  • Какие данные должна обрабатывать эта задача, получается из подсчетов по taskId, то есть начальной и конечной позиции данных;
  • Вставить данные в allRows.add(value), значение можно рассматривать как ;
  • Вызовите QIndex, чтобы вычислить квантильные метаданные; quantileNum делится на несколько сегментов, а q1 — это размер каждого сегмента. Если он разделен на 6 сегментов, размер каждого сегмента равен 1/6;
  • Перемещайтесь до количества бинов и каждый раз вызывайте qIndex.genIndex(j), чтобы получить индекс каждого бина. Затем получите реальное значение данных из входных данных в соответствии с индексом этого биннинга, и это реальное значение данных является индексом реальных данных. Например, непрерывная область 1001 ~ 2000, разделенная на 6 частей, затем первая вызывает qIndex.genIndex(j) для получения 167, затем по 167 получаются реальные данные 1001 + 167 = 1168, то есть в 1001 ~ 2000 гг. первый квантильный индекс равен 1168;

Конкретный код:

public static class MultiQuantile
   extends RichMapPartitionFunction<PairComparable, Tuple2<Integer, Number>> {
		private List<Tuple2<Integer, Long>> counts;
		private List<Tuple2<Integer, Long>> missingCounts;
		private long totalCnt = 0;
		private int[] quantileNum;
		private HasRoundMode.RoundMode roundType;
		private int taskId;

		@Override
		public void open(Configuration parameters) throws Exception {
      // 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)。
      // 之前设置广播变量.withBroadcastSet(sortedData.f1, "counts"),其中 f1 的格式是: dataset which has partition id and count,所以就是用 partition id来排序
			this.counts = getRuntimeContext().getBroadcastVariableWithInitializer(
				"counts",
				new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
					@Override
					public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
						Iterable<Tuple2<Integer, Long>> data) {
						ArrayList<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
						for (Tuple2<Integer, Long> datum : data) {
							sortedData.add(datum);
						}
            //排序
						sortedData.sort(Comparator.comparing(o -> o.f0));
            
// runtime的数据如下,本机有4核,所以数据分为4个 partition,每个partition的数据分别为251,250,250,250        
sortedData = {ArrayList@9347}  size = 4
 0 = {Tuple2@9350} "(0,251)" // partition 0, 数据个数是251
 1 = {Tuple2@9351} "(1,250)"
 2 = {Tuple2@9352} "(2,250)"
 3 = {Tuple2@9353} "(3,250)"         
            
						return sortedData;
					}
				});

			this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt",
				new BroadcastVariableInitializer<Long, Long>() {
					@Override
					public Long initializeBroadcastVariable(Iterable<Long> data) {
						return data.iterator().next();
					}
				});

			this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer(
				"missingCounts",
				new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
					@Override
					public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
						Iterable<Tuple2<Integer, Long>> data) {
						return StreamSupport.stream(data.spliterator(), false)
							.sorted(Comparator.comparing(o -> o.f0))
							.collect(Collectors.toList());
					}
				}
			);

			taskId = getRuntimeContext().getIndexOfThisSubtask();
      
// runtime的数据如下        
this = {QuantileDiscretizerTrainBatchOp$MultiQuantile@9348} 
 counts = {ArrayList@9347}  size = 4
  0 = {Tuple2@9350} "(0,251)"
  1 = {Tuple2@9351} "(1,250)"
  2 = {Tuple2@9352} "(2,250)"
  3 = {Tuple2@9353} "(3,250)"
 missingCounts = {ArrayList@9375}  size = 1
  0 = {Tuple2@9381} "(0,0)"
 totalCnt = 1001
 quantileNum = {int[1]@9376} 
  0 = 6
 roundType = {HasRoundMode$RoundMode@9377} "ROUND"
 taskId = 2
		}

		@Override
		public void mapPartition(Iterable<PairComparable> values, Collector<Tuple2<Integer, Number>> out) throws Exception {

			long start = 0;
			long end;

			int curListIndex = -1;
			int size = counts.size(); // 分成4份,所以这里是4

			for (int i = 0; i < size; ++i) {
				int curId = counts.get(i).f0; // 取出输入元素中的 partition id

				if (curId == taskId) {
					curListIndex = i; // 当前 task 对应哪个 partition id
					break; // 到了当前task,就可以跳出了
				}

				start += counts.get(i).f1; // 累积,得到当前 task 的起始位置,即1000个数据中从哪个数据开始计算
			}

      // 根据 taskId 从counts中得到了本 task 应该处理哪些数据,即数据的start,end位置
      // 本 partition 是 0,其中有251个数据
			end = start + counts.get(curListIndex).f1; // end = 起始位置 + 此partition的数据个数 

			ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start));

			for (PairComparable value : values) {
				allRows.add(value); // value 可认为是 <partition id, 真实数据>
			}

			allRows.sort(Comparator.naturalOrder());

// runtime变量
start = 0
curListIndex = 0
size = 4
end = 251
allRows = {ArrayList@9406}  size = 251
 0 = {PairComparable@9408} 
  first = {Integer@9397} 0
  second = {Long@9434} 0
 1 = {PairComparable@9409} 
  first = {Integer@9397} 0
  second = {Long@9435} 1
 2 = {PairComparable@9410} 
  first = {Integer@9397} 0
  second = {Long@9439} 2
 ......
      
      // size = ((251 - 1) / 1001 - 0 / 1001) + 1 = 1
			size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1;

			int localStart = 0;
			for (int i = 0; i < size; ++i) {
				int fIdx = (int) (start / totalCnt + i);
				int subStart = 0;
				int subEnd = (int) totalCnt;

				if (i == 0) {
					subStart = (int) (start % totalCnt); // 0
				}

				if (i == size - 1) {
					subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt); // 251
				}

				if (totalCnt - missingCounts.get(fIdx).f1 == 0) {
					localStart += subEnd - subStart;
					continue;
				}

				QIndex qIndex = new QIndex(
					totalCnt - missingCounts.get(fIdx).f1, quantileNum[fIdx], roundType);

// runtime变量
qIndex = {QuantileDiscretizerTrainBatchOp$QIndex@9548} 
 totalCount = 1001.0
 q1 = 0.16666666666666666
 roundMode = {HasRoundMode$RoundMode@9377} "ROUND"      
        
        // 遍历,一直到分箱数。
				for (int j = 1; j < quantileNum[fIdx]; ++j) {
          // 获取每个分箱的index 
					long index = qIndex.genIndex(j); // j = 1 ---> index = 167,就是把 1001 个分为6段,第一段终点是167
          //对应本 task = 0,subStart = 0,subEnd = 251。则index = 167,直接从allRows获取第167个,数值是 1168。因为连续区域是 1001 ~ 2000,所以第167个对应数值就是1168
          //如果本 task = 1,subStart = 251,subEnd = 501。则index = 333,直接从allRows获取第 (333 + 0 - 251)= 第 82 个,获取其中的数值。这里因为数值区域是 1001 ~ 2000, 所以数值是1334。
					if (index >= subStart && index < subEnd) { // idx刚刚好在本分区的数据中
						PairComparable pairComparable = allRows.get(
							(int) (index + localStart - subStart)); // 
            
              
// runtime变量            
pairComparable = {PairComparable@9581} 
 first = {Integer@9507} 0 // first是column idx
 second = {Long@9584} 167 // 真实数据     
   
						out.collect(Tuple2.of(pairComparable.first, pairComparable.second));
					}
				}

				localStart += subEnd - subStart;
			}
		}
	}

4.4 QIndex

Среди них ключевым моментом этой статьи является QIndex, который представляет собой конкретный расчет квантилей.

  • В конструкторе будет подсчитано количество всех элементов и размер каждого сегмента;
  • Функция genIndex будет вычисляться специально, например, если предположить, что отрезков еще 6, если взять первый отрезок, то k=1, а его индекс равен (1/6 * (1001 - 1) * 1 ) = 167
public static class QIndex {
   private double totalCount;
   private double q1;
   private HasRoundMode.RoundMode roundMode;

   public QIndex(double totalCount, int quantileNum, HasRoundMode.RoundMode type) {
      this.totalCount = totalCount; // 1001,所有元素的个数
      this.q1 = 1.0 / (double) quantileNum; // 1.0 / 6 = 16666666666666666。quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6
      this.roundMode = type;
   }

   public long genIndex(int k) {
      // 假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167
      return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k);
   }
}

Модель вывода 0x05

Выходная модель создается путем вызова SerializeModel для reduceGroup.

Конкретная логика такова:

  • Сначала создайте метаданные точки биннинга;
  • Затем сериализуйте его в модель;
// 序列化模型
quantile = quantile.reduceGroup(
      new SerializeModel(
         getParams(),
         quantileColNames,
         TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
         BinTypes.BinDivideType.QUANTILE
      )
);

Конкретная реализация SerializeModel:

public static class SerializeModel implements GroupReduceFunction<Row, Row> {
   private Params meta;
   private String[] colNames;
   private TypeInformation<?>[] colTypes;
   private BinTypes.BinDivideType binDivideType;

   @Override
   public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
      Map<String, FeatureBorder> m = new HashMap<>();
      for (Row val : values) {
         int index = (int) val.getField(0);
         Number[] splits = (Number[]) val.getField(1);
         m.put(
            colNames[index],
            QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
               colNames[index],
               colTypes[index],
               splits,
               meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
               binDivideType
            )
         );
      }

      for (int i = 0; i < colNames.length; ++i) {
         if (m.containsKey(colNames[i])) {
            continue;
         }

         m.put(
            colNames[i],
            QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
               colNames[i],
               colTypes[i],
               null,
               meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
               binDivideType
            )
         );
      }

      QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m, meta);

      model.save(model, out);
   }
}

Здесь используется класс FeatureBorder.

Биннинг данных — это классификация данных по определенным правилам. Так же, как фрукты можно сортировать по размеру и продавать по разным ценам.

FeatureBorder специально для Featureborder для бинирования, дискретного Featureborder и непрерывного Featureborder.

Мы видим, что имя столбца, индекс и каждая точка разделения соответствуют биннингу.

m = {HashMap@9380}  size = 1
 "col0" -> {FeatureBorder@9438} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"

0x06 предсказание

Прогноз делается в QuantileDiscretizerModelMapper.

6.1 Загрузка модели

Данные модели

model = {QuantileDiscretizerModelDataConverter@9582} 
 meta = {Params@9670} "Params {selectedCols=["col0"], version="v2", numBuckets=6}"
 data = {HashMap@9584}  size = 1
  "col0" -> {FeatureBorder@9676} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"

loadModel завершит загрузку.

@Override
public void loadModel(List<Row> modelRows) {
   QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter();
   model.load(modelRows);

   for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) {
      FeatureBorder border = model.data.get(mapperBuilder.paramsBuilder.selectedCols[i]);
      List<Bin.BaseBin> norm = border.bin.normBins;
      int size = norm.size();
      Long maxIndex = norm.get(0).getIndex();
      Long lastIndex = norm.get(size - 1).getIndex();
      for (int j = 0; j < norm.size(); ++j) {
         if (maxIndex < norm.get(j).getIndex()) {
            maxIndex = norm.get(j).getIndex();
         }
      }

      long maxIndexWithNull = Math.max(maxIndex, border.bin.nullBin.getIndex());

      switch (mapperBuilder.paramsBuilder.handleInvalidStrategy) {
         case KEEP:
            mapperBuilder.vectorSize.put(i, maxIndexWithNull + 1);
            break;
         case SKIP:
         case ERROR:
            mapperBuilder.vectorSize.put(i, maxIndex + 1);
            break;
         default:
            throw new UnsupportedOperationException("Unsupported now.");
      }

      if (mapperBuilder.paramsBuilder.dropLast) {
         mapperBuilder.dropIndex.put(i, lastIndex);
      }

      mapperBuilder.discretizers[i] = createQuantileDiscretizer(border, model.meta);
   }

   mapperBuilder.setAssembledVectorSize();
}

Во время загрузки, наконец, вызовите createQuantileDiscretizer, чтобы сгенерировать LongQuantileDiscretizer. Это дискретизатор для типа Long.

public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer {
   long[] bounds;
   boolean isLeftOpen;
   int[] boundIndex;
   int nullIndex;
   boolean zeroAsMissing;

   @Override
   public int findIndex(Object number) {
      if (number == null) {
         return nullIndex;
      }

      long lVal = ((Number) number).longValue();

      if (isMissing(lVal, zeroAsMissing)) {
         return nullIndex;
      }

      int hit = Arrays.binarySearch(bounds, lVal);

      if (isLeftOpen) {
         hit = hit >= 0 ? hit - 1 : -hit - 2;
      } else {
         hit = hit >= 0 ? hit : -hit - 2;
      }

      return boundIndex[hit];
   }
}

Его значения следующие:

this = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} 
 bounds = {long[7]@9757} 
  0 = -9223372036854775807
  1 = 1168
  2 = 1334
  3 = 1501
  4 = 1667
  5 = 1834
  6 = 9223372036854775807
 isLeftOpen = true
 boundIndex = {int[7]@9743} 
  0 = 0 // -9223372036854775807 ~ 1168 之间对应的最终分箱离散值是 0 
  1 = 1
  2 = 2
  3 = 3
  4 = 4
  5 = 5
  6 = 5 // 1834 ~ 9223372036854775807 之间对应的最终分箱离散值是 5 
 nullIndex = 6
 zeroAsMissing = false

6.2 Прогноз

Прогнозирование завершения DiscretizerMapperBuilder QuantileDiscretizerModelMapper.

Row map(Row row){
  
// 这里的 row 举例是: row = {Row@9743} "1003"
   for (int i = 0; i < paramsBuilder.selectedCols.length; i++) {
      int colIdxInData = selectedColIndicesInData[i];
      Object val = row.getField(colIdxInData);
      int foundIndex = discretizers[i].findIndex(val); // 找到 1003对应的index,就是调用Discretizer完成,这里找到 foundIndex 是0
      predictIndices[i] = (long) foundIndex;
   }

   return paramsBuilder.outputColsHelper.getResultRow(
      row,
      setResultRow(
         predictIndices,
         paramsBuilder.encode,
         dropIndex,
         vectorSize,
         paramsBuilder.dropLast,
         assembledVectorSize) // 最后返回离散值是0
   );
}

this = {QuantileDiscretizerModelMapper$DiscretizerMapperBuilder@9744} 
 paramsBuilder = {QuantileDiscretizerModelMapper$DiscretizerParamsBuilder@9752} 
 selectedColIndicesInData = {int[1]@9754} 
 vectorSize = {HashMap@9758}  size = 1
 dropIndex = {HashMap@9759}  size = 1
 assembledVectorSize = {Integer@9760} 6
 discretizers = {QuantileDiscretizerModelMapper$NumericQuantileDiscretizer[1]@9761} 
  0 = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} 
   bounds = {long[7]@9776} 
   isLeftOpen = true
   boundIndex = {int[7]@9777} 
   nullIndex = 6
   zeroAsMissing = false
 predictIndices = {Long[1]@9763} 

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

Если вы хотите получать своевременные новости о статьях, написанных отдельными лицами, или хотите видеть технические материалы, рекомендованные отдельными лицами, обратите внимание.

ссылка 0xFF

Использование QuantileDiscretizer

Искра QuantileDiscretizer Quantile Discretizer

Машинное обучение — дискретизация данных (временная дискретизация, многозначная дискретизация, квантиль, кластеризация, частотный интервал, бинаризация)

Как понять квантили с точки зрения непрофессионала?

Квантиль популярного понимания

Python объясняет математическую серию — квантильный квантиль

Анализ исходного кода искры QuantileDiscretizer