На основе JAX мы можем легко реализовать структуру глубокого обучения.

искусственный интеллект Python
На основе JAX мы можем легко реализовать структуру глубокого обучения.

Это 10-й день моего участия в августовском испытании обновлений. Узнайте подробности события:Испытание августовского обновления

Обзор JAX

Что такое ДЖАКС

Предшественником JAX является Autograd, что означает, что JAX является обновленной версией Autograd, JAX может автоматически различать программы Python и NumPy. Его можно отличить по большому подмножеству функций Python, включая автоматический вывод циклов, ветвей, операторов рекурсии и замыкания, а также производных третьего порядка (производные третьего порядка — это производные от производных производных исходной функции. так называемая производная третьего порядка, то есть производная производной от производной исходной функции, исходная функция выводится трижды). Через grad JAX поддерживает вывод обратного режима и прямого режима, и эти два режима можно комбинировать в любом порядке с определенной гибкостью.

001.jpeg

Для кого JAX?

По сравнению с Tensorflow и Pytorch, JAX все еще относительно примитивен (базовый), и многие вещи все еще нужно реализовать самостоятельно.Может быть, вы спросите, нужно ли реализовывать фреймворк глубокого обучения самостоятельно?Преимущество самостоятельной реализации заключается в том, что у вас есть лучший контроль над проблемами Настройка сильнее, поэтому JAX предназначен для исследователей, а не разработчиков, что, я думаю, всем нужно ясно понимать, когда они начинают понимать эту библиотеку.

003.jpeg

Изучение мотивации JAX

  • Наиболее ориентированный на производительность, при использовании существующей среды машинного обучения и столкновении с узкими местами производительности, но при отсутствии понимания базовых структурных принципов C++ и GPU, вы можете рассмотреть JAX для реконструкции вашей собственной модели (сам я еще не пробовал)
  • Если вы хотите реализовать платформу глубокого обучения на основе Python, вы можете рассмотреть JAX

Для завершения крупномасштабных данных

  • Аппаратное ускорение
  • Автоматический вывод для операций оптимизации
  • Операции слияния, например np.sum((pres - target) ** 2)
  • Параллельная обработка данных и вычислений

006.png

JAX-установка

pip install --upgrade jax jaxlib

Установить графический процессор

pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

009.jpeg

JAX можно рассматривать как Numpy, работающий на GPU и TPU.

Numpy — это библиотека для научных вычислений. Даже сегодня популярные библиотеки глубокого обучения, такие как tensorflow и pytorch, используются для реализации моделей и сетей, и несколько строк кода с использованием numpy необходимы. Можно увидеть важность numpy, но когда numpy родился, он не использовал много GPU для поддержки операций, поэтому программы numpy не могут работать на GPU, но JAX на самом деле не numpy.Он просто использует API-интерфейс numpy, чтобы разработчики чувствовали себя что они используют Numpy, используя numpy без разбора.

import numpy as np

Введите numpy и создайте псевдоним np для введения numpy для удобства.

import numpy as np
x = np.random.rand(2000,2000)
print(x)

    [[0.56745022 0.4247945  0.32374621 ... 0.72424614 0.31471484 0.75709393]
     [0.76504917 0.41393967 0.1195595  ... 0.27311255 0.36763284 0.39811399]
     [0.30034904 0.8224698  0.0160814  ... 0.75720634 0.72237672 0.09741124]
     ...
     [0.14822982 0.918704   0.22328525 ... 0.67143212 0.91682163 0.65214596]
     [0.25847224 0.7675988  0.64836721 ... 0.19096599 0.89869396 0.22051008]
     [0.23031364 0.60925244 0.72548038 ... 0.63396252 0.13415147 0.0674989 ]]
2 * x

Таким образом, матричные операции интуитивно понятны.Например, умножение каждого элемента x на 2 может использовать вышеупомянутую интуитивно понятную операцию без обхода каждого элемента матрицы.

    array([[1.13490044, 0.849589  , 0.64749241, ..., 1.44849228, 0.62942968,
            1.51418785],
           [1.53009834, 0.82787934, 0.239119  , ..., 0.54622511, 0.73526569,
            0.79622798],
           [0.60069808, 1.6449396 , 0.03216279, ..., 1.51441268, 1.44475343,
            0.19482249],
           ...,
           [0.29645964, 1.83740799, 0.4465705 , ..., 1.34286423, 1.83364326,
            1.30429192],
           [0.51694448, 1.5351976 , 1.29673443, ..., 0.38193199, 1.79738792,
            0.44102015],
           [0.46062729, 1.21850487, 1.45096075, ..., 1.26792504, 0.26830294,
            0.1349978 ]])

