WildTorchSimulator

WildTorch main simulator class

Channel definition for fire_state:

Index Description Data type Default value
0 burning bool False
1 burned bool False
2 firebreak bool False

Cell state in the simulator:

Description burning_bit burned_bit
burnable 0 0
burning 1 0
burned 0 1

Optimized state transition:

some_cells.state = burning # Make initial ignition

IF current.state == burning THEN
    IF NOT rand_int > p_count THEN
        current.state = burned # burn out
    FOR each neighbour in current.neighbours DO
        IF NOT neighbour.state == burned THEN
            IF rand_int > p_burn[neighbour.x, neighbour.y, direction_x, direction_y] THEN
                neighbour.state = burning # propagate burning

The device and data type from the simulator constants are used for the simulator.

Attributes:
  • simulator_constants (SimulatorConstants) –

    The simulator constants.

  • wildfire_map (torch.Tensor) –

    The wildfire map tensor.

  • device (torch.device) –

    The device now using.

  • dtype (torch.dtype) –

    The data type now using.

  • p_propagate_constant (torch.Tensor) –

    The constant part of the probability of propagation.

  • p_burn (torch.Tensor) –

    The probability of burning.

  • fire_state (torch.Tensor) –

    The current fire state tensor. [burning, burned, firebreak]

  • seed (int) –

    The initial random seed.

  • maximum_step (int) –

    The maximum number of steps to simulate.

  • initial_ignition (torch.Tensor | None) –

    The initial ignition map tensor.

  • current_step (int) –

    The current step of the simulation.

  • terminated (bool) –

    Whether the simulation is terminated.

  • truncated (bool) –

    Whether the simulation is truncated.

Initialize the simulator with the given wildfire map and simulator constants.

Parameters:
  • wildfire_map (torch.Tensor, default: generate_empty_dataset() ) –

    The wildfire map tensor.

  • simulator_constants (SimulatorConstants, default: SimulatorConstants() ) –

    The simulator constants.

  • maximum_step (int, default: 200 ) –

    The maximum number of steps to simulate.

  • initial_ignition (torch.Tensor | None, default: None ) –

    The initial ignition map tensor.

  • seed (int | None, default: None ) –

    The random seed to use.

Source code in wildtorch/main.py
def __init__(self,
             wildfire_map: torch.Tensor = generate_empty_dataset(),
             simulator_constants: SimulatorConstants = SimulatorConstants(),
             maximum_step: int = 200,
             initial_ignition: torch.Tensor | None = None,
             seed: int | None = None,
             ) -> None:
    """
    Initialize the simulator with the given wildfire map and simulator constants.

    Parameters:
        wildfire_map:
            The wildfire map tensor.
        simulator_constants:
            The simulator constants.
        maximum_step:
            The maximum number of steps to simulate.
        initial_ignition:
            The initial ignition map tensor.
        seed:
            The random seed to use.
    """
    self._simulator_constants = simulator_constants

    self._p_propagate_constant = None
    self.wildfire_map = wildfire_map
    self._fire_state = torch.zeros((3, *wildfire_map[0].shape), dtype=torch.bool, device=self.device)

    self.maximum_step = maximum_step
    if initial_ignition is None:
        self.initial_ignition = torch.zeros_like(wildfire_map[0], dtype=torch.bool, device=self.device)
    else:
        self.initial_ignition = initial_ignition.to(device=self.device, dtype=self.dtype)
        self._fire_state[0] = self.initial_ignition
    self.seed = seed

    self._current_step = 0

checkpoint property

Get the checkpoint of the simulator.

{'seed', 'current_step', 'fire_state'}

Returns:
  • Dict[str, Any]

    The checkpoint of the simulator.

p_burn property

Compute the probability of burning for each cell.

This functions follows the following formula:

$$ p_{burn}=1-\prod_{i=1}^{8}{(1-p_{propagate_i})} $$

Returns:
  • torch.Tensor

    The probability of burning for each cell.

batch_forward(step=20)

Perform multiple steps of simulation.

Parameters:
  • step (int, default: 20 ) –

    The number of steps to perform.

Source code in wildtorch/main.py
def batch_forward(self,
                  step: int = 20,
                  ) -> None:
    """
    Perform multiple steps of simulation.

    Parameters:
        step: The number of steps to perform.
    """
    for i in range(step):
        self.step()

compute_p_propagate_constant()

Compute the constant part of the probability of propagation.

