베타리더 후기 viii
지은이 소개 x
JAX/Flax LAB 소개 xii
들어가며 xiii
이 책에 대하여 xiv
CHAPTER 1 JAX/Flax를 공부하기 전에 1
1.1 JAX/Flax에 대한 소개와 예시 1
__1.1.1 JAX란 1
__1.1.2 Flax란 2
__1.1.3 JAX로 이루어진 기타 프레임워크들 3
__1.1.4 JAX 프레임워크 사용 예시 3
1.2 함수형 프로그래밍에 대한 이해 5
__1.2.1 부수 효과와 순수 함수 5
__1.2.2 불변성과 순수 함수 7
__1.2.3 정리하며 8
1.3 JAX/Flax에서 자주 사용하는 파이썬 표준 라이브러리 9
__1.3.1 functools.partial( 10
__1.3.2 typing 모듈 12
__1.3.3 정리하며 13
1.4 JAX/Flax 설치 방법 14
__1.4.1 로컬에 JAX/Flax 설치하기 14
__1.4.2 코랩에서 TPU 사용하기 14
CHAPTER 2 JAX의 특징 17
2.1 NumPy에서부터 JAX 시작하기 18
__2.1.1 JAX와 NumPy 비교하기 18
__2.1.2 JAX에서 미분 계산하기 19
__2.1.3 손실 함수의 그레이디언트 계산하기 21
__2.1.4 손실 함수의 중간 과정 확인하기 22
__2.1.5 JAX의 함수형 언어적 특징 이해하기 23
__2.1.6 JAX로 간단한 학습 돌려보기 25
2.2 JAX의 JIT 컴파일 28
__2.2.1 JAX 변환 이해하기 29
__2.2.2 함수를 JIT 컴파일하기 32
__2.2.3 JIT 컴파일이 안 되는 경우 34
__2.2.4 JIT 컴파일과 캐싱 37
2.3 자동 벡터화 39
__2.3.1 수동으로 벡터화하기 39
__2.3.2 자동으로 벡터화하기 41
2.4 자동 미분 42
__2.4.1 고차 도함수 43
__2.4.2 그레이디언트 중지 46
__2.4.3 샘플당 그레이디언트 49
2.5 JAX의 난수 52
주요 내용
함수형 프로그래밍, 파이썬 라이브러리 등 JAX 사용 시 알아야 할 기초
JIT 컴파일, 자동 벡터화, pytree, 병렬처리 등 JAX의 주요 특징
CNN 튜토리얼로 알아보는 Flax 기초
ResNet, DCGAN, CLIP 모델을 구축하며 Flax에 익숙해지기
코랩, 캐글에서 TPU 환경 설정하기
책 속에서
JAX는 구글에서 개발한 고성능 수치 계산 라이브러리로, 특히 병렬 가속화 기능을 통해 대규모 모델의 효율적인 학습과 추론이 가능합니다. Flax는 JAX 기반의 심플한 신경망 라이브러리로, JAX의 장점을 살려 유연하고 확장 가능한 모델 구축을 지원합니다. 이 책은 모두의연구소 JAX/Flax LAB이 다양한 경험과 지식을 바탕으로 JAX를 어떻게 실용적으로 활용할 수 있는지에 중점을 두고 집필한 책입니다. 이론만을 설명하는 것이 아니라 실제 예제를 통해 적용 방법까지 소개합니다.
--- p.xiv
JAX에서는 해당 메서드를 데커레이터--- p.decorator로 활용합니다. 데커레이터로 사용하면 코드 간결성이나 코드 재사용이 늘어난다는 장점을 갖고 있습니다. --- p.… 이번 예제는 @partial에만 집중해서 살펴보겠습니다. 해당 데커레이터는 jax.jit라는 함수에서 고정시키고 싶은 인수인 n을 static_argnames로 고정시키고 컴파일됩니다. 이 방식을 취하면 n은 컴파일되어 추가적인 계산을 진행하지 않습니다. / 따라서 JAX에서 partial 데커레이터를 사용하면 굳이 선언할 필요 없이 병렬처리를 할 수 있게 도와줍니다.
--- p.11
먼저 SELU--- p.scaled exponential linear unit를 구현한 예시를 봅시다. --- p.… 이 출력 결과는 구글 코랩의 T4 가속기에서 실행한 결과입니다. 이제 XLA 컴파일러를 이용해보겠습니다. JAX는 jax.jit 변환을 통해 JAX와 호환되는 함수들을 JIT 컴파일합니다. 얼마나 빨라지는지 확인해보겠습니다. --- p.