Hva er Google JAX? Alt du trenger å vite

Google JAX eller Just After Execution er et rammeverk utviklet av Google for å øke hastigheten på maskinlæringsoppgaver.

Du kan vurdere det som et bibliotek for Python, som hjelper til med raskere oppgaveutførelse, vitenskapelig databehandling, funksjonstransformasjoner, dyp læring, nevrale nettverk og mye mer.

Om Google JAX

Den mest grunnleggende beregningspakken i Python er NumPy-pakken som har alle funksjonene som aggregasjoner, vektoroperasjoner, lineær algebra, n-dimensjonale array- og matrisemanipulasjoner og mange andre avanserte funksjoner.

Hva om vi kunne fremskynde beregningene som utføres ved hjelp av NumPy ytterligere – spesielt for enorme datasett?

Har vi noe som kan fungere like godt på forskjellige typer prosessorer som en GPU eller TPU, uten noen kodeendringer?

Hva med om systemet kunne utføre komponerbare funksjonstransformasjoner automatisk og mer effektivt?

Google JAX er et bibliotek (eller rammeverk, som Wikipedia sier) som gjør nettopp det og kanskje mye mer. Den ble bygget for å optimalisere ytelsen og effektivt utføre maskinlæring (ML) og dyplæringsoppgaver. Google JAX tilbyr følgende transformasjonsfunksjoner som gjør den unik fra andre ML-biblioteker og hjelper til med avansert vitenskapelig beregning for dyp læring og nevrale nettverk:

  • Automatisk differensiering
  • Auto vektorisering
  • Automatisk parallellisering
  • Just-in-time (JIT) kompilering

Google JAX sine unike funksjoner

Alle transformasjonene bruker XLA (Accelerated Linear Algebra) for høyere ytelse og minneoptimalisering. XLA er en domenespesifikk optimeringskompilatormotor som utfører lineær algebra og akselererer TensorFlow-modeller. Å bruke XLA på toppen av Python-koden krever ingen vesentlige kodeendringer!

La oss utforske i detalj hver av disse funksjonene.

Funksjoner i Google JAX

Google JAX kommer med viktige komponerbare transformasjonsfunksjoner for å forbedre ytelsen og utføre dyplæringsoppgaver mer effektivt. For eksempel automatisk differensiering for å få gradienten til en funksjon og finne deriverte av hvilken som helst rekkefølge. Tilsvarende, automatisk parallellisering og JIT for å utføre flere oppgaver parallelt. Disse transformasjonene er nøkkelen til applikasjoner som robotikk, spill og til og med forskning.

En komponerbar transformasjonsfunksjon er en ren funksjon som transformerer et sett med data til en annen form. De kalles komponerbare ettersom de er selvstendige (dvs. disse funksjonene har ingen avhengigheter med resten av programmet) og er tilstandsløse (dvs. den samme inngangen vil alltid resultere i samme utgang).

Y(x) = T: (f(x))

I ligningen ovenfor er f(x) den opprinnelige funksjonen som en transformasjon brukes på. Y(x) er den resulterende funksjonen etter at transformasjonen er brukt.

For eksempel, hvis du har en funksjon kalt «total_bill_amt», og du vil ha resultatet som en funksjonstransformasjon, kan du ganske enkelt bruke transformasjonen du ønsker, la oss si gradient (grad):

  En rask guide til HTTP-statuskoder med infografikk

grad_total_bill = grad(total_bill_amt)

Ved å transformere numeriske funksjoner ved å bruke funksjoner som grad(), kan vi enkelt få deres høyere ordens derivater, som vi kan bruke mye i dyplæringsoptimaliseringsalgoritmer som gradientnedstigning, og dermed gjøre algoritmene raskere og mer effektive. På samme måte, ved å bruke jit(), kan vi kompilere Python-programmer just-in-time (dovent).

#1. Automatisk differensiering

Python bruker autograd-funksjonen for automatisk å skille NumPy og opprinnelig Python-kode. JAX bruker en modifisert versjon av autograd (dvs. grad) og kombinerer XLA (Accelerated Linear Algebra) for å utføre automatisk differensiering og finne derivater av hvilken som helst rekkefølge for GPU (Graphic Processing Units) og TPU (Tensor Processing Units).]

