«Документация JAX на китайском языке» Краткое руководство по JAX

искусственный интеллект глубокое обучение

商标

Новейшие

оригинал:Не терпится подумать. прочитайте документ S.IO/ru/latest/you…

Быстрый старт JAX


Сначала ответьте на вопрос:Что такое ДЖАКС?

Проще говоря, это numpy с ускорением на GPU, который поддерживает автоматическое дифференцирование (autodiff). Как мы все знаем, numpy — это основная библиотека числовых операций в Python, и она широко используется. Никто не может заниматься научными вычислениями или машинным обучением с помощью Python. Но numpy не поддерживает GPU или другие аппаратные ускорители, и нет встроенной поддержки обратного распространения в сочетании с ограничениями скорости самого Python, поэтому мало кто будет напрямую обучать или развертывать модели глубокого обучения с numpy в производственной среде. Вот почему существуют фреймворки глубокого обучения, такие как Theano, TensorFlow и Caffe. Однако у numpy есть свои уникальные преимущества: низкоуровневый, гибкий, простой в отладке, стабильный и знакомый API (в том же духе, что и MATLAB), и исследователи предпочитают его. Основная отправная точка JAX — объединить вышеуказанные преимущества numpy с аппаратным ускорением. Теперь JAX с открытым исходным кодом (github.com/google/jax) заключается в достижении аппаратного ускорения с помощью графического процессора (CUDA). От:Ууху. Call.com/question/30…

 

Сяо Сун сказал: JAX на самом деле является библиотекой научных вычислений (numpy, scipy) и библиотекой нейронных сетей (предоставляющей relu, sigmoid, conv и т. д.), которая поддерживает ускорители (GPU и TPU), По сравнению с PyTorch и TensorFlow, она более гибкая. и универсальный.хорошо. Это также причина, по которой автор рекомендует изучить и выполнить эту работу по переводу и призывает всех изучить и освоить эту структуру.

Поскольку автор не является специалистом по английскому языку, некоторые переводы Nei Rong неизбежно будут ошибочными, комментарии и исправления приветствуются. Для некоторых переводов, в которых автор не уверен, используйте знаки подчеркивания и скобки, чтобы процитировать исходные слова для дополнения, например:автоматическийдифференциал ( differentiation )

 

Официальное определение:JAX — это NumPy на CPU, GPU и TPU с отличным автоматическимразница( differentiation ), который можно использовать для исследований высокопроизводительного машинного обучения.

 

как обновленная версияAutograd, JAX может автоматически различать собственный код Python и NumPy. Его можно отличить по большинству функций Python, включая циклы, if, рекурсию и замыкания, и он может даже принимать производные классы от производных классов. Он поддерживает дифференциацию обратного и прямого режима, и они могут быть составлены в любом порядке.

Новая функция заключается в том, что JAX используетXLAСкомпилируйте и запустите свой код NumPy на ускорителях, таких как GPU и TPU. По умолчанию компиляция происходит в фоновом режиме, а вызовы библиотек компилируются и выполняются вовремя. Однако JAX даже позволяет компилировать ваши собственные функции Python в XLA-оптимизированные ядра на лету, используя однофункциональный API. Компиляцию и автоматическую дифференциацию можно комбинировать произвольно, поэтому вы можете выражать сложные алгоритмы и получать максимальную производительность, не выходя из Python.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

матрица умножения

 

В следующем примере мы будем генерировать случайные данные. Одно большое различие между NumPy и JAX заключается в том, как генерируются случайные числа. Подробнее см.Распространенные ошибки в JAX.

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]

Умножьте две большие матрицы.

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
489 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Мы добавляем,block_until_readyПоскольку по умолчанию JAX использует асинхронное выполнение (см.Асинхронное планирование).

Функции JAX NumPy работают с обычными массивами NumPy.

import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
488 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Это медленнее, потому что каждый раз приходится передавать данные на графический процессор. Вы можете использовать, чтобы убедиться, что NDArray поддерживается памятью устройстваdevice_put().

