Недавно, когда я изучал tensorflow 2.0, я увидел несколько очень полезных расширенных функций, здесь я буду записывать их использование.
1.tf.gather()
tf.gather(params,indices,validate_indices=None,name=None,axis=0)Для простоты понимания сначала передайте тензор, который нужно обработать, а затем передайте операцию выбора над ним, то есть индексный тензор.
Вот пример:
Рассмотрим пример журнала оценок класса: 4 класса, 35 учеников в каждом классе, 8 предметов и тензорная фигура [4, 35, 8], содержащая журнал оценок.
#创建成绩册
record=tf.random.uniform([4,35,8],maxval=100)
record.numpy
Если нам сейчас нужно собрать журналы успеваемости 1-го и 2-го классов, мы можем сделать это, нарезая
record1_2=record[0:2]
record1_2.numpy
Вы также можете использовать tf.gather(), чтобы получить тот же результат.
#从第一个维度(班级)选择前两个班级
record1_2=tf.gather(record,[0,1],axis=0)
record1_2.numpy
Однако для другого требования вам необходимо проверить оценки учащихся № 1, 4, 9, 12, 13 и 27 во всех классах.В настоящее время получить результаты с помощью срезов непросто. очень простой в использовании сбор.
#从第二个维度(学生)抽取
score=tf.gather(record,[0,3,8,11,12,26],axis=1)
score.numpy
2.tf.gather_nd()
С помощью tf.gather_nd() цель выборки нескольких точек может быть достигнута путем указания координат каждой выборки. пример:Получите оценки по предмету 2 учащегося 1 класса, по предмету 3 учащегося 2, по предмету 4 учащегося 3, по предмету 4 учащегося 3 класса.
score=tf.gather_nd(record,[[0,0,1],[1,1,2],[2,2,3]])
score.numpy
3.tf.scatter_nd()
Часть данных тензора может быть эффективно обновлена через tf.scatter_nd(индексы, обновления, форма), но она может быть обновлена только на доске всех 0 тензоров, поэтому может потребоваться объединение других операций для реализации функции обновления данных. существующих тензоров.
#需要刷新的位置
indices = tf.constant([[4], [3], [1], [7]])
# 构造需要写入的数据
updates = tf.constant([4.4, 3.3, 1.1, 7.7])
# 在长度为 8 的全 0 向量上根据 indices 写入 updates
tf.scatter_nd(indices, updates, [8])
4.tf.meshgrid()
С помощью tf.meshgrid можно легко сгенерировать координаты точки выборки двумерной сетки, или это можно понять как выполнение матричного умножения, повторение времени столбца y по строке и повторение времени строки x по столбцу ( механизм передачи) пример:выполнить
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['axes.unicode_minus']=False
x = tf.linspace(-8.,8,100) # 设置 x 坐标的间隔
y = tf.linspace(-8.,8,100) # 设置 y 坐标的间隔
x,y = tf.meshgrid(x,y) # 生成网格点,并拆分后返回
print(x.shape,y.shape) # 打印拆分后的所有点的 x,y 坐标张量 shape
z = tf.sqrt(x**2+y**2)
z = tf.sin(z)/z # sinc 函数实现
fig = plt.figure()
ax = Axes3D(fig)
# 根据网格点绘制 sinc 函数 3D 曲面
ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50)
plt.show()
Или простой пример может лучше отразить его преобразование
x=tf.constant([1,2,3])
y=tf.constant([3,4,5])
x,y = tf.meshgrid(x,y)
print(x.numpy,y.numpy)
Таким образом, роль meshgrid ясна с первого взгляда.