This method is usually called after changing the wildfire map or simulator constants.

This function follows the following formula:

$$ p_{propagate}=p_h(1+p_{veg})(1+p_{den})p_wp_s $$

in which,

$$ p_w=\exp(c_1V)\exp(c_2V(\cos(\theta)-1)) $$

$$ p_s=\exp(a\theta_s) $$

Returns:
  • torch.Tensor

    The constant part of the probability of propagation.

Source code in wildtorch/main.py
def compute_p_propagate_constant(self) -> torch.Tensor:
    """
    Compute the constant part of the probability of propagation.

    This method is usually called after changing the wildfire map or simulator constants.

    This function follows the following formula:

    $$ p_{propagate}=p_h(1+p_{veg})(1+p_{den})p_wp_s $$

    in which,

    $$ p_w=\\exp(c_1V)\\exp(c_2V(\\cos(\\theta)-1)) $$

    $$ p_s=\\exp(a\\theta_s) $$

    Returns:
        The constant part of the probability of propagation.
    """
    p_h = self.simulator_constants.p_h
    p_veg = self.wildfire_map[0]
    p_den = self.wildfire_map[1]
    p_w = (torch.exp(self.simulator_constants.c_1 * self.simulator_constants.V) * torch.exp(
        self.simulator_constants.c_2 * self.simulator_constants.V * (
                torch.cos(torch.deg2rad(self.simulator_constants.theta)) - 1)))
    p_s = torch.exp(self.simulator_constants.a * torch.deg2rad(self.wildfire_map[2]))

    p_veg = rearrange(p_veg, 'h w -> h w 1 1')
    p_den = rearrange(p_den, 'h w -> h w 1 1')
    p_s = rearrange(p_s, 'h w -> h w 1 1')
    p_w = rearrange(p_w, 'o1 o2 -> 1 1 o1 o2', o1=3, o2=3)

    return p_h * (1 + p_veg) * (1 + p_den) * p_w * p_s

load_checkpoint(checkpoint, restore_seed=True)

Reset the simulator to the checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) –

    The checkpoint to reset.

  • restore_seed (bool, default: True ) –

    Whether to reset the random seed as well.

Source code in wildtorch/main.py
def load_checkpoint(self,
                    checkpoint: Dict[str, Any],
                    restore_seed: bool = True,
                    ) -> None:
    """
    Reset the simulator to the checkpoint.

    Parameters:
        checkpoint: The checkpoint to reset.
        restore_seed: Whether to reset the random seed as well.
    """
    if restore_seed:
        self.seed = checkpoint['seed']
    else:
        self._seed = None
    self._current_step = checkpoint['current_step']
    self._fire_state = checkpoint['fire_state'].clone().detach()

reset()

Reset the simulator.

  • The current step is set to 0.
  • The random seed is set to None.
  • The fire state is set to the initial ignition map.
Source code in wildtorch/main.py
def reset(self) -> None:
    """
    Reset the simulator.

    - The current step is set to 0.
    - The random seed is set to None.
    - The fire state is set to the initial ignition map.
    """
    self._current_step = 0
    self.seed = None
    self._fire_state = torch.zeros((3, *self.wildfire_map[0].shape), dtype=torch.bool, device=self.device)
    self._fire_state[0] = self.initial_ignition

step(force=False)

Perform one step of simulation.

Parameters:
  • force (bool, default: False ) –

    Whether to force calculating even if the simulation is terminated or truncated.

Source code in wildtorch/main.py
def step(self, force: bool = False) -> None:
    """
    Perform one step of simulation.

    Parameters:
        force:
            Whether to force calculating even if the simulation is terminated or truncated.
    """
    if not force and (self.terminated or self.truncated):
        return

    burning, burned, firebreak = self.fire_state
    p_burn = self.p_burn

    rand_propagate = torch.rand_like(p_burn, dtype=self.dtype, device=self.device)
    rand_continue = torch.rand_like(p_burn, dtype=self.dtype, device=self.device)

    # burnable patches have p_burn probability to be burning
    burnable = ~(burning | burned)
    new_burning = burnable & (rand_propagate < p_burn)
    burning[new_burning] = True

    # burning patches have p_continue_burn probability to continue burning
    will_burn_out = burning & (rand_continue >= self.simulator_constants.p_continue_burn)
    burning[will_burn_out] = False
    burned[will_burn_out] = True

    # burned remain burned
    self._current_step += 1