Rask merknad om TPU, GPU og CPU: CPU eller sentral prosesseringsenhet styrer alle operasjonene på datamaskinen. GPU er en ekstra prosessor som forbedrer datakraften og kjører avanserte operasjoner. TPU er en kraftig enhet spesielt utviklet for komplekse og tunge arbeidsbelastninger som AI og dyplæringsalgoritmer.

På samme måte som autograd-funksjonen, som kan differensiere gjennom løkker, rekursjoner, forgreninger og så videre, bruker JAX grad()-funksjonen for revers-modus-gradienter (tilbakepropagasjon). Vi kan også differensiere en funksjon til hvilken som helst rekkefølge ved å bruke grad:

grad(grad(grad(sin θ))) (1.0)

Automatisk differensiering av høyere orden

Som vi nevnte før, er grad ganske nyttig for å finne partielle deriverte av en funksjon. Vi kan bruke en delvis derivert for å beregne gradientnedgangen til en kostnadsfunksjon med hensyn til de nevrale nettverksparametrene i dyp læring for å minimere tap.

Beregning av partiell derivert

Anta at en funksjon har flere variabler, x, y og z. Å finne den deriverte av en variabel ved å holde de andre variablene konstante kalles en partiell derivert. Anta at vi har en funksjon,

f(x,y,z) = x + 2y + z2

Eksempel for å vise partiell derivert

Den partielle deriverte av x vil være ∂f/∂x, som forteller oss hvordan en funksjon endres for en variabel når andre er konstante. Hvis vi utfører dette manuelt, må vi skrive et program for å differensiere, bruke det for hver variabel og deretter beregne gradientnedstigningen. Dette ville blitt en kompleks og tidkrevende affære for flere variabler.

Automatisk differensiering bryter ned funksjonen i et sett med elementære operasjoner, som +, -, *, / eller sin, cos, tan, exp, etc., og bruker deretter kjederegelen for å beregne den deriverte. Vi kan gjøre dette i både forover- og reversmodus.

Dette er ikke det! Alle disse beregningene skjer så fort (vel, tenk på en million beregninger som ligner på de ovennevnte og tiden det kan ta!). XLA tar seg av hastigheten og ytelsen.

  Slik ser du blokkerte meldinger på iPhone

#2. Akselerert lineær algebra

La oss ta den forrige ligningen. Uten XLA vil beregningen ta tre (eller flere) kjerner, hvor hver kjerne vil utføre en mindre oppgave. For eksempel,

Kjerne k1 –> x * 2y (multiplikasjon)

k2 –> x * 2y + z (tillegg)

k3 –> Reduksjon

Hvis den samme oppgaven utføres av XLA, tar en enkelt kjerne seg av alle mellomoperasjonene ved å smelte dem sammen. De mellomliggende resultatene av elementære operasjoner streames i stedet for å lagre dem i minnet, og sparer dermed minne og øker hastigheten.

#3. Just-in-time samling

JAX bruker internt XLA-kompilatoren for å øke utførelseshastigheten. XLA kan øke hastigheten til CPU, GPU og TPU. Alt dette er mulig ved å bruke JIT-kodekjøringen. For å bruke dette kan vi bruke jit via import:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

En annen måte er å dekorere jit over funksjonsdefinisjonen:

@jit
def my_function(x):
	…………some lines of code

Denne koden er mye raskere fordi transformasjonen vil returnere den kompilerte versjonen av koden til den som ringer i stedet for å bruke Python-tolken. Dette er spesielt nyttig for vektorinndata, som matriser og matriser.

Det samme gjelder for alle eksisterende python-funksjoner også. For eksempel funksjoner fra NumPy-pakken. I dette tilfellet bør vi importere jax.numpy som jnp i stedet for NumPy:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Når du har gjort dette, erstatter kjerneobjektet JAX-matrise kalt DeviceArray standard NumPy-matrisen. DeviceArray er lat – verdiene holdes i akseleratoren til de trengs. Dette betyr også at JAX-programmet ikke venter på at resultatene går tilbake til det anropende (Python)-programmet, og følger dermed en asynkron sending.

#4. Automatisk vektorisering (vmap)

