[도서 리뷰] Jax / Flax로 딥러닝 레벨업

본 리뷰는 Jpub으로부터 도서를 제공받아 작성되었습니다.✏️


 

Jax / Flax를 들어보긴 들어봤는데.. 정확히 어떤 프레임워크인지, 뭐에 사용하는지 잘 몰랐다가 (llm 튜닝에 쓰인다고 해서)

궁금해서 서평 신청을 했다가 운이 좋게 당첨되어 읽어보게 되었습니다! 덕분에 jax / flax 공부도 했고 ㅎㅎ

 

회사 점심시간에 짬짬히 읽었다!

Jax / Flax 와 함께 보내주신 에세이

 

 

우선 이 책은 '모두의 연구소'에서 활동하시는 jax/flax 랩에서 작성해주셨다.

지난번 Google Build AI 컨퍼런스 들으러 갔을 때, 여기 랩짱 이영빈님께서 강연해주시면서 jax/flax에 대해 알게 되었는데 그분들이 작성하신 책이다..!

 

목차

 

목차는 위와 같다
jax / flax로 하는 numpy 부터 tensor까지

 

우선 jax와 flax라는 프레임워크에 대해 소개하고, 이를 통해 직접 딥러닝 모델들을 실습해보는 구조로 되어 있다.

구글 colab 쿡북 예제들도 github 레포로 다 제공되니 충분히 실습할 만 하다..

 

Jax / Flax?

jax / flax에 대해 생소한 사람도 있을것 같다.

jax 는 자동 미분과 XLA를 결합한 고성능 머신러닝 프레임워크다.

고성능 프레임워크라 함은 감이 잘 안올 수 있는데, numpy 함수에도 자동 미분 적용이 가능하고, 구글 딥러닝 전용 하드웨어인 TPU에서 특히 빠른 학습 & 추론이 가능하다고 한다.

 

https://www.datacamp.com/tutorial/combine-google-gemma-with-tpus-fine-tune-and-run-inference-with-enhanced-performance-and-speed



위 이미지에서 볼 수 있듯,

기존 nvidia GPU를 사용해서 머신러닝 모델들을 돌리던 것 대비 훨씬 빠른 속도로 추론이 가능하다.

 

Flax는 구글 브레인 팀에서 구글 리서치 Jax 팀과 협업하여 만든 오픈소스 프레임워크로, jax보다 더 쉬운 구현이 가능하도록 만든 프레임워크라고 한다.

 

사실 내 주변엔 (주변이라고 하기엔 지인이 없다)

거의 다 머신러닝 추론엔 GPU를 사용하고 있기 때문에 jax / Flax를 하시는 분을 거의 못 뵀는데

진짜 고수분들은 다 이걸 한번씩 써보셨더라..

 

 

JAX 의 특징

Jax는 기본적으로 numpy 가 지원하는 연산들을 거의 지원하는 것 같다.

하지만 jax는 함수형 프로그래밍을 따르기 때문에 부수효과가 없고 불변성이 지켜져야 한다는 특징을 유념해야 한다.

그래서 같은 기능을 구현하기 위해서 JAX에서 제공하는 부수효과 없는 함수들을 사용해야 한다.

 

 

JAX를 이용한 Numpy, Tensor

JAX 를 이용해 기존에 사용하던 pytorch 코드, numpy 코드 등을 새로이 작성해 볼 수 있다.

이를 위한 다양한 딥러닝 코드 실습 예제들도 준비되어 있다.

CNN부터 ResNet, DCGAN, CLIP.. 등

 

짧은 시간안에 서평안에 모든 걸 적을 순 없을 것 같아서 새로 jax / flax / pytorch 메뉴를 열고 천천히 스터디를 진행해보려고 한다🥲

CLIP 논문 리뷰를 하면서 함께 jax로 코드를 작성해 볼까 한다..! ㅎㅎ