Эта статья является третьей в серии заметок об исходном коде keras. В первых двух статьях мы проанализировали, как keras обрабатывает такие понятия, как тензор и слой, и объяснили, как они работают для формирования ориентированного ациклического графа. В этой статье основное внимание уделяется абстракции на уровне модели многоуровневой сети, то есть интерфейсе, наиболее близком к пользователю.Файл с исходным кодом находится/keras/engine/training.pyи/keras/model.py, класс для наблюденияModel
иSequential
.
Первая статья из этой серии:[Примечания к источнику] Tensor, Node и Layer для анализа исходного кода keras
Вторая статья:[Source Notes] Контейнер для анализа исходного кода keras
Model
: Добавлена информация о тренировкахContainer
Model.compile()
В основном завершить настройкуoptimizer
, loss
, metrics
и так далее, покаfit
, evaluate
подожди там нетcompile
настройка в процессе.
def compile(self, optimizer, loss, metrics=None, loss_weights=None,
sample_weight_mode=None, **kwargs):
loss = loss or {}
self.optimizer = optimizers.get(optimizer)
self.sample_weight_mode = sample_weight_mode
self.loss = loss
self.loss_weights = loss_weights
loss_function = losses.get(loss)
loss_functions = [loss_function for _ in range(len(self.outputs))]
self.loss_functions = loss_functions
# Prepare targets of model.
self.targets = []
self._feed_targets = []
for i in range(len(self.outputs)):
shape = self.internal_output_shapes[i]
name = self.output_names[i]
target = K.placeholder(ndim=len(shape),
name=name + '_target',
sparse=K.is_sparse(self.outputs[i]),
dtype=K.dtype(self.outputs[i]))
self.targets.append(target)
self._feed_targets.append(target)
# Prepare metrics.
self.metrics = metrics
self.metrics_names = ['loss']
self.metrics_tensors = []
# Compute total loss.
total_loss = None
for i in range(len(self.outputs)):
y_true = self.targets[i]
y_pred = self.outputs[i]
loss_weight = loss_weights_list[i]
if total_loss is None:
total_loss = loss_weight * output_loss
else:
total_loss += loss_weight * output_loss
for loss_tensor in self.losses:
total_loss += loss_tensor
self.total_loss = total_loss
self.sample_weights = sample_weights
Model
объектfit()
инкапсулированный метод_fit_loop()
внутренний метод, в то время как_fit_loop()
Ключевыми этапами метода являются_make_train_function()
метод завершает, возвращаетhistory
Объект, используемый для обработки callback-функции.
def fit(self, x=None, y=None, ...):
self._make_train_function()
f = self.train_function
return self._fit_loop(f, ins, ...)
существует_fit_loop()
В методе функция обратного вызова выполняет такие задачи, как мониторинг и запись процесса обучения.train_function
Также применяется к входящим данным:
def _fit_loop(self, f, ins, out_labels=None, batch_size=32,
epochs=100, verbose=1, callbacks=None,
val_f=None, val_ins=None, shuffle=True,
callback_metrics=None, initial_epoch=0):
self.history = cbks.History()
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
callbacks = cbks.CallbackList(callbacks)
out_labels = out_labels or []
callbacks.set_model(callback_model)
callbacks.set_params({
'batch_size': batch_size,
'epochs': epochs,
'samples': num_train_samples,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
})
callbacks.on_train_begin()
callback_model.stop_training = False
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
batches = _make_batches(num_train_samples, batch_size)
epoch_logs = {}
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
batch_logs = {}
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
callbacks.on_batch_begin(batch_index, batch_logs)
# 应用传入的train_function
outs = f(ins_batch)
callbacks.on_batch_end(batch_index, batch_logs)
callbacks.on_epoch_end(epoch, epoch_logs)
callbacks.on_train_end()
return self.history
_make_train_function()
метод изoptimizer
Получите информацию о параметре, которую нужно обновить, и передайтеbackend
изfunction
Объект:
def _make_train_function(self):
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
training_updates = self.optimizer.get_updates(
self._collected_trainable_weights,
self.constraints,
self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
Model
другие методыevaluate()
и т.д., сfit()
структура похожа.
Sequential
: Построить внешний интерфейс модели
Sequential
объектModel
Дальнейшая инкапсуляция объекта также является интерфейсом, с которым непосредственно сталкивается пользователь.compile()
, fit()
, predict()
и другие методыModel
практически идентичны, за исключением добавленияadd()
метод, который также является самой основной операцией, которую мы используем для построения сети.
Sequential.add()
Исходный код метода выглядит следующим образом:
def add(self, layer):
# 第一层必须是InputLayer对象
if not self.outputs:
if not layer.inbound_nodes:
x = Input(batch_shape=layer.batch_input_shape,
dtype=layer.dtype, name=layer.name + '_input')
layer(x)
self.outputs = [layer.inbound_nodes[0].output_tensors[0]]
self.inputs = topology.get_source_inputs(self.outputs[0])
topology.Node(outbound_layer=self, ...)
else:
output_tensor = layer(self.outputs[0])
self.outputs = [output_tensor]
self.inbound_nodes[0].output_tensors = self.outputs
self.layers.append(layer)
можно увидеть,add()
метод всегда гарантирует, что первый слой сетиInputLayer
объект и примените только что добавленный слой кoutputs
, чтобы обновить его. Итак, по существу, вModel
Добавление новых слоев в модель по-прежнему обновляет модель.outputs
.
@ddlee
Эта статья соответствуетCreative Commons Attribution-ShareAlike 4.0 International License.
Это означает, что вы можете воспроизводить эту статью с указанием авторства, сопровождаемой этим соглашением.
Если вы хотите получать регулярные обновления о моих сообщениях в блоге, пожалуйста, подпишитесьДонгдон ежемесячно.
Ссылка на эту статью: блог Дошел до /posts/электромобиля 5 не…