Чтение данных с помощью набора данных и итератора Tensorflow!

машинное обучение искусственный интеллект TensorFlow

Когда я сегодня писал код NCF, я обнаружил, что код в Интернете имеет новый способ чтения данных, здесь я вырежу соответствующие фрагменты и поделюсь ими с вами.

Ссылка на статью NCF:woo woo Краткое описание.com/afraid/6173 представляет 4…

Необработанные данные
Наши исходные данные хранятся в файле npy, который представляет собой словарь с тремя ключами, а именно пользователем, элементом и меткой:

data = np.load('data/test_data.npy').item()
print(type(data))

#output
<class 'dict'>

Создайте набор данных tf
Используйте метод tf.data.Dataset.from_tensor_slices, чтобы превратить наши данные в набор данных с тензорным потоком:

dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
#output
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>

Кроме того, превратите наш набор данных в набор данных BatchDataset, чтобы при повторении данных мы могли возвращать размер пакета данных за раз:

dataset = dataset.shuffle(1000).batch(100)
print(type(dataset))

#output
<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>

Видно, что мы используем перемешивание для перетасовки данных до того, как они станут пакетом, а 100 означает размер буфера, то есть перемешиваются каждые 1000 штук.

На данный момент набор данных имеет два атрибута, output_shapes и output_types, Мы создадим итератор на основе этих двух атрибутов для перебора данных.

print(dataset.output_shapes)
print(dataset.output_types)

#output
{'user': TensorShape([Dimension(None)]), 'item': TensorShape([Dimension(None)]), 'label': TensorShape([Dimension(None)])}
{'user': tf.int32, 'item': tf.int32, 'label': tf.int32}

построить итератор
Мы используем свойства двух наборов данных, упомянутых выше, и используем метод tf.data.Iterator.from_structure для создания итератора:

iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)

Итератор должен быть инициализирован:

 sess.run(iterator.make_initializer(dataset))

На этом этапе вы можете использовать get_next() для непрерывного чтения пакетных данных.

def getBatch():
    sample = iterator.get_next()
    print(sample)
    user = sample['user']
    item = sample['item']
    return user,item

Правильная позиция для использования итераторов
Здесь мы вычисляем среднее количество пользователей и элементов в каждой возвращенной партии:

users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)

Итератор-итератор может перемещаться только вперед.Если get_next() вызывается после обхода, будет сообщено об ошибке tf.errors.OutOfRangeError, поэтому требуется try-catch.

try:
    while True:
        print(sess.run([usersum,itemsum]))
except tf.errors.OutOfRangeError:
    print("outOfRange")  

Если вы хотите просмотреть данные несколько раз, оберните слой цикла вокруг инициализации:

for i in range(2):
    sess.run(iterator.make_initializer(dataset))
    try:
        while True:
            print(sess.run([usersum,itemsum]))
    except tf.errors.OutOfRangeError:
        print("outOfRange")

полный код

import numpy as np
import tensorflow as tf


data = np.load('data/test_data.npy').item()
print(type(data))


dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
dataset = dataset.shuffle(10000).batch(100)
print(type(dataset))

print(dataset.output_shapes)
print(dataset.output_types)

iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)

print(type(iterator))


def getBatch():
    sample = iterator.get_next()
    print(sample)
    user = sample['user']
    item = sample['item']
    return user,item


users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(2):
        sess.run(iterator.make_initializer(dataset))
        try:
            while True:
                print(sess.run([usersum,itemsum]))
        except tf.errors.OutOfRangeError:
            print("outOfRange")