Давайте поговорим о терминаторе Numpy, JAX

искусственный интеллект Python
Давайте поговорим о терминаторе Numpy, JAX

Это 11-й день моего участия в августовском испытании обновлений.Подробности о событии:Испытание августовского обновления

По сути, JAX — это библиотека, которая предоставляет API, аналогичный NumPy, в основном для написания программ манипулирования массивами дляконвертировать. Некоторые люди даже думают, что JAX можно рассматривать как Numpy v2, который не только ускоряет Numpy, но также обеспечивает функцию автоматического деривации (градации) для Numpy, позволяя нам реализовать инфраструктуру машинного обучения только с JAX.

022.png

Следующий шаг — объяснить, почему JAX предоставляет API, аналогичный NumPy. Теперь вы можете думать о JAX как о запущенном NumPy с автоматическим выводом поверх ускорителя.

import jax
import jax.numpy as jnp

x = jnp.arange(10)
print(x)

Если вы знакомы или написали что-то с numpy, приведенный выше код не должен быть незнакомым.В этом прелесть JAX.Плавный переход от numpy к JAX заключается в том, что вам не нужно изучать новый API. Код, который ранее был реализован в numpy, можно преобразовать с помощьюjnpзаменятьnp, программа также может работать, конечно, есть отличия, которые будут введены позже. существуетjnpявляется переменной типа DeviceArray, именно так JAX представляет массивы.

Теперь мы вычислим скалярное произведение двух векторов,block_until_readyЗапустите код на устройстве GPU без изменения кода без изменения кода. использовать%timeitчтобы проверить работоспособность.

Технические подробности: при вызове функции JAX соответствующая операция отправляется ускорителю, который вычисляется асинхронно. Следовательно, массив, возвращаемый вычислением, не обязательно "заполнен" к моменту возврата функции. Поэтому, если результат не требуется немедленно, выполнение Python не будет заблокировано, поскольку расчет асинхронный. Поэтому, если не установлен block_until_ready, мы будем синхронизировать только отправку, а не фактическое вычисление. См. документацию JAX.Асинхронное планирование

long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 6.37 ms per loop

Первая трансформация JAX: град

Фундаментальной особенностью JAX является то, что он позволяетфункция преобразования. Одним из наиболее часто используемых преобразований являетсяjax.grad, который принимает числовую функцию, написанную на Python, и возвращает новую функцию Python, которая вычисляет градиент исходной функции. определить функциюsum_of_squares, который принимает массив и возвращает сумму квадратов каждого элемента массива.

def sum_of_squares(x):
  return jnp.sum(x**2)

правильноsum_of_squaresприменениеjax.gradвернет другую функцию, эта функцияsum_of_squaresГрадиент относительно его первого аргумента x .

Затем передайте массив в эту производную функцию, чтобы вернуть производную по отношению к каждому элементу в массиве.

sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))
0.0
[2. 4. 6. 8.]

Сделать это можно по аналогии с векторным исчислениемnablanablaОператор jax.grad, если функцияf(x)f(x)ввод далjax.grad, что эквивалентно возвратуnablanablaфункция Функция, используемая для вычисления градиента ?.

(f)(xi)=fxi(xi)(\nabla f)(x_i) = \frac{\partial f}{\partial x_i}(x_i)

Так же,jax.grad(f)это функция, которая вычисляет градиент, поэтомуjax.grad(f)(x)даfсуществуетxградиент в . (и\nablaТакой же,jax.gradРаботает только для функций со скалярным выводом, иначе будет выдана ошибка)

Это сильно отличает JAX API от других фреймворков глубокого обучения, поддерживающих автоматический вывод, таких как Tensorflow и PyTorch, где мы можем использовать сам тензор потерь для вычисления градиента (например, вызывая loss.backward() для вычисления градиента). JAX API работает непосредственно с функциями, ближе к базовой математике. Как только вы привыкнете к такому способу ведения дел, это станет естественным: ваша функция потерь в коде на самом деле является функцией параметров и данных, и вы найдете ее градиент точно так же, как в математике.

Такой способ ведения дел упрощает и упрощает управление такими вещами, как дифференцирование переменных. По умолчанию jax.grad найдет градиент относительно первого параметра. В приведенном ниже примере результатом sum_squared_error_dx будет градиент sum_squared_error по отношению к x.

def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)

y = jnp.asarray([1.1, 2.1, 3.1, 4.1])

