Это 11-й день моего участия в августовском испытании обновлений.Подробности о событии:Испытание августовского обновления
По сути, JAX — это библиотека, которая предоставляет API, аналогичный NumPy, в основном для написания программ манипулирования массивами дляконвертировать. Некоторые люди даже думают, что JAX можно рассматривать как Numpy v2, который не только ускоряет Numpy, но также обеспечивает функцию автоматического деривации (градации) для Numpy, позволяя нам реализовать инфраструктуру машинного обучения только с JAX.
Следующий шаг — объяснить, почему JAX предоставляет API, аналогичный NumPy. Теперь вы можете думать о JAX как о запущенном NumPy с автоматическим выводом поверх ускорителя.
import jax
import jax.numpy as jnp
x = jnp.arange(10)
print(x)
Если вы знакомы или написали что-то с numpy, приведенный выше код не должен быть незнакомым.В этом прелесть JAX.Плавный переход от numpy к JAX заключается в том, что вам не нужно изучать новый API. Код, который ранее был реализован в numpy, можно преобразовать с помощьюjnp
заменятьnp
, программа также может работать, конечно, есть отличия, которые будут введены позже. существуетjnp
является переменной типа DeviceArray, именно так JAX представляет массивы.
Теперь мы вычислим скалярное произведение двух векторов,block_until_ready
Запустите код на устройстве GPU без изменения кода без изменения кода. использовать%timeit
чтобы проверить работоспособность.
Технические подробности: при вызове функции JAX соответствующая операция отправляется ускорителю, который вычисляется асинхронно. Следовательно, массив, возвращаемый вычислением, не обязательно "заполнен" к моменту возврата функции. Поэтому, если результат не требуется немедленно, выполнение Python не будет заблокировано, поскольку расчет асинхронный. Поэтому, если не установлен block_until_ready, мы будем синхронизировать только отправку, а не фактическое вычисление. См. документацию JAX.Асинхронное планирование
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 6.37 ms per loop
Первая трансформация JAX: град
Фундаментальной особенностью JAX является то, что он позволяетфункция преобразования. Одним из наиболее часто используемых преобразований являетсяjax.grad
, который принимает числовую функцию, написанную на Python, и возвращает новую функцию Python, которая вычисляет градиент исходной функции. определить функциюsum_of_squares
, который принимает массив и возвращает сумму квадратов каждого элемента массива.
def sum_of_squares(x):
return jnp.sum(x**2)
правильноsum_of_squares
применениеjax.grad
вернет другую функцию, эта функцияsum_of_squares
Градиент относительно его первого аргумента x .
Затем передайте массив в эту производную функцию, чтобы вернуть производную по отношению к каждому элементу в массиве.
sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares(x))
print(sum_of_squares_dx(x))
0.0
[2. 4. 6. 8.]
Сделать это можно по аналогии с векторным исчислениемОператор jax.grad, если функцияввод далjax.grad
, что эквивалентно возвратуфункция Функция, используемая для вычисления градиента ?.
Так же,jax.grad(f)
это функция, которая вычисляет градиент, поэтомуjax.grad(f)(x)
даf
существуетx
градиент в . (иТакой же,jax.grad
Работает только для функций со скалярным выводом, иначе будет выдана ошибка)
Это сильно отличает JAX API от других фреймворков глубокого обучения, поддерживающих автоматический вывод, таких как Tensorflow и PyTorch, где мы можем использовать сам тензор потерь для вычисления градиента (например, вызывая loss.backward() для вычисления градиента). JAX API работает непосредственно с функциями, ближе к базовой математике. Как только вы привыкнете к такому способу ведения дел, это станет естественным: ваша функция потерь в коде на самом деле является функцией параметров и данных, и вы найдете ее градиент точно так же, как в математике.
Такой способ ведения дел упрощает и упрощает управление такими вещами, как дифференцирование переменных. По умолчанию jax.grad найдет градиент относительно первого параметра. В приведенном ниже примере результатом sum_squared_error_dx будет градиент sum_squared_error по отношению к x.
def sum_squared_error(x, y):
return jnp.sum((x-y)**2)
sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
print(sum_squared_error_dx(x, y))
Если вам нужно рассчитать градиент разных параметров (или нескольких параметров), вы можете установить argnums для достижения.
[-0.20000005 -0.19999981 -0.19999981 -0.19999981]
jax.grad(sum_squared_error, argnums=(0, 1))(x, y) # Find gradient wrt both x & y
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))
Означает ли это, что при машинном обучении модели должны писать функции с огромными списками параметров, по одному на каждый массив параметров модели? JAX оснащен механизмами объединения массивов в структуры данных, называемые "pytrees".jax.grad
Использование такое.
Стоимость и Град
jax.value_and_grad(sum_squared_error)(x, y)
(DeviceArray(0.03999995, dtype=float32),
DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))
дополнительные данные
Помимо желания записать числовые значения, мы часто хотим сообщить о некоторых промежуточных результатах, полученных при вычислении функции потерь. Но если мы попытаемся использовать обычныеjax.grad
Для этого вы столкнетесь с неприятностями.
def squared_error_with_aux(x, y):
return sum_squared_error(x, y), x-y
jax.grad(squared_error_with_aux)(x, y)
Приведенное выше выполнение кода сообщит об ошибке, и вам нужноgrad
Задайте параметр в функции.
jax.grad(squared_error_with_aux, has_aux=True)(x, y)
Это потому чтоjax.grad
Он определен только для скалярных функций, и преобразованная функция вернет кортеж. Поскольку члены группы содержат некоторые вспомогательные данные, этоhas_aux
эффект.
Чем JAX отличается от NumPy
В приведенных выше примерах мы обнаружили, что дизайн API jax.numpy в основном соответствует API NumPy. Однако не все имеют некоторые отличия. Далее мы представим различия между JAX и Numpy. Наиболее важным отличием является то, что JAX — это скорее функциональный стиль программирования, что является основной причиной того, что Numpy и JAX не только в некоторых моментах одинаковы. Введение в функциональное программирование (FP) выходит за рамки этого руководства. Если вы уже знакомы с FP, то использовать JAX будет удобнее, потому что JAX предназначен для функционального программирования.
import numpy as np
x = np.array([1, 2, 3])
def in_place_modify(x):
x[0] = 123
return None
in_place_modify(x)
x
Если вы знакомы с функциональным программированием, когда вы видите выводarray([123, 2, 3])
, проблема будет найдена,in_place_modify
делает некоторые побочные эффекты, обновляя значение x внутри него. Поскольку данные в функциональном программировании должны быть неизменяемыми (неизменяемыми), каждый раз, когда данные изменяются, они не изменяются в исходных данных, а изменяется копия.
in_place_modify(jnp.array(x)
Полезно, эта ошибка дает проход JAX jax.ops.index_* ops
Do – метод без побочных эффектов. Подобно изменению на месте, которое не должно выполняться в исходном массиве по индексу, а вместо этого создается новый массив и изменяется соответствующим образом. Таким образом, вышеуказанная операция сообщит об ошибке в JAX.
def jax_in_place_modify(x):
return jax.ops.index_update(x, 0, 123)
y = jnp.array([1, 2, 3])
jax_in_place_modify(y)
DeviceArray([123, 2, 3], dtype=int32)
В этот момент мы снова смотрим на y и видим, что он не изменился.
y #DeviceArray([1, 2, 3], dtype=int32)
Side-effect-free code is sometimes called functionally pure, or just pure.
Код без побочных эффектов иногда называют функционально чистым, который не является функционально чистым, но не выполняет какое-либо обновление состояния приложения, ввод-вывод или другую работу.
Разве чистая версия не менее эффективна? Строго говоря, да. Это то, что вместо изменения исходных данных мы создаем новый массив для их изменения. Однако вычисления JAX обычно перед запуском преобразуются с помощью другой программы, т.е.jax.jit
Скомпилировать. если мы используемjax.ops.index_update()
Изменяя исходный массив «на месте» и не используя его, компилятор распознает, что на самом деле он компилируется вМодификация на месте, что приводит к эффективному коду.
Конечно, можно смешивать код Python с побочными эффектами и функциональный код JAX, поддерживающий функции хранения.На самом деле, трудно или почти невозможно писать чисто функциональные программы.Поскольку вы все больше и больше знакомитесь с JAX, тем больше вы знакомы, тем лучше вы будете знать, когда использовать JAX, и мы поговорим об этом позже, но пока мы будем помнить, чтобы избежать побочных эффектов в JAX.