Предоставлено КерасомCallbackИнтерфейс для отслеживания результатов каждого шага в процессе обучения, включая каждую партию и каждую эпоху. Хотя она называется «функция обратного вызова», на самом деле, если вы хотите расширить эту функцию, вам нужно наследоватьkeras.callbacks.Callback
класс, который предоставляет два свойства, связанные с процессом обучения модели:
-
params
: Параметры задаются при компиляции модели; -
model
: Объект модели.
Визуализация в реальном времени через этот интерфейсfit
Размер ошибки изменяется во время каждой партии и каждой итерации эпохи в процессе. от"Neural Networks and Deep Learning - Chap3 Improving the way neural networks learn«Например, предположим, что мы хотим обучить простейшую нейронную сеть:
Эта нейронная сеть только с одним нейроном имеет только один весw
и предвзятостьb
Два параметра для обучения, предполагая, что данные для обучения только(1, 0)
, где сравниваются эффекты обучения двух функций стоимости, MSE и Cross Entropy.
Сначала соберите эту модель:
from keras import Sequential, initializers, optimizers
from keras.layers import Activation, Dense
import numpy as np
def viz_keras_fit(w, b, runtime_plot=False, loss="mean_squared_error", act="sigmoid"):
d = DrawCallback(runtime_plot=runtime_plot)
# 初始化参数
w = initializers.Constant([w])
b = initializers.Constant([b])
x_train, y_train = np.array([1]), np.array([0])
model = Sequential()
model.add(Dense(1,
activation=act,
input_shape=(1,),
kernel_initializer=w,
bias_initializer=b))
# Learning Rate = 0.15
sgd = optimizers.SGD(lr=0.15)
model.compile(optimizer=sgd, loss=loss)
model.fit(x = x_train,
y = y_train,
epochs=150,
verbose=0,
callbacks=[d]) # Callback List
return d
Исходные параметры остаются(2, 2)
После перехода на Cross Entropy как функцию потерь:
Хотя достигается визуализация в реальном времени, отрисовка может занять больше эпохи, поэтому лучше записывать потерю каждого шага перед отрисовкой:
Наблюдение за обучением модели в режиме реального времени может помочь нам выбрать функцию потерь, функцию активации, структуру модели и гиперпараметры на ранней стадии. Ниже приведеныDrawCallback
Реализация:
import pylab as pl
from IPython import display
from keras.callbacks import Callback
class DrawCallback(Callback):
def __init__(self, runtime_plot=True):
super().__init__()
self.init_loss = None
self.runtime_plot = runtime_plot
self.xdata = []
self.ydata = []
def _plot(self, epoch=None):
epochs = self.params.get("epochs")
pl.ylim(0, int(self.init_loss*2))
pl.xlim(0, epochs)
pl.plot(self.xdata, self.ydata)
pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs))
pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
def _runtime_plot(self, epoch):
self._plot(epoch)
display.clear_output(wait=True)
display.display(pl.gcf())
pl.gcf().clear()
def plot(self):
self._plot()
pl.show()
def on_epoch_end(self, epoch, logs = None):
logs = logs or {}
loss = logs.get("loss")
if self.init_loss is None:
self.init_loss = loss
self.xdata.append(epoch)
self.ydata.append(loss)
if self.runtime_plot:
self._runtime_plot(epoch)