NumPy argmax(): Finn maks-indeksen i arrayer!

Lær å bruke NumPy argmax() for å finne indeksen til det største elementet i matriser

I denne veiledningen skal vi se på hvordan du kan benytte deg av NumPy sin argmax()-funksjon for å identifisere posisjonen til det største elementet i en matrise.

NumPy er et robust bibliotek for vitenskapelig databehandling i Python. Det tilbyr flerdimensjonale matriser som er mer effektive enn vanlige Python-lister. En vanlig operasjon ved bruk av NumPy er å finne den høyeste verdien i en matrise. Men noen ganger trenger vi ikke bare verdien, men også hvor i matrisen denne verdien befinner seg.

Argmax()-funksjonen hjelper deg med å lokalisere indeksen til det maksimale elementet, både i endimensjonale og flerdimensjonale matriser. La oss utforske hvordan den fungerer.

Steg-for-steg: Slik finner du indeksen til det maksimale elementet i en NumPy-matrise

For å følge denne veiledningen, trenger du Python og NumPy installert. Du kan skrive koden i en Python REPL eller i en Jupyter Notebook.

Først må vi importere NumPy under det vanlige aliaset «np».

import numpy as np

Du kan bruke NumPy sin max()-funksjon for å få tak i den høyeste verdien i en matrise. Det er også mulig å finne maksverdien langs en spesifikk akse.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.max(array_1))

    # Output
    10
  

Som vi ser returnerer np.max(array_1) tallet 10, som er det største tallet i matrisen.

La oss anta at vi ønsker å finne posisjonen, altså indeksen, der det største tallet befinner seg i matrisen. Dette kan gjøres i to trinn:

  1. Finn det største tallet.
  2. Finn indeksen der dette tallet er plassert.

I `array_1` finner vi tallet 10 på indeks 4, siden indeksering i Python starter på 0. Det første elementet har indeks 0, det andre indeks 1, og så videre.

For å finne indeksen der det største elementet ligger, kan vi benytte NumPy sin where()-funksjon. `np.where(condition)` gir deg tilbake en matrise med alle indekser der betingelsen er sann.

Vi må så hente ut det første elementet i den returnerte matrisen. For å finne hvor maksimumsverdien er, setter vi betingelsen til `array_1==10`; husk at 10 er den høyeste verdien i `array_1`.

    print(int(np.where(array_1==10)[0]))

    # Output
    4
  

Vi har brukt `np.where()` med kun en betingelse, men dette er ikke den anbefalte måten å bruke funksjonen på.

📑 Notat: NumPy sin where()-funksjon:
`np.where(condition, x, y)` gir tilbake:

  • Elementer fra `x` når betingelsen er `True`, og
  • Elementer fra `y` når betingelsen er `False`.

Ved å kombinere `np.max()` og `np.where()`, kan vi finne det maksimale elementet og deretter finne indeksen det befinner seg på.

Men i stedet for denne totrinnsprosessen, kan vi bruke NumPy sin `argmax()`-funksjon for å få indeksen direkte.

Syntaks for NumPy sin argmax()-funksjon

Den grunnleggende syntaksen for å bruke `argmax()`-funksjonen ser slik ut:

np.argmax(array, axis, out)
    # numpy er importert som 'np'
  

Her er en forklaring av parameterne:

  • `array` er en gyldig NumPy-matrise.
  • `axis` er en valgfri parameter. Hvis du har en flerdimensjonal matrise, kan du bruke `axis` for å finne indeksen for det største elementet langs en spesifikk akse.
  • `out` er også valgfri. Du kan sette `out`-parameteren til en annen NumPy-matrise for å lagre resultatet fra `argmax()`.

Merk: Fra NumPy versjon 1.22.0 finnes det en `keepdims`-parameter. Når vi angir `axis` i `argmax()`, blir matrisen redusert langs den aksen. Men hvis `keepdims` settes til `True`, vil den returnerte matrisen ha samme form som input-matrisen.

Bruke NumPy sin argmax() for å finne indeksen til det største elementet

#1. La oss bruke `argmax()` for å finne indeksen til det største tallet i `array_1`.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))

    # Output
    4
  

`argmax()` returnerer 4, som er riktig! ✅

#2. Hvis vi endrer `array_1` slik at tallet 10 forekommer to ganger, vil `argmax()` kun returnere indeksen til den første forekomsten.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))

    # Output
    4
  

For de resterende eksemplene bruker vi den originale `array_1` fra eksempel #1.

Bruke NumPy sin argmax() i en 2D-matrise

La oss nå endre formen på `array_1` til en todimensjonal matrise med to rader og fire kolonner.

    array_2 = array_1.reshape(2,4)
    print(array_2)

    # Output
    [[ 1  5  7  2]
    [10  9  8  4]]
  

I en todimensjonal matrise, representerer akse 0 radene, og akse 1 representerer kolonnene. NumPy-matriser bruker nullindeksering. Så indeksene for rader og kolonner i `array_2` er:

La oss nå bruke `argmax()` på denne todimensjonale matrisen, `array_2`.

    print(np.argmax(array_2))

    # Output
    4
  

