Одним из навыков JAX является JIT

искусственный интеллект Python
Одним из навыков JAX является JIT

В этом разделе мы более подробно рассмотрим, как работает JAX. Расскажу о JAXjax.jit()Transform, который будет выполнять JIT-компиляцию функции JAX Python для эффективного выполнения в XLA.

В предыдущем разделе JAX мы узнали, что JAX может преобразовывать функции Python для получения новой функции. Это достигается путем преобразования функции Python в простой промежуточный язык, который называется jaxpr. Затем преобразование работает с представлением jaxpr.

Следующее использованиеjax.make_jaxprjaxpr, чтобы показать, что функция представляет собой функцию Python.

Концептуально первое, что нужно сделать в преобразовании JAX, — это преобразовать функцию Python в облегченную промежуточную форму с хорошим представлением.Этот процесс можно понимать как конкретную трассировку, и Jaxpr выполняет преобразование через внутренний интерпретатор. Одна из причин, по которой JAX может втиснуть так много функций в такой небольшой пакет, заключается не только в том, что он начинает со знакомого гибкого интерфейса программирования (Python с NumPy), но и в том, что он использует настоящий интерпретатор Python для выполнения большей части тяжелой работы. работа по преобразованию сущности вычислений в простой статически типизированный язык выражений с ограниченными функциями высшего порядка. Этот язык является языком jaxpr.

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda  ; a.
  let b = log a
      c = log 2.0
      d = div b c
  in (d,) }

Раздел «Понимание Jaxprs» в документации содержит дополнительную информацию о значении приведенного выше вывода.

Важно отметить, что jaxpr не выполняет побочных эффектов функцийtrace: не найдено в конвертированном jaxprglobal_list.append(x)Содержание. Это особенность, а не ошибка. JAX предназначен для понимания кода без побочных эффектов (он же чистые функции).

Внутреннее представление JAX является чисто функциональным, но, учитывая высокую динамичность языка Python, для пользователей существуют некоторые программные ограничения. Например, автоматическая дифференциация функций Python в JAX поддерживает только чистые функции, и пользователи должны убедиться в этом самостоятельно. Если пользовательский код записывает побочные эффекты, результат выполнения функции, сгенерированной JAX-преобразованием, может не соответствовать ожиданиям. Поскольку функция трассировки JAX является чистой функцией, при изменении глобальных переменных и информации о конфигурации может потребоваться повторная трассировка.

В процессе трассировки JAX оборачивает каждый параметр в объект трассировщика, и эти трассировщики записывают все операции JAX, выполненные с параметром во время вызова функции (это происходит в обычном Python). Затем JAX использует запись трассировщика для рефакторинга всей функции. Результатом этого рефакторинга является промежуточный файл jaxpr. Поскольку трекер не записывает побочные эффекты Python, код побочных эффектов не отображается в jaxpr. В процессе отслеживания все же возникают побочные эффекты.

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))

Примечание. Функция Python print() не является функцией хранения, поскольку операции ввода-вывода и вывода текста можно рассматривать как побочные эффекты, поэтомуprintИ это не чистая функция. Следовательно, никакой print() не появляется в jaxpr.

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda  ; a.
  let b = log a
      c = log 2.0
      d = div b c
  in (d,) }

Видите, что распечатанный x — это объект слежения? Это то, что JAX делает внутри. Тот факт, что код Python запускается хотя бы один раз, является исключительно деталью реализации, и на него нельзя полагаться. Однако это полезно понимать, так как вы можете использовать для вывода промежуточных значений вычислений при отладке.