Metrics

calculate_fire_state_difference(fire_state_a, fire_state_b)

Calculate the difference between two fire states.

Value a[x,y] is burnable a[x,y] is not burnable
b[x,y] is burnable 0 1
b[x,y] is not burnable -1 0
IF a[x,y] is burnable and b[x,y] is not burnable THEN
    return -1
IF a[x,y] is NOT burnable and b[x,y] is burnable THEN
    return 1
ELSE
    return 0
Parameters:
  • fire_state_a (torch.Tensor) –

    fire state of shape (3, H, W)

  • fire_state_b (torch.Tensor) –

    fire state of shape (3, H, W)

Returns:
  • torch.Tensor

    difference between fire states, shape (H, W)

Source code in wildtorch/metrics.py
def calculate_fire_state_difference(fire_state_a: torch.Tensor, fire_state_b: torch.Tensor) -> torch.Tensor:
    """
    Calculate the difference between two fire states.

    | Value | a[x,y] is burnable | a[x,y] is not burnable |
    |:---:|:---:|:---:|
    | b[x,y] is burnable | 0 | 1 |
    | b[x,y] is not burnable | -1 | 0 |

    ```
    IF a[x,y] is burnable and b[x,y] is not burnable THEN
        return -1
    IF a[x,y] is NOT burnable and b[x,y] is burnable THEN
        return 1
    ELSE
        return 0
    ```

    Parameters:
        fire_state_a:
            fire state of shape (3, H, W)
        fire_state_b:
            fire state of shape (3, H, W)

    Returns:
        difference between fire states, shape (H, W)
    """
    assert fire_state_a.shape == fire_state_b.shape, "Fire states must have the same shape"

    burnable_a = ~(fire_state_a[0] | fire_state_a[1])
    burnable_b = ~(fire_state_b[0] | fire_state_b[1])

    output = torch.zeros_like(fire_state_a[0], dtype=torch.int)
    output[(burnable_a & ~burnable_b)] = -1
    output[(~burnable_a & burnable_b)] = 1

    return output

cell_burned_out(fire_state)

Calculate the number of cells burned out.

Higher values indicate more cells burned out, lower values indicate fewer cells burned out

Parameters:
  • fire_state (torch.Tensor) –

    fire state of shape (3, H, W)

Returns:
  • torch.Tensor

    number of cells burned out, shape ()

Source code in wildtorch/metrics.py
def cell_burned_out(fire_state: torch.Tensor) -> torch.Tensor:
    """
    Calculate the number of cells burned out.

    Higher values indicate more cells burned out, lower values indicate fewer cells burned out

    Parameters:
        fire_state:
            fire state of shape (3, H, W)

    Returns:
        number of cells burned out, shape ()
    """
    return torch.sum(fire_state[1])

cell_on_fire(fire_state)

Calculate the number of cells on fire.

Higher values indicate more cells on fire, lower values indicate fewer cells on fire

Parameters:
  • fire_state (torch.Tensor) –

    fire state of shape (3, H, W)

Returns:
  • torch.Tensor

    number of cells on fire, shape ()

Source code in wildtorch/metrics.py
def cell_on_fire(fire_state: torch.Tensor) -> torch.Tensor:
    """
    Calculate the number of cells on fire.

    Higher values indicate more cells on fire, lower values indicate fewer cells on fire

    Parameters:
        fire_state:
            fire state of shape (3, H, W)

    Returns:
        number of cells on fire, shape ()
    """
    return torch.sum(fire_state[0])

saved_cells(fire_state_diff)

Calculate the number of cells saved.

Higher values indicate more cells saved, lower values indicate more cells burned out

Parameters:
  • fire_state_diff (torch.Tensor) –

    fire state difference of shape (H, W)

Returns:
  • torch.Tensor

    number of cells saved, shape ()

Source code in wildtorch/metrics.py
def saved_cells(fire_state_diff: torch.Tensor) -> torch.Tensor:
    """
    Calculate the number of cells saved.

    Higher values indicate more cells saved, lower values indicate more cells burned out

    Parameters:
        fire_state_diff:
            fire state difference of shape (H, W)

    Returns:
        number of cells saved, shape ()
    """
    return torch.sum(fire_state_diff)

weighted_cell_burned_out(fire_state, weights)

Calculate the weighted number of cells burned out.

Higher values indicate more valuable cells burned out, lower values indicate fewer valuable cells burned out

Parameters:
  • fire_state (torch.Tensor) –

    fire state of shape (3, H, W)

  • weights (torch.Tensor) –

    weight matrix of shape (H, W)

Returns:
  • torch.Tensor

    weighted number of cells burned out, shape ()

Source code in wildtorch/metrics.py
def weighted_cell_burned_out(fire_state: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """
    Calculate the weighted number of cells burned out.

    Higher values indicate more valuable cells burned out, lower values indicate fewer valuable cells burned out

    Parameters:
        fire_state:
            fire state of shape (3, H, W)
        weights:
            weight matrix of shape (H, W)

    Returns:
        weighted number of cells burned out, shape ()
    """
    return torch.sum(fire_state[1] * weights)

weighted_cell_on_fire(fire_state, weights)

Calculate the weighted number of cells on fire.

Higher values indicate more valuable cells on fire, lower values indicate fewer valuable cells on fire

Parameters:
  • fire_state (torch.Tensor) –

    fire state of shape (3, H, W)

  • weights (torch.Tensor) –

    weight matrix of shape (H, W)

Returns:
  • torch.Tensor

    weighted number of cells on fire, shape ()

Source code in wildtorch/metrics.py
def weighted_cell_on_fire(fire_state: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """
    Calculate the weighted number of cells on fire.

    Higher values indicate more valuable cells on fire, lower values indicate fewer valuable cells on fire

    Parameters:
        fire_state:
            fire state of shape (3, H, W)
        weights:
            weight matrix of shape (H, W)

    Returns:
        weighted number of cells on fire, shape ()
    """
    return torch.sum(fire_state[0] * weights)

weighted_saved_cells(fire_state_diff, weights)

Calculate the weighted number of cells saved.

Higher values indicate more valuable cells saved, lower values indicate more valuable cells burned out

Parameters:
  • fire_state_diff (torch.Tensor) –

    fire state difference of shape (H, W)

  • weights (torch.Tensor) –

    weight matrix of shape (H, W)

Returns:
  • torch.Tensor

    weighted number of cells saved, shape()

Source code in wildtorch/metrics.py
def weighted_saved_cells(fire_state_diff: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """
    Calculate the weighted number of cells saved.

    Higher values indicate more valuable cells saved, lower values indicate more valuable cells burned out

    Parameters:
        fire_state_diff:
            fire state difference of shape (H, W)
        weights:
            weight matrix of shape (H, W)

    Returns:
        weighted number of cells saved, shape()
    """
    return torch.sum(fire_state_diff * weights)