np.sin(x)

Для некоторых сложных операций, таких какnp.sinnumpy также хорошо с этим справляется.

    array([[0.53748363, 0.41213356, 0.31812038, ..., 0.66257099, 0.30954533,
            0.68681211],
           [0.69257247, 0.40221938, 0.11927486, ..., 0.26972993, 0.35940746,
            0.38768052],
           [0.29585364, 0.73282855, 0.0160807 , ..., 0.68689382, 0.66116964,
            0.09725726],
           ...,
           [0.14768759, 0.79481581, 0.2214345 , ..., 0.62210787, 0.79367208,
            0.60689338],
           [0.25560384, 0.69440939, 0.60388576, ..., 0.18980742, 0.78251439,
            0.21872738],
           [0.2282829 , 0.57225456, 0.66349493, ..., 0.59234195, 0.13374945,
            0.06744766]])
x - x.mean(0)
    array([[ 0.05966959, -0.07397188, -0.18537367, ...,  0.21733322,
            -0.18467283,  0.25997255],
           [ 0.25726854, -0.08482671, -0.38956037, ..., -0.23380037,
            -0.13175483, -0.09900739],
           [-0.20743159,  0.32370341, -0.49303848, ...,  0.25029342,
             0.22298905, -0.39971013],
           ...,
           [-0.35955081,  0.41993761, -0.28583463, ...,  0.1645192 ,
             0.41743396,  0.15502459],
           [-0.24930839,  0.26883241,  0.13924734, ..., -0.31594693,
             0.39930629, -0.2766113 ],
           [-0.27746699,  0.11048605,  0.2163605 , ...,  0.1270496 ,
            -0.3652362 , -0.42962248]])

np.dot(x,x)

Точечное умножение между матрицами тоже очень удобно

    array([[499.08919102, 490.98247709, 495.18751355, ..., 498.40635521,
            494.50937914, 485.34695773],
           [510.29685902, 499.95239357, 511.85978277, ..., 509.82817989,
            495.05226925, 507.41925595],
           [502.82328413, 501.8213885 , 506.67580735, ..., 508.35889233,
            492.64972834, 493.06081799],
           ...,
           [502.20453325, 496.38140482, 508.98725444, ..., 505.05666502,
            490.64576912, 491.95629717],
           [515.66634283, 498.26014692, 516.70676734, ..., 508.06152946,
            506.435225  , 500.36645682],
           [509.67692906, 502.64662385, 509.47906271, ..., 509.0583251 ,
            505.48856182, 493.5220343 ]])

Далее давайте посмотрим на метод, предоставляемый модулем numpy jax, который похож на метод numpy, Давайте сравним приведенные выше операции numpy.jax.numpyсделать это снова.

import jax.numpy as jnp
y = jnp.array(x)
y

конвертировать объекты numpy вDeviceArray

    DeviceArray([[0.5674502 , 0.4247945 , 0.3237462 , ..., 0.72424614,
                  0.31471485, 0.7570939 ],
                 [0.76504916, 0.41393968, 0.1195595 , ..., 0.27311257,
                  0.36763284, 0.398114  ],
                 [0.30034903, 0.8224698 , 0.0160814 , ..., 0.7572063 ,
                  0.7223767 , 0.09741125],
                 ...,
                 [0.14822982, 0.918704  , 0.22328524, ..., 0.67143214,
                  0.9168216 , 0.652146  ],
                 [0.25847223, 0.7675988 , 0.6483672 , ..., 0.190966  ,
                  0.898694  , 0.22051008],
                 [0.23031364, 0.60925245, 0.7254804 , ..., 0.6339625 ,
                  0.13415147, 0.0674989 ]], dtype=float32)
2 * y

    DeviceArray([[1.1349005 , 0.849589  , 0.6474924 , ..., 1.4484923 ,
                  0.6294297 , 1.5141878 ],
                 [1.5300983 , 0.82787937, 0.23911901, ..., 0.54622513,
                  0.7352657 , 0.796228  ],
                 [0.60069805, 1.6449395 , 0.03216279, ..., 1.5144126 ,
                  1.4447534 , 0.19482249],
                 ...,
                 [0.29645965, 1.837408  , 0.4465705 , ..., 1.3428643 ,
                  1.8336432 , 1.304292  ],
                 [0.51694447, 1.5351976 , 1.2967345 , ..., 0.381932  ,
                  1.797388  , 0.44102016],
                 [0.4606273 , 1.2185049 , 1.4509608 , ..., 1.267925  ,
                  0.26830295, 0.1349978 ]], dtype=float32)
