Некоторые расширенные функции использования tensorflow2.0

TensorFlow

Недавно, когда я изучал tensorflow 2.0, я увидел несколько очень полезных расширенных функций, здесь я буду записывать их использование.

1.tf.gather()

tf.gather(params,indices,validate_indices=None,name=None,axis=0)Для простоты понимания сначала передайте тензор, который нужно обработать, а затем передайте операцию выбора над ним, то есть индексный тензор.

Вот пример:

Рассмотрим пример журнала оценок класса: 4 класса, 35 учеников в каждом классе, 8 предметов и тензорная фигура [4, 35, 8], содержащая журнал оценок.

#创建成绩册
record=tf.random.uniform([4,35,8],maxval=100)
record.numpy

在这里插入图片描述Если нам сейчас нужно собрать журналы успеваемости 1-го и 2-го классов, мы можем сделать это, нарезая

record1_2=record[0:2]
record1_2.numpy

在这里插入图片描述Вы также можете использовать tf.gather(), чтобы получить тот же результат.

#从第一个维度(班级)选择前两个班级
record1_2=tf.gather(record,[0,1],axis=0)
record1_2.numpy

在这里插入图片描述Однако для другого требования вам необходимо проверить оценки учащихся № 1, 4, 9, 12, 13 и 27 во всех классах.В настоящее время получить результаты с помощью срезов непросто. очень простой в использовании сбор.

#从第二个维度(学生)抽取
score=tf.gather(record,[0,3,8,11,12,26],axis=1)
score.numpy

在这里插入图片描述

2.tf.gather_nd()

С помощью tf.gather_nd() цель выборки нескольких точек может быть достигнута путем указания координат каждой выборки. пример:Получите оценки по предмету 2 учащегося 1 класса, по предмету 3 учащегося 2, по предмету 4 учащегося 3, по предмету 4 учащегося 3 класса.

score=tf.gather_nd(record,[[0,0,1],[1,1,2],[2,2,3]])
score.numpy

在这里插入图片描述

3.tf.scatter_nd()

Часть данных тензора может быть эффективно обновлена ​​через tf.scatter_nd(индексы, обновления, форма), но она может быть обновлена ​​только на доске всех 0 тензоров, поэтому может потребоваться объединение других операций для реализации функции обновления данных. существующих тензоров.

#需要刷新的位置
indices = tf.constant([[4], [3], [1], [7]])
# 构造需要写入的数据
updates = tf.constant([4.4, 3.3, 1.1, 7.7]) 
# 在长度为 8 的全 0 向量上根据 indices 写入 updates
tf.scatter_nd(indices, updates, [8])

在这里插入图片描述

4.tf.meshgrid()

С помощью tf.meshgrid можно легко сгенерировать координаты точки выборки двумерной сетки, или это можно понять как выполнение матричного умножения, повторение времени столбца y по строке и повторение времени строки x по столбцу ( механизм передачи) пример:выполнитьz=sin(x2+y2)x2+y2z=\frac{sin(x^2+y^2)}{x^2+y^2}

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['axes.unicode_minus']=False
x = tf.linspace(-8.,8,100) # 设置 x 坐标的间隔
y = tf.linspace(-8.,8,100) # 设置 y 坐标的间隔
x,y = tf.meshgrid(x,y) # 生成网格点,并拆分后返回
print(x.shape,y.shape) # 打印拆分后的所有点的 x,y 坐标张量 shape

z = tf.sqrt(x**2+y**2) 
z = tf.sin(z)/z # sinc 函数实现

fig = plt.figure()
ax = Axes3D(fig)
# 根据网格点绘制 sinc 函数 3D 曲面
ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50)
plt.show()

在这里插入图片描述Или простой пример может лучше отразить его преобразование

x=tf.constant([1,2,3])
y=tf.constant([3,4,5])
x,y = tf.meshgrid(x,y) 
print(x.numpy,y.numpy)

在这里插入图片描述Таким образом, роль meshgrid ясна с первого взгляда.