print(sum_squared_error_dx(x, y))

Если вам нужно рассчитать градиент разных параметров (или нескольких параметров), вы можете установить argnums для достижения.

[-0.20000005 -0.19999981 -0.19999981 -0.19999981]
jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))

Означает ли это, что при машинном обучении модели должны писать функции с огромными списками параметров, по одному на каждый массив параметров модели? JAX оснащен механизмами объединения массивов в структуры данных, называемые "pytrees".jax.gradИспользование такое.

Стоимость и Град

jax.value_and_grad(sum_squared_error)(x, y)
(DeviceArray(0.03999995, dtype=float32),
 DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))

дополнительные данные

Помимо желания записать числовые значения, мы часто хотим сообщить о некоторых промежуточных результатах, полученных при вычислении функции потерь. Но если мы попытаемся использовать обычныеjax.gradДля этого вы столкнетесь с неприятностями.

def squared_error_with_aux(x, y):
  return sum_squared_error(x, y), x-y

jax.grad(squared_error_with_aux)(x, y)

Приведенное выше выполнение кода сообщит об ошибке, и вам нужноgradЗадайте параметр в функции.

jax.grad(squared_error_with_aux, has_aux=True)(x, y)

Это потому чтоjax.gradОн определен только для скалярных функций, и преобразованная функция вернет кортеж. Поскольку члены группы содержат некоторые вспомогательные данные, этоhas_auxэффект.

Чем JAX отличается от NumPy

В приведенных выше примерах мы обнаружили, что дизайн API jax.numpy в основном соответствует API NumPy. Однако не все имеют некоторые отличия. Далее мы представим различия между JAX и Numpy. Наиболее важным отличием является то, что JAX — это скорее функциональный стиль программирования, что является основной причиной того, что Numpy и JAX не только в некоторых моментах одинаковы. Введение в функциональное программирование (FP) выходит за рамки этого руководства. Если вы уже знакомы с FP, то использовать JAX будет удобнее, потому что JAX предназначен для функционального программирования.

import numpy as np

x = np.array([1, 2, 3])

def in_place_modify(x):
  x[0] = 123
  return None

in_place_modify(x)
x

Если вы знакомы с функциональным программированием, когда вы видите выводarray([123, 2, 3]), проблема будет найдена,in_place_modifyделает некоторые побочные эффекты, обновляя значение x внутри него. Поскольку данные в функциональном программировании должны быть неизменяемыми (неизменяемыми), каждый раз, когда данные изменяются, они не изменяются в исходных данных, а изменяется копия.

in_place_modify(jnp.array(x)

Полезно, эта ошибка дает проход JAX jax.ops.index_* opsDo – метод без побочных эффектов. Подобно изменению на месте, которое не должно выполняться в исходном массиве по индексу, а вместо этого создается новый массив и изменяется соответствующим образом. Таким образом, вышеуказанная операция сообщит об ошибке в JAX.

def jax_in_place_modify(x):
  return jax.ops.index_update(x, 0, 123)

y = jnp.array([1, 2, 3])
jax_in_place_modify(y)
DeviceArray([123,   2,   3], dtype=int32)

В этот момент мы снова смотрим на y и видим, что он не изменился.

y #DeviceArray([1, 2, 3], dtype=int32)

Side-effect-free code is sometimes called functionally pure, or just pure.

Код без побочных эффектов иногда называют функционально чистым, который не является функционально чистым, но не выполняет какое-либо обновление состояния приложения, ввод-вывод или другую работу.

Разве чистая версия не менее эффективна? Строго говоря, да. Это то, что вместо изменения исходных данных мы создаем новый массив для их изменения. Однако вычисления JAX обычно перед запуском преобразуются с помощью другой программы, т.е.jax.jitСкомпилировать. если мы используемjax.ops.index_update()Изменяя исходный массив «на месте» и не используя его, компилятор распознает, что на самом деле он компилируется вМодификация на месте, что приводит к эффективному коду.

Конечно, можно смешивать код Python с побочными эффектами и функциональный код JAX, поддерживающий функции хранения.На самом деле, трудно или почти невозможно писать чисто функциональные программы.Поскольку вы все больше и больше знакомитесь с JAX, тем больше вы знакомы, тем лучше вы будете знать, когда использовать JAX, и мы поговорим об этом позже, но пока мы будем помнить, чтобы избежать побочных эффектов в JAX.

025.png