Lær å bruke NumPy argmax() for å finne indeksen til det største elementet i matriser
I denne veiledningen skal vi se på hvordan du kan benytte deg av NumPy sin argmax()-funksjon for å identifisere posisjonen til det største elementet i en matrise.
NumPy er et robust bibliotek for vitenskapelig databehandling i Python. Det tilbyr flerdimensjonale matriser som er mer effektive enn vanlige Python-lister. En vanlig operasjon ved bruk av NumPy er å finne den høyeste verdien i en matrise. Men noen ganger trenger vi ikke bare verdien, men også hvor i matrisen denne verdien befinner seg.
Argmax()-funksjonen hjelper deg med å lokalisere indeksen til det maksimale elementet, både i endimensjonale og flerdimensjonale matriser. La oss utforske hvordan den fungerer.
Steg-for-steg: Slik finner du indeksen til det maksimale elementet i en NumPy-matrise
For å følge denne veiledningen, trenger du Python og NumPy installert. Du kan skrive koden i en Python REPL eller i en Jupyter Notebook.
Først må vi importere NumPy under det vanlige aliaset «np».
import numpy as np
Du kan bruke NumPy sin max()-funksjon for å få tak i den høyeste verdien i en matrise. Det er også mulig å finne maksverdien langs en spesifikk akse.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
Som vi ser returnerer np.max(array_1) tallet 10, som er det største tallet i matrisen.
La oss anta at vi ønsker å finne posisjonen, altså indeksen, der det største tallet befinner seg i matrisen. Dette kan gjøres i to trinn:
- Finn det største tallet.
- Finn indeksen der dette tallet er plassert.
I `array_1` finner vi tallet 10 på indeks 4, siden indeksering i Python starter på 0. Det første elementet har indeks 0, det andre indeks 1, og så videre.
For å finne indeksen der det største elementet ligger, kan vi benytte NumPy sin where()-funksjon. `np.where(condition)` gir deg tilbake en matrise med alle indekser der betingelsen er sann.
Vi må så hente ut det første elementet i den returnerte matrisen. For å finne hvor maksimumsverdien er, setter vi betingelsen til `array_1==10`; husk at 10 er den høyeste verdien i `array_1`.
print(int(np.where(array_1==10)[0])) # Output 4
Vi har brukt `np.where()` med kun en betingelse, men dette er ikke den anbefalte måten å bruke funksjonen på.
📑 Notat: NumPy sin where()-funksjon:
`np.where(condition, x, y)` gir tilbake:
- Elementer fra `x` når betingelsen er `True`, og
- Elementer fra `y` når betingelsen er `False`.
Ved å kombinere `np.max()` og `np.where()`, kan vi finne det maksimale elementet og deretter finne indeksen det befinner seg på.
Men i stedet for denne totrinnsprosessen, kan vi bruke NumPy sin `argmax()`-funksjon for å få indeksen direkte.
Syntaks for NumPy sin argmax()-funksjon
Den grunnleggende syntaksen for å bruke `argmax()`-funksjonen ser slik ut:
np.argmax(array, axis, out) # numpy er importert som 'np'
Her er en forklaring av parameterne:
- `array` er en gyldig NumPy-matrise.
- `axis` er en valgfri parameter. Hvis du har en flerdimensjonal matrise, kan du bruke `axis` for å finne indeksen for det største elementet langs en spesifikk akse.
- `out` er også valgfri. Du kan sette `out`-parameteren til en annen NumPy-matrise for å lagre resultatet fra `argmax()`.
Merk: Fra NumPy versjon 1.22.0 finnes det en `keepdims`-parameter. Når vi angir `axis` i `argmax()`, blir matrisen redusert langs den aksen. Men hvis `keepdims` settes til `True`, vil den returnerte matrisen ha samme form som input-matrisen.
Bruke NumPy sin argmax() for å finne indeksen til det største elementet
#1. La oss bruke `argmax()` for å finne indeksen til det største tallet i `array_1`.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
`argmax()` returnerer 4, som er riktig! ✅
#2. Hvis vi endrer `array_1` slik at tallet 10 forekommer to ganger, vil `argmax()` kun returnere indeksen til den første forekomsten.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
For de resterende eksemplene bruker vi den originale `array_1` fra eksempel #1.
Bruke NumPy sin argmax() i en 2D-matrise
La oss nå endre formen på `array_1` til en todimensjonal matrise med to rader og fire kolonner.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
I en todimensjonal matrise, representerer akse 0 radene, og akse 1 representerer kolonnene. NumPy-matriser bruker nullindeksering. Så indeksene for rader og kolonner i `array_2` er:
La oss nå bruke `argmax()` på denne todimensjonale matrisen, `array_2`.
print(np.argmax(array_2)) # Output 4
Til tross for at vi brukte `argmax()` på en 2D-matrise, returnerer den fortsatt 4. Dette er det samme som for den endimensjonale `array_1`.
Hvorfor skjer dette? 🤔
Dette skjer fordi vi ikke har angitt noen verdi for `axis`-parameteren. Når denne ikke er angitt, vil `argmax()` som standard returnere indeksen til det største elementet i den flate matrisen.
Hva er en flat matrise? En N-dimensjonal matrise med form d1 x d2 x … x dN, der d1, d2, og dN er størrelsene langs de N dimensjonene. Den flate matrisen er en lang endimensjonal matrise med størrelse d1 * d2 * … * dN.
For å se hvordan den flate matrisen ser ut for `array_2`, kan du bruke `flatten()`-metoden:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Indeks for det største elementet langs radene (axis = 0)
La oss nå finne indeksen for det største tallet langs radene (axis = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
Dette kan være litt vanskelig å forstå, men la oss se på hvordan det fungerer.
Vi har satt `axis` til 0, siden vi ønsker å finne indeksen til det største elementet langs radene. `argmax()` returnerer derfor radnummeret der det største tallet er – for hver av kolonnene.
La oss visualisere dette for å bedre forstå:
Fra diagrammet over og output, ser vi følgende:
- For den første kolonnen (indeks 0), er den største verdien 10 i den andre raden (indeks 1).
- For den andre kolonnen (indeks 1), er den største verdien 9 i den andre raden (indeks 1).
- For den tredje og fjerde kolonnen (indeks 2 og 3), er de største verdiene 8 og 4 i den andre raden (indeks 1).
Dette er grunnen til at vi får output `([1, 1, 1, 1])`, siden det største elementet langs radene er i den andre raden (for alle kolonnene).
Indeks for det største elementet langs kolonnene (axis = 1)
La oss så bruke `argmax()` for å finne indeksen til det største elementet langs kolonnene.
Kjør følgende kode og se på output:
np.argmax(array_2,axis=1)
array([2, 0])
Kan du analysere output?
Vi har satt `axis = 1` for å beregne indeksen for det største elementet langs kolonnene.
`argmax()` returnerer, for hver rad, kolonnenummeret der det største tallet er.
Her er en visuell forklaring:
Fra diagrammet og output, ser vi følgende:
- For den første raden (indeks 0), er det største tallet 7 i den tredje kolonnen (indeks 2).
- For den andre raden (indeks 1), er det største tallet 10 i den første kolonnen (indeks 0).
Jeg håper du nå forstår hva output `array([2, 0])` betyr.
Bruke den valgfrie `out`-parameteren i NumPy `argmax()`
Du kan bruke `out`-parameteren i NumPy sin `argmax()` for å lagre output i en NumPy-matrise.
La oss lage en matrise fylt med nuller for å lagre output fra det forrige `argmax()`-kallet – for å finne indeksen for det største elementet langs kolonnene (`axis= 1`).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
La oss nå se på et eksempel der vi finner indeksen for det største elementet langs kolonnene (axis= 1) og setter `out` til `out_arr` som vi definerte ovenfor.
np.argmax(array_2,axis=1,out=out_arr)
Vi ser at Python gir en `TypeError`, siden `out_arr` som standard ble initialisert som en rekke flyttall.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Når du setter `out`-parameteren, er det derfor viktig å sørge for at output-matrisen har korrekt form og datatype. Siden matriseindekser alltid er heltall, bør vi sette parameteren `dtype` til `int` når vi definerer `out`-matrisen.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Vi kan nå gå videre og kalle `argmax()` med både `axis`- og `out`-parametere, og denne gangen kjører den uten feil.
np.argmax(array_2,axis=1,out=out_arr)
Output fra `argmax()` er nå tilgjengelig i matrisen `out_arr`.
print(out_arr) # Output [2 0]
Konklusjon
Jeg håper denne veiledningen har hjulpet deg med å forstå hvordan du bruker NumPy sin `argmax()`-funksjon. Du kan kjøre kodeeksemplene i en Jupyter Notebook.
La oss gå gjennom det vi har lært:
- NumPy sin `argmax()`-funksjon returnerer indeksen til det største elementet i en matrise. Hvis det største elementet finnes flere ganger, vil `np.argmax(a)` returnere indeksen til den første forekomsten.
- Når du jobber med flerdimensjonale matriser, kan du bruke den valgfrie `axis`-parameteren for å få indeksen til det største elementet langs en spesifikk akse. For eksempel, i en todimensjonal matrise: ved å sette `axis = 0` og `axis = 1`, kan du få indeksen til det største elementet langs radene og kolonnene, henholdsvis.
- Hvis du ønsker å lagre output i en annen matrise, kan du sette den valgfrie `out`-parameteren til output-matrisen. Output-matrisen må ha en kompatibel form og datatype.
Nå kan du sjekke ut vår dyptgående veiledning om Python-sett.