jnp.sin(y)

    DeviceArray([[0.53748363, 0.41213354, 0.31812036, ..., 0.662571  ,
                  0.30954534, 0.6868121 ],
                 [0.6925725 , 0.40221938, 0.11927487, ..., 0.26972994,
                  0.35940745, 0.38768053],
                 [0.2958536 , 0.73282856, 0.0160807 , ..., 0.6868938 ,
                  0.66116965, 0.09725726],
                 ...,
                 [0.1476876 , 0.79481584, 0.2214345 , ..., 0.6221079 ,
                  0.7936721 , 0.6068934 ],
                 [0.25560382, 0.69440943, 0.60388577, ..., 0.18980742,
                  0.78251445, 0.21872738],
                 [0.2282829 , 0.5722546 , 0.66349494, ..., 0.59234196,
                  0.13374946, 0.06744765]], dtype=float32)
y - y.mean(0)

    DeviceArray([[ 0.05966955, -0.0739719 , -0.18537366, ...,  0.2173332 ,
                  -0.18467283,  0.2599725 ],
                 [ 0.2572685 , -0.08482671, -0.38956037, ..., -0.23380038,
                  -0.13175485, -0.0990074 ],
                 [-0.20743164,  0.32370338, -0.49303848, ...,  0.25029337,
                   0.22298902, -0.39971015],
                 ...,
                 [-0.35955083,  0.41993758, -0.2858346 , ...,  0.16451919,
                   0.41743392,  0.15502459],
                 [-0.24930844,  0.26883242,  0.13924736, ..., -0.31594694,
                   0.3993063 , -0.27661133],
                 [-0.277467  ,  0.11048606,  0.21636051, ...,  0.12704957,
                  -0.36523622, -0.4296225 ]], dtype=float32)
jnp.dot(y,y)

    DeviceArray([[499.08923, 490.98248, 495.18756, ..., 498.4064 , 494.50937,
                  485.347  ],
                 [510.2968 , 499.95236, 511.85983, ..., 509.8281 , 495.0523 ,
                  507.4191 ],
                 [502.82324, 501.82147, 506.67572, ..., 508.35886, 492.64972,
                  493.06076],
                 ...,
                 [502.20465, 496.3814 , 508.9873 , ..., 505.0567 , 490.6458 ,
                  491.95618],
                 [515.66626, 498.2601 , 516.70667, ..., 508.06168, 506.43524,
                  500.3665 ],
                 [509.67685, 502.64664, 509.47913, ..., 509.0583 , 505.48856,
                  493.52206]], dtype=float32)

%timeit np.dot(x,x)
%timeit jnp.dot(y,y)
    47.2 ms ± 5.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    2.16 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Из трудоемкого сравнения выполнения двух методов,jpn.dotВсе равно большое преимущество.

007.jpeg

JIT-компиляция (даже если она скомпилирована)

Чтобы использовать возможности XLA, код компилируется в ядро ​​XLA. Здесь в игру вступает jit. Чтобы использовать XLA и jit, вы можете использовать функцию jit() или аннотацию @jit.

def f(x):
    for i in range(10):
        x -= 0.1 * x
    return x

Здесь мы определяем функциюf(x), сама функция не имеет реального значения и предназначена для иллюстрации JIT-компиляции

f(x)

Мы можем использовать numpy для выполнения операций над матрицами


    array([[0.19785766, 0.14811668, 0.11288332, ..., 0.25252901, 0.10973428,
            0.26398233],
           [0.26675615, 0.14433184, 0.04168782, ..., 0.09522846, 0.12818565,
            0.13881376],
           [0.10472524, 0.28677749, 0.00560724, ..., 0.26402153, 0.25187719,
            0.0339652 ],
           ...,
           [0.05168454, 0.32033228, 0.07785475, ..., 0.2341139 , 0.31967594,
            0.22738924],
           [0.0901237 , 0.26764515, 0.22607167, ..., 0.06658573, 0.31335521,
            0.07688711],
           [0.0803054 , 0.21243319, 0.25295937, ..., 0.22104906, 0.04677572,
            0.02353541]])

f(y)

