0x00 сводка
Alink — это платформа алгоритмов машинного обучения нового поколения, разработанная Alibaba на основе вычислительного движка реального времени Flink, и первая в отрасли платформа машинного обучения, которая поддерживает как пакетные, так и потоковые алгоритмы. В этой статье рассказывается о том, как алгоритм онлайн-обучения FTRL реализован в Alink вместе с вышеизложенным, надеюсь, она будет полезна всем.
0x01 Обзор
книга забратьAlink's Talk (12): Общий дизайн алгоритма онлайн-обучения FTRL. На данный момент ввод обработан, и следующий шаг — онлайн-обучение. Основная цель оптимизации обучения - найти направление.После того, как параметры перемещаются в этом направлении, значение функции потерь может быть уменьшено.Это направление часто получается различными комбинациями частных производных первого порядка или частных производных второго порядка .
Для того, чтобы все поняли лучше, мы еще раз разместили общую блок-схему:
0x02 Онлайн-обучение
Основная логика онлайн-обучения такова:
- 1) Загрузите инициализированную модель в dataBridge: dataBridge = DirectReader.collect(model);
- 2) Получить соответствующие параметры. Например, vectorSize по умолчанию равен 30000, независимо от того, hasInterceptItem;
- 3) Получить информацию о сегментации. splitInfo = getSplitInfo(featureSize, hasInterceptItem, parallelism); вскоре будет использоваться.
- 4) Разрезать многомерные векторы. Данные инициализации хэшируются с признаками, которые будут генерировать многомерный вектор, который здесь необходимо вырезать. initData.flatMap (новый SplitVector (splitInfo, hasInterceptItem, vectorSize, vectorTrainIdx, featureIdx, labelIdx));
- 5) Построить итерацию IterativeStream.ConnectedIterativeStreams, которая будет строить (или соединять) два потока данных: поток обратной связи и обучающий поток;
- 6) Использовать итерацию для построения iterativeBody, состоящего из двух частей: CalcTask, ReduceTask;
- 6.1) CalcTask разделен на две части. flatMap1 — это прогноз, необходимый для итерации FTRL для распределенных вычислений, а flatMap2 — часть параметра обновления FTRL;
- 6.2) ReduceTask разделен на две функции: «Объединить эти прогнозные результаты расчета» / «Если условия выполнены, объединить модель и вывести модель нижестоящему оператору»;
- 7) результат = iterativeBody.filter, он в основном оценивается на основе временного интервала (он также может рассматриваться как управляемый по времени), данные «время не истекло и вектор имеет смысл» будут отправлены обратно в данные обратной связи поток, и итерация продолжится.Вернитесь к шагу 6), введите flatMap2;
- 8) output = iterativeBody.filter; данные, соответствующие стандарту (время истекло), выскочат из итерации, а затем алгоритм вызовет WriteModel, чтобы преобразовать LineModelData в несколько строк и направить их нижестоящему оператору (то есть онлайн-оператору). этап предсказания);То есть регулярно обновлять модель до стадии онлайн-прогноза..
2.1 Предустановленные модели
Как упоминалось ранее, FTRL должен сначала обучитьмодель логистической регрессииКак начальная модель алгоритма FTRL, это для нужд холодного старта системы.
2.1.1 Обучение модели
Конкретная настройка/обучение модели логистической регрессии:
// train initial batch model
LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.setMaxIter(10);
BatchOperator<?> initModel = featurePipelineModel.transform(trainBatchData).link(lr);
После обучения информация о модели имеет тип DataSet и находится в переменной BatchOperator> initModel, которая является пакетным оператором.
2.1.2 Загрузка модели
FtrlTrainStreamOp принимает initModel в качестве параметра инициализации.
FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
Эта модель будет загружена в конструктор FtrlTrainStreamOp;
dataBridge = DirectReader.collect(initModel);
При загрузке данные в инициализированной модели DataSet получаются напрямую через MemoryDataBridge.
public MemoryDataBridge generate(BatchOperator batchOperator, Params globalParams) {
return new MemoryDataBridge(batchOperator.collect());
}
2.2 Сегментация многомерных векторов
Как видно из предыдущей статьи, размерность вектора признаков, заданного алгоритмом Alink FTRL, составляет 30000. Таким образом, первым шагом алгоритма является нарезка многомерных векторов для распределенных вычислений.
String vecColName = "vec";
int numHashFeatures = 30000;
Первый шаг — получить информацию о сегментации.Код выглядит следующим образом: нужно разделить количество признаков featureSize на параллелизм параллелизма, а затем получить начальную позицию соответствующего коэффициента каждой задачи.
private static int[] getSplitInfo(int featureSize, boolean hasInterceptItem, int parallelism) {
int coefSize = (hasInterceptItem) ? featureSize + 1 : featureSize;
int subSize = coefSize / parallelism;
int[] poses = new int[parallelism + 1];
int offset = coefSize % parallelism;
for (int i = 0; i < offset; ++i) {
poses[i + 1] = poses[i] + subSize + 1;
}
for (int i = offset; i < parallelism; ++i) {
poses[i + 1] = poses[i] + subSize;
}
return poses;
}
//程序运行时变量如下
featureSize = 30000
hasInterceptItem = true
parallelism = 4
coefSize = 30001
subSize = 7500
poses = {int[5]@11660}
0 = 0
1 = 7501
2 = 15001
3 = 22501
4 = 30001
offset = 1
Затем многомерный вектор разрезается в соответствии с информацией о вырезании.
// Tuple5<SampleId, taskId, numSubVec, SubVec, label>
DataStream<Tuple5<Long, Integer, Integer, Vector, Object>> input
= initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,
vectorTrainIdx, featureIdx, labelIdx))
.partitionCustom(new CustomBlockPartitioner(), 1);
Конкретная сегментация выполняется в функции SplitVector.flatMap, и в результате многомерный вектор разделяется на каждую задачу CalcTask.
Резюме кода выглядит следующим образом:
public void flatMap(Row row, Collector<Tuple5<Long, Integer, Integer, Vector, Object>> collector) throws Exception {
long sampleId = counter;
counter += parallelism;
Vector vec;
if (vectorTrainIdx == -1) {
.....
} else {
// 输入row的第vectorTrainIdx个field就是那个30000大小的系数向量
vec = VectorUtil.getVector(row.getField(vectorTrainIdx));
}
if (vec instanceof SparseVector) {
Map<Integer, Vector> tmpVec = new HashMap<>();
for (int i = 0; i < indices.length; ++i) {
.....
// 此处迭代完成后,tmpVec中就是task number个元素,每一个元素是分割好的系数向量。
}
for (Integer key : tmpVec.keySet()) {
//此处遍历,给后面所有CalcTask发送五元组数据。
collector.collect(Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx)));
}
} else {
......
}
}
}
Этот Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx)) является входом следующей задачи CalcTask.
2.3 Итеративное обучение
Теоретически здесь есть следующие ключевые моменты:
-
Метод прогнозирования: в каждом раунде t для выборки признаков xt и параметра модели wt после итерации (первый раз является заданным начальным значением) мы можем предсказать значение метки выборки: pt=σ(wt,xt), где σ(a)=1/(1+exp(−a)) — сигмовидная функция.
-
Функция потерь: для выборки признаков xt, соответствующей метке которой является yt ∈ 0,1, в качестве функции потерь используются логистические потери.
-
Итерационная формула: Наша цель — сделать функцию потерь как можно меньше, то есть для решения параметров можно использовать оценку максимального правдоподобия. Сначала найдите градиент, затем используйте FTRL для итерации.
Идея псевдокода примерно следующая
double p = learner.predict(x); //预测
learner.updateModel(x, p, y); //更新模型
double loss = LogLossEvalutor.calLogLoss(p, y); //计算损失
evalutor.addLogLoss(loss); //更新损失
totalLoss += loss;
trainedNum += 1;
конкретная реализацияAlink имеет свои особенности и настройки.
2.3.1 Функция итерации Flink Stream
Машинное обучение требует итеративного обучения, и здесь Alink использует итеративную функцию Flink Stream.
Экземпляр IterativeStream создается с помощью метода итерации DataStream. Есть две перегрузки метода iterate:
- Один параметр отсутствует, что означает, что максимальное время ожидания не ограничено;
- Один предоставляет длинный параметр maxWaitTimeMillis, который позволяет пользователю указать максимальный интервал времени для ожидания следующего входного элемента фронта обратной связи.
Алинка выбрала второе.
При создании ConnectedIterativeStreams используйте исходный ввод итеративного потока в качестве первого входного потока и поток обратной связи в качестве второго входного потока.
Каждый поток данных (DataStream) будет иметь соответствующее преобразование потока (StreamTransformation). Соответствующим преобразованием IterativeStream является FeedbackTransformation.
Соответствующим преобразованием IterativeStream является FeedbackTransformation, которое представляет точку обратной связи (т. е. итеративную головку) в топологии. Точка обратной связи содержит входное ребро и несколько ребер обратной связи, и Flink требует, чтобы параллелизм каждого ребра обратной связи был таким же, как и параллелизм входного ребра, который будет проверяться при добавлении ребер обратной связи в преобразование.
При создании объекта IterativeStream создается экземпляр FeedbackTransformation, который передается конструктору DataStream.
Итеративное закрытие достигается вызовом метода экземпляра IterativeStream closeWith. Эта функция указывает, что поток будет концом итератора, и этот поток будет возвращен в итерацию в качестве второго входа.
2.3.2 Итеративная сборка
Для Alink итеративный код сборки:
// train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>
// feedback format = Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>
IterativeStream.ConnectedIterativeStreams<
Tuple5<Long, Integer, Integer, Vector, Object>,
Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
iteration = input.iterate(Long.MAX_VALUE)
.withFeedbackType(TypeInformation
.of(new TypeHint<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {}));
// 即iteration是一个 IterativeStream.ConnectedIterativeStreams<...>
2.3.2.1 Входные данные для итерации
Как вы можете видеть из кода и комментариев, два входа для итерации:
- формат обучающих данных =
; это фактически обучающие данные; - Tuple7
; на самом деле это данные обратной связи, которые являются «итеративным потоком обратной связи» в качестве второго входа;
2.3.2.2 Итеративная обратная связь
Настройка потока обратной связи достигается вызовом метода экземпляра closeWith из IterativeStream. ссылка здесь
DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
result = iterativeBody.filter(
return (t3.f0 > 0 && t3.f2 > 0); // 这里是省略版本代码
);
iteration.closeWith(result);
Как упоминалось ранее, оценка фильтра результатовreturn (t3.f0 > 0 && t3.f2 > 0)
, если условия соблюдены, это означает, что время не истекло и вектор имеет смысл, поэтому его следует вернуть в это время для продолжения обучения.
Формат потока обратной связи:
- Tuple7
;
2.3.3 Итератор CalcTask/ReduceTask
Итеративное тело состоит из двух частей: CalcTask/ReduceTask.
Каждый экземпляр CalcTask имеет модель инициализации dataBridge.
DataStream iterativeBody = iteration.flatMap(
new CalcTask(dataBridge, splitInfo, getParams()))
2.3.3.1 Итеративная инициализация
Итерация запускается функцией CalcTask.open, которая в основном выполняет следующие действия.
- Задайте различные параметры, такие как
- Количество рабочих задач, numWorkers = getRuntimeContext().getNumberOfParallelSubtasks();
- Идентификатор этой задачи, workerId = getRuntimeContext().getIndexOfThisSubtask();
- прочитать модель инициализации
- List modelRows = DirectReader.directRead(dataBridge);
- Преобразование данных типа Row в линейную модель Модель LinearModelData = new LinearModelDataConverter().load(modelRows);
- Считайте коэффициент coef[i - startIdx], соответствующий этой задаче, вот и разделите всю модель на столько-то задач numWorkers и обновляйте их параллельно.
- Укажите время начала этой задачи startTime = System.currentTimeMillis();
2.3.3.2 Обработка входных данных
CalcTask.flatMap1 в основном реализует прогнозную часть алгоритма FTRL (обратите внимание, что это не прогнозирование FTRL).
объяснять: pt=σ(Xt⋅w) — функция предсказания LR, Единственная цель нахождения pt — найти первую производную целевой функции (с использованием кросс-энтропийной функции потерь в качестве целевой функции в LR) к параметру w, g, gi = (pt−yt)xi. Этот шаг также применим к FTRL-оптимизации других целевых функций, единственное отличие состоит в том, чтобы найти субградиент g (субградиент — это набор между левой и правой производной, функцию можно дифференцировать — когда левая производная равна правой производной, субградиент равен одному шагу градиента) метод другой.
Входными данными для функции являются «входные данные для обучения», т.е.SplitVector.flatMap的输出 ----> CalcCalcTask的输入
. Входные данные представляют собой пятерку в формате обучающих данных format =
Следует отметить три вещи:
- Да, если заходишь впервые, нужно сохранить FristModel;
- Здесь ввод обрабатывается, а затем немедленно выводится (в отличие от flatMap2, flatMap2 обрабатывается при наличии ввода, но выводится не сразу, а выводится по истечении времени);
- Реализация прогноза:
((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
Все скажут: нет! Функция прогнозирования должна бытьsigmoid = 1.0 / (1.0 + np.exp(-w.dot(x)))
. Да, сигмовидной операции пока нет. Когда ReduceTask выполняет агрегирование, он возвращает агрегированное значение p обратно в итератор, а затем в CalcTask.flatMap2 выполняется сигмовидная операция.
public void flatMap1(Tuple5<Long, Integer, Integer, Vector, Object> value,
Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out) throws Exception {
if (!savedFristModel) { //第一次进入需要存模型
out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
new DenseVector(coef), labelValues, -1.0, modelId++));
savedFristModel = true;
}
Long timeStamps = System.currentTimeMillis();
double wx = 0.0;
Long sampleId = value.f0;
Vector vec = value.f3;
if (vec instanceof SparseVector) {
int[] indices = ((SparseVector)vec).getIndices();
// 这里就是具体的Predict
for (int i = 0; i < indices.length; ++i) {
wx += ((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
}
} else {
......
}
//处理了就输出
out.collect(Tuple7.of(sampleId, value.f1, value.f2, value.f3, value.f4, wx, timeStamps));
}
2.3.3.3 Агрегированные данные
ReduceTask.flatMap отвечает за слияние данных.
public static class ReduceTask extends
RichFlatMapFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>,
Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> {
private int parallelism;
private int[] poses;
private Map<Long, List<Object>> buffer;
private Map<Long, List<Tuple2<Integer, DenseVector>>> models = new HashMap<>();
}
Функция flatMap грубо завершает следующие функции, то есть два слияния:
- Используется для выходной модели. Определите, истекает ли время, если (value.f0
- Сгенерируйте List
> model = models.get(value.f6); со значением.f6, то есть меткой времени, в качестве ключа, и вставьте его в HashMap. - Если все коллекции собраны, модель выводится нижестоящему оператору, а промежуточная модель удаляется из HashMap.
- Обновите f5 Tuple7 с меткой y, которая является меткой в Tuple7
, которая является предсказанной y. - Отправьте этот новый Tuple7 каждому нижестоящему оператору (то есть каждому CalcTask, но в качестве входных данных для flatMap2);
При использовании именно в качестве выходной модели его переменные выглядят следующим образом:
models = {HashMap@13258} size = 1
{Long@13456} 1 -> {ArrayList@13678} size = 1
key = {Long@13456} 1
value = {ArrayList@13678} size = 1
0 = {Tuple2@13698} "(1,0.0 -8.244533295515879E-5 0.0 -1.103997743166529E-4 0.0 -3.336931546279811E-5....."
2.3.3.4 Оценка необходимости обратной связи
Этот результат фильтра используется для сужения, следует ли обратная связь. Здесь T3.F0 - образец, а T3.F2 является подсумом.
DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
result = iterativeBody.filter(
new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
@Override
public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> t3)
throws Exception {
// if t3.f0 > 0 && t3.f2 > 0 then feedback
return (t3.f0 > 0 && t3.f2 > 0);
}
});
заt3.f0, есть два места, где код имеет отрицательное значение.
-
установит "-1" один раз в saveFirstModel, то есть
if (!savedFristModel) { out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); savedFristModel = true; }
-
Также установите «-1» по истечении времени.
if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) { startTime = System.currentTimeMillis(); out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); }
заt3.f2, если subNum больше нуля, это означает, что значимое значение получается при делении многомерного вектора.
следовательноreturn (t3.f0 > 0 && t3.f2 > 0)
Это означает, что время не истекло и вектор имеет значение, поэтому его следует вернуть в это время, чтобы продолжить обучение.
2.3.3.5 Определите, следует ли выводить модель
Вот результат фильтра.
value.f0 < 0
Указывает, что время истекло и модель должна быть выведена.
DataStream<Row> output = iterativeBody.filter(
new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
@Override
public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value)
{
/* if value.f0 small than 0, then output */
return value.f0 < 0;
}
}).flatMap(new WriteModel(labelType, getVectorCol(), featureCols, hasInterceptItem));
2.3.3.6 Данные обратной связи процесса/параметры обновления
На самом деле CalcTask.flatMap2 выполняет остальную часть алгоритма FTRL, обновляя часть параметров. Основная логика следующая:
- Вычислить временной интервал timeInterval = System.currentTimeMillis() - value.f6;
- Формальный расчет предсказывает, p = 1 / (1 + Math.exp(-p)), то есть сигмовидная операция;
- Вычислить градиент g = (p - метка) * values[i] / Math.sqrt(timeInterval), здесь разделить на временной интервал;
- обновить параметры;
- входить. Обратите внимание, что ввод обрабатывается здесь, ноне выводится сразу, но кумулятивные параметры, которые выводятся по истечении времени, то есть достигается обычная модель вывода;
существуетLogistic Regression, сигмовидная функция σ(a) = 1 / (1 + exp(-a)) , а расчетное значение pt = σ(xt . wt), то функция LogLoss имеет вид
можно вычислить напрямую
Конкретный алгоритм LR + FTRL реализован следующим образом:
@Override
public void flatMap2(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value,
Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out)
throws Exception {
double p = value.f5;
// 计算时间间隔
long timeInterval = System.currentTimeMillis() - value.f6;
Vector vec = value.f3;
/* eta */
// 正式计算predict,之前只是计算了一半,这里计算后半部,即
p = 1 / (1 + Math.exp(-p));
.....
if (vec instanceof SparseVector) {
// 这里是更新参数
int[] indices = ((SparseVector)vec).getIndices();
double[] values = ((SparseVector)vec).getValues();
for (int i = 0; i < indices.length; ++i) {
// update zParam nParam
int id = indices[i] - startIdx;
// values[i]是xi
// 下面的计算基本和Google伪代码一致
double g = (p - label) * values[i] / Math.sqrt(timeInterval);
double sigma = (Math.sqrt(nParam[id] + g * g) - Math.sqrt(nParam[id])) / alpha;
zParam[id] += g - sigma * coef[id];
nParam[id] += g * g;
// update model coefficient
if (Math.abs(zParam[id]) <= l1) {
coef[id] = 0.0;
} else {
coef[id] = ((zParam[id] < 0 ? -1 : 1) * l1 - zParam[id])
/ ((beta + Math.sqrt(nParam[id]) / alpha + l2));
}
}
} else {
......
}
// 当时间到期了再输出,即做到了定期输出模型
if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) {
startTime = System.currentTimeMillis();
out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
new DenseVector(coef), labelValues, -1.0, modelId++));
}
}
2.4 Выходная модель
Класс WriteModel реализует функцию выходной модели, и общая логика выглядит следующим образом:
- Создайте LinearModelData и заполните эту LinearModelData обученным Tuple7. Два важных момента:
- modelData.coefVector = (DenseVector)value.f3;
- modelData.labelValues = (Object[])value.f4;
- Преобразуйте данные модели в строки списка. LinearModelDataConverter().save(modelData, listCollector);
- Сериализуется и отправляется нижестоящим операторам. Поскольку модель может быть очень большой, распределение отправляется нижестоящему оператору после разделения здесь.
public void flatMap(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value, Collector<Row> out){
//输入value变量打印如下:
value = {Tuple7@13296}
f0 = {Long@13306} -1
f1 = {Integer@13307} 0
f2 = {Integer@13308} 2
f3 = {DenseVector@13309} "-0.7383426732137565 0.0 0.0 0.0 1.5885293675862715E-4 -4.834608575902742E-5 0.0 0.0 -6.754208708318647E-5 ......"
data = {double[30001]@13314}
f4 = {Object[2]@13310}
f5 = {Double@13311} -1.0
f6 = {Long@13312} 0
//生成模型
LinearModelData modelData = new LinearModelData();
......
modelData.coefVector = (DenseVector)value.f3;
modelData.labelValues = (Object[])value.f4;
//把模型数据转换成List<Row> rows
RowCollector listCollector = new RowCollector();
new LinearModelDataConverter().save(modelData, listCollector);
List<Row> rows = listCollector.getRows();
for (Row r : rows) {
int rowSize = r.getArity();
for (int j = 0; j < rowSize; ++j) {
.....
//序列化
}
out.collect(row);
}
iter++;
}
}
0x03 Онлайн-предсказание
Функция предсказания выполняется в ftrlPredictStreamOp.
// ftrl predict
FtrlPredictStreamOp predictResult = new FtrlPredictStreamOp(initModel)
.setVectorCol(vecColName)
.setPredictionCol("pred")
.setReservedCols(new String[]{labelColName})
.setPredictionDetailCol("details")
.linkFrom(model, featurePipelineModel.transform(splitter.getSideOutput(0)));
Из приведенного выше кода мы видим
- Для функции ftrlPredict также требуется исходная модель initModel, и мы также назначаем ей модель логистической регрессии. Это также для холодного запуска, то есть до того, как модуль обучения FTRL сгенерирует модель, модуль прогнозирования FTRL также может делать прогнозы на своих входных данных.
- Модель является выходом FtrlTrainStreamOp, выходом обучения FTRL. Таким образом, WriteModel передает выходные данные непосредственно в функцию FTrlPredict.
- splitter.getSideOutput(0) Вот тестовые входные данные, упомянутые ранее, которые представляют собой набор тестовых данных.
Функция linkFrom завершает бизнес-логику, общая функция выглядит следующим образом:
- использовать
inputs[0].getDataStream().flatMap ------> partition ----> map ----> flatMap(new CollectModel())
Получил модель LinearModelData modelstr; - Используйте DataStream.connect, чтобы связать набор входных тестовых данных с моделью LinearModelData modelstr, чтобы каждая задача имела интерактивную модель modelstr, доступ к которой можно получить через
flatMap(new PredictProcess(...)
делать распределенные прогнозы; - Используйте setOutputTable и LinearModelMapper для вывода результатов прогнозирования;
То есть функция предсказания FTRL имеет три входа:
- начальная модель initModel-----> Окончательно загружен PredictProcess.open как модель прогнозирования холодного запуска;
- поток тестовых данных-----> Обрабатывается PredictProcess.flatMap1 для предсказания;
- Поток данных модели, созданный на этапе обучения FTRL----> Обрабатывается PredictProcess.flatMap2 для онлайн-обновления модели;
3.1 Инициализация
В конструкторе завершается инициализация, то есть получается предварительно обученная модель логистической регрессии.
public FtrlPredictStreamOp(BatchOperator model) {
super(new Params());
if (model != null) {
dataBridge = DirectReader.collect(model);
} else {
throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model.");
}
}
3.2 Получите модель онлайн-обучения
CollectModel дополняет функцию получения онлайн-модели обучения.
Логика в основном такова: модель разбита на несколько блоков, при этом (long) INROW.GETFIELD(1) Сколько блоков здесь записано. Поэтому функция FlatMap накапливает эти блоки, окончательно собирает модель и отправляет ее нижестоящему оператору.
В частности, временная сборка / окончательная сборка завершена через HASHMAP буферы.
public static class CollectModel implements FlatMapFunction<Row, LinearModelData> {
private Map<Long, List<Row>> buffers = new HashMap<>(0);
@Override
public void flatMap(Row inRow, Collector<LinearModelData> out) throws Exception {
// 输入参数如下
inRow = {Row@13389} "0,19,0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
fields = {Object[5]@13405}
0 = {Long@13406} 0
1 = {Long@13403} 19
2 = {Long@13406} 0
3 = "{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"}"
"
long id = (long)inRow.getField(0);
Long nTab = (long)inRow.getField(1);
Row row = new Row(inRow.getArity() - 2);
for (int i = 0; i < row.getArity(); ++i) {
row.setField(i, inRow.getField(i + 2));
}
if (buffers.containsKey(id) && buffers.get(id).size() == nTab.intValue() - 1) {
buffers.get(id).add(row);
// 如果累积完成,则组装成模型
LinearModelData ret = new LinearModelDataConverter().load(buffers.get(id));
buffers.get(id).clear();
// 发送给下游算子。
out.collect(ret);
} else {
if (buffers.containsKey(id)) {
//如果有key。则往list添加。
buffers.get(id).add(row);
} else {
// 如果没有key,则添加list
List<Row> buffer = new ArrayList<>(0);
buffer.add(row);
buffers.put(id, buffer);
}
}
}
}
//变量类似这种
this = {FtrlPredictStreamOp$CollectModel@13388}
buffers = {HashMap@13393} size = 1
{Long@13406} 0 -> {ArrayList@13431} size = 2
key = {Long@13406} 0
value = 0
value = {ArrayList@13431} size = 2
0 = {Row@13409} "0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
1 = {Row@13471} "1048576,{"featureColNames":null,"featureColTypes":null,"coefVector":{"data":[-0.7383426732137549,0.0,0.0,0.0,1.5885293675862704E-4,-4.834608575902738E-5,0.0,0.0,-6.754208708318643E-5,-1.5904172331763155E-4,0.0,-1.315219790338925E-4,0.0,-4.994749246390495E-4,0.0,2.755456604395511E-4,-9.616429481614131E-4,-9.601054004112163E-5,0.0,-1.6679174640370486E-4,0.0,......"
3.3 Онлайн-прогноз
PredictProcess выполняет функцию онлайн-прогнозирования, а LinearModelMapper — конкретную реализацию прогнозирования.
public static class PredictProcess extends RichCoFlatMapFunction<Row, LinearModelData, Row> {
private LinearModelMapper predictor = null;
private String modelSchemaJson;
private String dataSchemaJson;
private Params params;
private int iter = 0;
private DataBridge dataBridge;
}
3.3.1 Загрузка предустановленных моделей
Его конструктор получает dataBridge класса FtrlPredictStreamOp, который является предварительно обученной моделью логистической регрессии. Каждая задача имеет полную модель.
Функция open загружает модель логистической регрессии.
public void open(Configuration parameters) throws Exception {
this.predictor = new LinearModelMapper(TableUtil.fromSchemaJson(modelSchemaJson),
TableUtil.fromSchemaJson(dataSchemaJson), this.params);
if (dataBridge != null) {
// read init model
List<Row> modelRows = DirectReader.directRead(dataBridge);
LinearModelData model = new LinearModelDataConverter().load(modelRows);
this.predictor.loadModel(model);
}
}
3.3.2 Онлайн-прогноз
Функция ftrlpredictStreamop.FlatMap1 выполняет онлайн-прогнозирование.
public void flatMap1(Row row, Collector<Row> collector) throws Exception {
collector.collect(this.predictor.map(row));
}
Стек вызовов выглядит следующим образом:
predictWithProb:157, LinearModelMapper (com.alibaba.alink.operator.common.linear)
predictResultDetail:114, LinearModelMapper (com.alibaba.alink.operator.common.linear)
map:90, RichModelMapper (com.alibaba.alink.common.mapper)
flatMap1:174, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
flatMap1:143, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
processElement1:53, CoStreamFlatMap (org.apache.flink.streaming.api.operators.co)
processRecord1:135, StreamTwoInputProcessor (org.apache.flink.streaming.runtime.io)
В частности, это делается через LinearModelMapper.
public abstract class RichModelMapper extends ModelMapper {
public Row map(Row row) throws Exception {
if (isPredDetail) {
// 我们的示例代码在这里
Tuple2<Object, String> t2 = predictResultDetail(row);
return this.outputColsHelper.getResultRow(row, Row.of(t2.f0, t2.f1));
} else {
return this.outputColsHelper.getResultRow(row, Row.of(predictResult(row)));
}
}
}
Код предсказания выглядит следующим образом, и видно, что используется сигмоид.
/**
* Predict the label information with the probability of each label.
*/
public Tuple2 <Object, Double[]> predictWithProb(Vector vector) {
double dotValue = MatVecOp.dot(vector, model.coefVector);
switch (model.linearModelType) {
case LR:
case SVM:
double prob = sigmoid(dotValue);
return new Tuple2 <>(dotValue >= 0 ? model.labelValues[0] : model.labelValues[1],
new Double[] {prob, 1 - prob});
}
}
3.3.3 Модель онлайн-обновления
Функция FtrlPredictStreamOp.flatMap2 завершает обработку потока данных модели, выводимых онлайн-обучением, и обновляет модель онлайн.
Параметр LinearModelData загружается и передается CollectModel.
Процесс загрузки модели непредсказуем, и соответствующий механизм защиты не виден. Пожалуйста, укажите, если я что-то упустил.
public void flatMap2(LinearModelData linearModel, Collector<Row> collector) throws Exception {
this.predictor.loadModel(linearModel);
}
0x04 Вопрос Ответы
Отвечая на вопросы, которые мы задавали ранее, мы можем резюмировать их следующим образом:
- Существуют ли готовые модели для фаз обучения и прогнозирования для работы с «холодными пусками»? Есть сборные модели;
- Как связаны этап обучения и этап прогнозирования? Используйте linkFrom для прямого подключения операторов на этапе обучения и прогнозирования;
- Как передать обученную модель на стадию прогнозирования? На этапе обучения используйте Flink Collector.collect для отправки модели нижестоящим операторам;
- Что делать при экспорте модели, если модель слишком большая? онлайн-тренировкамодельЗатем рассылка отправляется нижестоящим операторам;
- Какой механизм используется для обновления модели, обученной онлайн? Это обычное обновление драйверов? регулярно обновлять;
- Можете ли вы по-прежнему предсказать, когда модель будет загружена на этапе прогнозирования? Существует ли механизм, гарантирующий, что этот период времени также можно предсказать? Подобный механизм защиты пока не найден;
- На каких этапах этапа обучения используется параллельная обработка? Процесс обучения состоит в основном из двух частей: «прогнозирования» и «обновления параметров» алгоритма FTRL, а также модели отправки;
- Какой из этапов прогнозирования использует параллельную обработку? Процесс прогнозирования в основном представляет собой распределенную модель принятия и распределенное прогнозирование;
- Как работать с многомерными векторами? Разрезать его? обработка сегментации;
0xEE Личная информация
★★★★★★Думая о жизни и технологиях★★★★★★
Публичный аккаунт WeChat: мысли Росси
Если вы хотите получать своевременные новости о статьях, написанных отдельными лицами, или хотите видеть технические материалы, рекомендованные отдельными лицами, обратите внимание.
ссылка 0xFF
[Машинное обучение] Логистическая регрессия (очень подробно)
[Машинное обучение] Распределенная (параллельная) реализация LR
Параллельная логистическая регрессия
Обсуждение алгоритмов машинного обучения и их распараллеливания
Online LR — понимание алгоритма FTRL
Принцип и реализация алгоритма онлайн-оптимизации FTRL
Принцип алгоритма LR+FTRL и его инженерная реализация
Итеративный анализ API потоковой обработки Flink
Интернет-машина обучения FTRL (последующий регуляризованный) алгоритм
LR+FTRL в бою FTRL (плотные данные, используемые кодом)
Алгоритм онлайн-обучения FTRL-проксимальный принцип
Алгоритм онлайн-прогнозирования CTR на основе FTRL
FTRL-Проксимальный алгоритм прогнозирования CTR
Подробное объяснение FTRL, алгоритма онлайн-обучения, широко используемого крупными компаниями.
Онлайн-оптимизация Решение 5: FTRL
Краткое изложение алгоритма FOLLOW THE REGULARIZED LEADER (FTRL)