I en typisk maskinlæringsverden har vi datasett med en million eller flere datapunkter. Mest sannsynlig vil vi utføre noen beregninger eller manipulasjoner på hvert eller de fleste av disse datapunktene – noe som er en veldig tid- og minnekrevende oppgave! Hvis du for eksempel vil finne kvadratet til hvert av datapunktene i datasettet, er det første du tenker på å lage en løkke og ta kvadratet ett etter ett – argh!

Hvis vi lager disse punktene som vektorer, kan vi gjøre alle rutene på en gang ved å utføre vektor- eller matrisemanipulasjoner på datapunktene med vår favoritt NumPy. Og hvis programmet ditt kunne gjøre dette automatisk – kan du be om noe mer? Det er akkurat det JAX gjør! Den kan automatisk vektorisere alle datapunktene dine slik at du enkelt kan utføre alle operasjoner på dem – noe som gjør algoritmene dine mye raskere og mer effektive.

JAX bruker vmap-funksjonen for autovektorisering. Tenk på følgende array:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Ved å gjøre bare det ovenfor, vil kvadratmetoden utføres for hvert punkt i matrisen. Men hvis du gjør følgende:

vmap(jnp.square(x))

Metodefirkanten vil kun utføres én gang fordi datapunktene nå vektoriseres automatisk ved hjelp av vmap-metoden før funksjonen utføres, og looping blir presset ned til det elementære operasjonsnivået – noe som resulterer i en matrisemultiplikasjon i stedet for skalarmultiplikasjon, og gir dermed bedre ytelse .

  Slik avslutter du Vi eller Vim Editor

#5. SPMD-programmering (pmap)

SPMD – eller Single Program Multiple Data-programmering er viktig i dype læringskontekster – du vil ofte bruke de samme funksjonene på forskjellige sett med data som ligger på flere GPUer eller TPUer. JAX har en funksjon kalt pumpe, som muliggjør parallell programmering på flere GPUer eller en hvilken som helst akselerator. I likhet med JIT vil programmer som bruker pmap bli kompilert av XLA og kjøres samtidig på tvers av systemene. Denne automatiske parallelliseringen fungerer for både forover- og bakoverberegninger.

Hvordan fungerer pmap

Vi kan også bruke flere transformasjoner på en gang i hvilken som helst rekkefølge på en hvilken som helst funksjon som:

pmap(vmap(jit(grad (f(x)))))

Flere komponerbare transformasjoner

Begrensninger for Google JAX

Google JAX-utviklere har tenkt godt på å øke hastigheten på dyplæringsalgoritmer mens de introduserte alle disse fantastiske transformasjonene. De vitenskapelige beregningsfunksjonene og pakkene er på linje med NumPy, så du trenger ikke å bekymre deg for læringskurven. JAX har imidlertid følgende begrensninger:

  • Google JAX er fortsatt i de tidlige utviklingsstadiene, og selv om hovedformålet er ytelsesoptimalisering, gir det ikke mye nytte for CPU-databehandling. NumPy ser ut til å yte bedre, og bruk av JAX kan bare øke overheaden.
  • JAX er fortsatt i sin forskning eller tidlige stadier og trenger mer finjustering for å nå infrastrukturstandardene til rammeverk som TensorFlow, som er mer etablerte og har mer forhåndsdefinerte modeller, åpen kildekode-prosjekter og læremateriell.
  • Per nå støtter ikke JAX Windows-operativsystem – du trenger en virtuell maskin for å få det til å fungere.
  • JAX fungerer kun på rene funksjoner – de som ikke har noen bivirkninger. For funksjoner med bivirkninger er JAX kanskje ikke et godt alternativ.

Slik installerer du JAX i ditt Python-miljø

Hvis du har python-oppsett på systemet og ønsker å kjøre JAX på din lokale maskin (CPU), bruk følgende kommandoer:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Hvis du vil kjøre Google JAX på en GPU eller TPU, følg instruksjonene gitt på GitHub JAX side. For å sette opp Python, gå til python offisielle nedlastinger side.

Konklusjon

Google JAX er flott for å skrive effektive dyplæringsalgoritmer, robotikk og forskning. Til tross for begrensningene, brukes den mye med andre rammer som Haiku, Flax og mange flere. Du vil være i stand til å sette pris på hva JAX gjør når du kjører programmer og se tidsforskjellene i å utføre kode med og uten JAX. Du kan begynne med å lese offisiell Google JAX-dokumentasjonsom er ganske omfattende.