from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
487 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Выводdevice_put()Все еще как NDArray, но он работает только тогда, когда это необходимо.распечатать, нарисовать, сохранить(печать, рисование, сохранение) на диск, ветки и т. д. копировать значения обратно в ЦП только тогда, когда их значения необходимы. поведениеdevice_put()Эквивалентно функции, но быстрее.jit(lambda x: x)

Если у вас есть GPU (или TPU!), эти вызовы будут выполняться на ускорителе и, вероятно, намного быстрее, чем на CPU.

x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
235 ms ± 546 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX — это не просто NumPy с поддержкой графического процессора. Он также поставляется с некоторыми процедурными преобразованиями, которые полезны при написании числового кода. В настоящее время выделяют три основных:

  • jit(), чтобы ускорить ваш код
  • grad(), заНайдите градиент(производные)
  • vmap(), для автоматической векторизации или пакетной обработки.

Давайте представим их один за другим. Мы также закончим тем, что напишем их интересными способами.

 

использоватьjit()функция ускорения

JAX прозрачно работает на графическом процессоре (если нет, то на процессоре, а скоро появится и TPU!). Однако в приведенном выше примере JAX назначает ядро ​​графическому процессору по одной операции за раз. Если у нас есть последовательность операций, мы можем использовать@jitиспользование декораторовXLAСкомпилируйте несколько операций вместе. Давайте попробуем.

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
4.4 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Мы можем ускорить его с помощью@jit, это будет в первую очередьseluВызовите jit-compile и кэшируйте его потом.

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
860 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

пройти через grad()计算梯度

Помимо вычисления числовой функции, мы также хотим ее преобразовать. Преобразованиеавтоматическая дифференциация. В JAX, как и вкак в Автограде, вы можете использовать егоgrad()функция для вычисления градиента.

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

Давайте начнем спредельный дифференциал(конечные разности) подтверждают правильность наших результатов.

def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569  0.10502338]

Решение для градиента можно выполнить, просто позвонивgrad().grad()иjit()Можно смешивать произвольно. В приведенном выше примере мы сначала дизеримsum_logisticЗатем возьмите его производную. Перейдем к эксперименту по глубокому обучению:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.035325594

Для более продвинутого автодиффа это может бытьjax.vjp()для суммы произведений Якоби вектора обратной модыjax.jvp()Произведение Якоби прямого режима. Их можно произвольно комбинировать друг с другом и с другими преобразованиями JAX. Вот один из способов объединить их, чтобы сформировать функцию, эффективно вычисляющую полный гессиан:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

автоматическая векторизацияvmap()

JAX имеет еще одно преобразование в своем API, которое может оказаться полезным:vmap()Векторизованная карта. Он имеет функции, отображающие оси массива.знаком с семантикой(знакомая семантика), но вместо того, чтобы держать цикл снаружи, цикл помещается в исходную операцию функции для повышения производительности. в сочетании сjit(), это может быть так же быстро, как добавление размеров пакетов вручную.

Мы будем использовать простой пример и преобразовывать произведение матрицы-вектора в произведение матрицы-матрицы, используяvmap(). Хотя в этом конкретном случае это легко сделать вручную, тот же метод можно применить к более сложным функциям.

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

Дана такая функция, какapply_matrix, мы можем зацикливаться на пакетных измерениях в Python, но производительность при этом обычно низкая.

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
4.43 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Мы знаем, как пакетировать эту операцию вручную. при этих обстоятельствах,jnp.dotПрозрачно обрабатывает дополнительные размеры пакетов.

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
51.9 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Однако, предполагая отсутствие поддержки пакетной обработки, наша функция становится более сложной. мы можем использоватьvmap()Автоматически добавлять пакетную поддержку.

@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
79.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Конечно,vmap()можно сочетать с любымjit(),grad()и любые другие преобразования JAX.

Это просто то, что JAX может сделать. Мы рады видеть вас в действии!