Til tross for at vi brukte `argmax()` på en 2D-matrise, returnerer den fortsatt 4. Dette er det samme som for den endimensjonale `array_1`.

Hvorfor skjer dette? 🤔

Dette skjer fordi vi ikke har angitt noen verdi for `axis`-parameteren. Når denne ikke er angitt, vil `argmax()` som standard returnere indeksen til det største elementet i den flate matrisen.

Hva er en flat matrise? En N-dimensjonal matrise med form d1 x d2 x … x dN, der d1, d2, og dN er størrelsene langs de N dimensjonene. Den flate matrisen er en lang endimensjonal matrise med størrelse d1 * d2 * … * dN.

For å se hvordan den flate matrisen ser ut for `array_2`, kan du bruke `flatten()`-metoden:

    array_2.flatten()

    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])
  

Indeks for det største elementet langs radene (axis = 0)

La oss nå finne indeksen for det største tallet langs radene (axis = 0).

    np.argmax(array_2,axis=0)

    # Output
    array([1, 1, 1, 1])
  

Dette kan være litt vanskelig å forstå, men la oss se på hvordan det fungerer.

Vi har satt `axis` til 0, siden vi ønsker å finne indeksen til det største elementet langs radene. `argmax()` returnerer derfor radnummeret der det største tallet er – for hver av kolonnene.

La oss visualisere dette for å bedre forstå:

Fra diagrammet over og output, ser vi følgende:

  • For den første kolonnen (indeks 0), er den største verdien 10 i den andre raden (indeks 1).
  • For den andre kolonnen (indeks 1), er den største verdien 9 i den andre raden (indeks 1).
  • For den tredje og fjerde kolonnen (indeks 2 og 3), er de største verdiene 8 og 4 i den andre raden (indeks 1).

Dette er grunnen til at vi får output `([1, 1, 1, 1])`, siden det største elementet langs radene er i den andre raden (for alle kolonnene).

Indeks for det største elementet langs kolonnene (axis = 1)

La oss så bruke `argmax()` for å finne indeksen til det største elementet langs kolonnene.

Kjør følgende kode og se på output:

np.argmax(array_2,axis=1)
array([2, 0])

Kan du analysere output?

Vi har satt `axis = 1` for å beregne indeksen for det største elementet langs kolonnene.

`argmax()` returnerer, for hver rad, kolonnenummeret der det største tallet er.

Her er en visuell forklaring:

Fra diagrammet og output, ser vi følgende:

  • For den første raden (indeks 0), er det største tallet 7 i den tredje kolonnen (indeks 2).
  • For den andre raden (indeks 1), er det største tallet 10 i den første kolonnen (indeks 0).

Jeg håper du nå forstår hva output `array([2, 0])` betyr.

Bruke den valgfrie `out`-parameteren i NumPy `argmax()`

Du kan bruke `out`-parameteren i NumPy sin `argmax()` for å lagre output i en NumPy-matrise.

La oss lage en matrise fylt med nuller for å lagre output fra det forrige `argmax()`-kallet – for å finne indeksen for det største elementet langs kolonnene (`axis= 1`).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]
  

La oss nå se på et eksempel der vi finner indeksen for det største elementet langs kolonnene (axis= 1) og setter `out` til `out_arr` som vi definerte ovenfor.

np.argmax(array_2,axis=1,out=out_arr)

Vi ser at Python gir en `TypeError`, siden `out_arr` som standard ble initialisert som en rekke flyttall.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
        56     try:
    ---> 57         return bound(*args, **kwds)
        58     except TypeError:

    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
  

Når du setter `out`-parameteren, er det derfor viktig å sørge for at output-matrisen har korrekt form og datatype. Siden matriseindekser alltid er heltall, bør vi sette parameteren `dtype` til `int` når vi definerer `out`-matrisen.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)

    # Output
    [0 0]
  

Vi kan nå gå videre og kalle `argmax()` med både `axis`- og `out`-parametere, og denne gangen kjører den uten feil.

np.argmax(array_2,axis=1,out=out_arr)

Output fra `argmax()` er nå tilgjengelig i matrisen `out_arr`.

    print(out_arr)
    # Output
    [2 0]
  

Konklusjon

Jeg håper denne veiledningen har hjulpet deg med å forstå hvordan du bruker NumPy sin `argmax()`-funksjon. Du kan kjøre kodeeksemplene i en Jupyter Notebook.

La oss gå gjennom det vi har lært:

  • NumPy sin `argmax()`-funksjon returnerer indeksen til det største elementet i en matrise. Hvis det største elementet finnes flere ganger, vil `np.argmax(a)` returnere indeksen til den første forekomsten.
  • Når du jobber med flerdimensjonale matriser, kan du bruke den valgfrie `axis`-parameteren for å få indeksen til det største elementet langs en spesifikk akse. For eksempel, i en todimensjonal matrise: ved å sette `axis = 0` og `axis = 1`, kan du få indeksen til det største elementet langs radene og kolonnene, henholdsvis.
  • Hvis du ønsker å lagre output i en annen matrise, kan du sette den valgfrie `out`-parameteren til output-matrisen. Output-matrisen må ha en kompatibel form og datatype.

Nå kan du sjekke ut vår dyptgående veiledning om Python-sett.