Вы можете использовать jax.numpy для выполнения этих операций с матрицами.

    DeviceArray([[0.19785768, 0.1481167 , 0.11288333, ..., 0.25252903,
                  0.10973427, 0.26398236],
                 [0.26675615, 0.14433186, 0.04168782, ..., 0.09522847,
                  0.12818564, 0.13881375],
                 [0.10472523, 0.2867775 , 0.00560724, ..., 0.26402152,
                  0.2518772 , 0.0339652 ],
                 ...,
                 [0.05168454, 0.32033232, 0.07785475, ..., 0.23411393,
                  0.31967595, 0.22738926],
                 [0.09012369, 0.26764515, 0.22607167, ..., 0.06658573,
                  0.31335524, 0.07688711],
                 [0.08030539, 0.21243319, 0.25295934, ..., 0.22104907,
                  0.04677573, 0.02353541]], dtype=float32)

Здесь мы вычисляем некоторыеf(y)Выполнение занимает время, потому что оно выполняется синхронно, поэтому событие выглядит длинным.Далее давайте используем JIT для выполнения этой функции.

%timeit f(y)
    3.42 ms ± 31.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Прежде чем использовать JIT, нам нужно представить пакет jit, который также более удобен в использовании.jitпарная функцияfПросто оберните и получите JIT

from jax import jit
g = jit(f)
%timeit g(y)
    88.2 µs ± 560 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Функции, содержащие несколько операций numpy, могут бытьjax.jit()провести just-in-timeКомпилируется в единую программу CUDA для последующего выполнения, что еще больше ускоряет работу.

автоматическая дифференциация

Автоматическое дифференцирование с помощью функции grad(), полезной для приложений глубокого обучения, упрощает выполнение обратного распространения ошибки. В глубоком обучении мы обновляем параметры с помощью градиентов, поэтому автоматическая деривация является основной задачей и трудностью реализации фреймворка глубокого обучения.

def f(x):
    return x * jnp.sin(x)

определить функцию здесь

f(x)=xsin(x)f(x) = x \sin(x)
f(3)
    DeviceArray(0.42336, dtype=float32)

Для вывода этой функции наше цепное правило и вывод общих функций можно получить следующим образом

f'(x)=sin(x)+xcos(x)f^{\prime}(x) = \sin(x) + x \cos(x)
def grad_f(x):
    return jnp.sin(x) + x * jnp.cos(x)
grad_f(3)
    DeviceArray(-2.8288574, dtype=float32)

Введите grad jax, а затем grad оборачивает f, чтобы вернуть функцию вывода. Автоматический вывод помогает поддерживать вывод цепочки, упрощая обратное распространение для разработки программы. На самом деле, сложность среды глубокого обучения заключается в обратном выводе

from jax import grad
grad_f_jax = grad(f)
grad_f_jax(3.0)
    DeviceArray(-2.8288574, dtype=float32)

векторизация

vmap — это преобразование функции, а JAX предоставляет алгоритм автоматической векторизации посредством преобразования vmap, что значительно упрощает этот тип вычислений, что позволяет исследователям работать с новыми алгоритмами, не сталкиваясь с проблемами пакетной обработки. Пример выглядит следующим образом:

def square(x):
    return jnp.sum(x ** 2)

определениеsquareФункция возводит в квадрат каждый элемент вектора, а затем суммирует вектор.Вы можете думать об этом как о том, что сначала выполняется сопоставление с вектором, а затем выполняется операция сокращения.

square(jnp.arange(100))
    DeviceArray(328350, dtype=int32)

Чтобы объяснить vmap, мы можем увидеть, как numpy реализует то, что такое vmap. В JAX API также есть преобразование, возможно, вы не знаете о преимуществах векторного отображения vmap(). Возможно, вы знакомы с функцией map для управления каждым элементом массива вдоль оси массива.В vmap вместо того, чтобы помещать цикл наружу, цикл помещается в исходную операцию функции, чтобы получить лучшую производительность.

x = jnp.arange(100).reshape(10,10)
[square(row) for row in x]
    [DeviceArray(285, dtype=int32),
     DeviceArray(2185, dtype=int32),
     DeviceArray(6085, dtype=int32),
     DeviceArray(11985, dtype=int32),
     DeviceArray(19885, dtype=int32),
     DeviceArray(29785, dtype=int32),
     DeviceArray(41685, dtype=int32),
     DeviceArray(55585, dtype=int32),
     DeviceArray(71485, dtype=int32),
     DeviceArray(89385, dtype=int32)]
from jax import vmap
vmap(square)
    <function __main__.square(x)>
vmap(square)(x)
    DeviceArray([  285,  2185,  6085, 11985, 19885, 29785, 41685, 55585,
                 71485, 89385], dtype=int32)