Utils

animate_snapshots(snapshots, wildfire_map=None, out_filename='runs/wildfire_simulation.mp4', fps=24)

Animate snapshots

Parameters:
  • snapshots (list[Dict[str, Any]]) –

    List of snapshots.

  • wildfire_map (torch.Tensor | None, default: None ) –

    The wildfire map tensor.

  • out_filename (str, default: 'runs/wildfire_simulation.mp4' ) –

    Output filename. Defaults to 'wildfire_simulation.mp4'.

  • fps (int, default: 24 ) –

    Frames per second. Defaults to 10.

Raises:
  • ImportError

    If numpy is not available.

  • ImportError

    If imageio is not available.

Returns:
  • np.ndarray

    The movie array.

Examples:

>>> animate_snapshots(snapshots, fps=10)
Source code in wildtorch/utils.py
def animate_snapshots(
        snapshots: list[Dict[str, Any]],
        wildfire_map: torch.Tensor | None = None,
        out_filename: str = 'runs/wildfire_simulation.mp4',
        fps: int = 24,
) -> 'np.ndarray':
    """
    Animate snapshots

    Parameters:
        snapshots:
            List of snapshots.
        wildfire_map:
            The wildfire map tensor.
        out_filename:
            Output filename. Defaults to 'wildfire_simulation.mp4'.
        fps:
            Frames per second. Defaults to 10.

    Raises:
        ImportError:
            If numpy is not available.
        ImportError:
            If imageio is not available.


    Returns:
        The movie array.

    Examples:
        >>> animate_snapshots(snapshots, fps=10)
    """

    check_np()
    check_iio()

    def process_frame(index_snapshot, wf_map=None):
        index, snapshot = index_snapshot
        fire_state = snapshot['fire_state'].cpu()
        vis_fire_state = colorize_array(compose_vis_fire_state(fire_state), vmin=0, vmax=3)

        if wf_map is not None:
            vis_wildfire_map = colorize_array(compose_vis_wildfire_map(wf_map))
            output = overlay_arrays(vis_fire_state, vis_wildfire_map, 0.6)
        else:
            output = vis_fire_state

        return index, (output * 255).astype(np.uint8)

    # Using ThreadPoolExecutor to parallelize frame processing
    with ThreadPoolExecutor() as executor:
        # Submit all tasks to the executor along with their indices
        futures = [executor.submit(process_frame, (index, snapshot), wildfire_map) for index, snapshot in
                   enumerate(snapshots)]

    processed_frames = []

    # Retrieve results as they are completed
    for future in as_completed(futures):
        index, processed_frame = future.result()
        processed_frames.append((index, processed_frame))

    processed_frames.sort(key=lambda x: x[0])

    movie_array_list = [frame for _, frame in processed_frames]
    movie_array = np.array(movie_array_list)

    with iio.imopen(out_filename, "w", plugin="pyav") as file:
        file.init_video_stream("libx264", fps=fps, force_keyframes=True)
        file.container_metadata["comment"] = "This video was created using WildTorch."
        file.write(movie_array,
                   # No more "height not divisible by 2"
                   filter_sequence=[('scale', '-2:1080')]
                   )

    return movie_array

check_iio()

Check if imageio is available

Raises:
  • ImportError

    If imageio is not available.

Returns:
  • bool

    True if imageio is available

Source code in wildtorch/utils.py
def check_iio() -> bool:
    """
    Check if imageio is available

    Raises:
        ImportError:
            If imageio is not available.

    Returns:
        True if imageio is available
    """
    if not imageio_available:
        raise ImportError("imageio is not available. Please install it using `pip install imageio`")
    return True

check_np()

Check if numpy is available

Raises:
  • ImportError

    If numpy is not available.

Returns:
  • bool

    True if numpy is available

Source code in wildtorch/utils.py
def check_np() -> bool:
    """
    Check if numpy is available

    Raises:
        ImportError:
            If numpy is not available.

    Returns:
        True if numpy is available
    """
    if not numpy_available:
        raise ImportError("numpy is not available. Please install it using `pip install numpy`")
    return True

check_plt()

Check if matplotlib is available

Raises:
  • ImportError

    If matplotlib is not available.

Returns:
  • bool

    True if matplotlib is available

Source code in wildtorch/utils.py
def check_plt() -> bool:
    """
    Check if matplotlib is available

    Raises:
        ImportError:
            If matplotlib is not available.

    Returns:
        True if matplotlib is available
    """
    if not matplotlib_available:
        raise ImportError("matplotlib is not available. Please install it using `pip install matplotlib`")
    return True

colorize_array(array, cmap='viridis', vmin=None, vmax=None)

Colorize array using a colormap

Parameters:
  • array (np.ndarray) –

    The array to colorize.

  • cmap (str, default: 'viridis' ) –

    The colormap to use.

  • vmin (int, default: None ) –

    The minimum value.

  • vmax (int, default: None ) –

    The maximum value.

Raises:
  • ImportError

    If matplotlib is not available.

Returns:
  • np.ndarray

    The colorized array.

Source code in wildtorch/utils.py
def colorize_array(array: 'np.ndarray',
                   cmap: str = 'viridis',
                   vmin: int = None,
                   vmax: int = None,
                   ) -> 'np.ndarray':
    """
    Colorize array using a colormap

    Parameters:
        array:
            The array to colorize.
        cmap:
            The colormap to use.
        vmin:
            The minimum value.
        vmax:
            The maximum value.

    Raises:
        ImportError:
            If matplotlib is not available.

    Returns:
        The colorized array.
    """
    check_plt()

    norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
    cmap = plt.get_cmap(cmap)
    colored = cmap(norm(array))[..., :3]

    return colored  # (h, w, c)

compose_vis_fire_state(fire_state)

Compose fire state tensor into a single array for visualization

Parameters:
  • fire_state (torch.Tensor) –

    The fire state tensor.

Raises:
  • ImportError

    If numpy is not available.

  • AssertionError

    If the fire state shape is invalid.

Returns:
  • np.ndarray

    The composed fire state array.

Source code in wildtorch/utils.py
def compose_vis_fire_state(fire_state: torch.Tensor) -> 'np.ndarray':
    """
    Compose fire state tensor into a single array for visualization

    Parameters:
        fire_state:
            The fire state tensor.

    Raises:
        ImportError:
            If numpy is not available.
        AssertionError:
            If the fire state shape is invalid.

    Returns:
        The composed fire state array.
    """
    check_np()
    assert fire_state.dim() == 3 and fire_state.shape[0] == 3, f"Invalid fire state shape: {fire_state.shape}"
    burning, burned, firebreak = fire_state
    composed_state = np.zeros(burning.shape, dtype=np.int32)

    composed_state[burning] = 1
    composed_state[burned] = 2
    composed_state[firebreak] = 3

    return composed_state

compose_vis_wildfire_map(wildfire_map)

Compose wildfire map into a single array for visualization

Parameters:
  • wildfire_map (torch.Tensor) –

    The wildfire map tensor.

Raises:
  • AssertionError

    If the wildfire map shape is invalid.

Returns:
  • np.ndarray

    The composed wildfire map array.

Source code in wildtorch/utils.py
def compose_vis_wildfire_map(wildfire_map: torch.Tensor) -> 'np.ndarray':
    """
    Compose wildfire map into a single array for visualization

    Parameters:
        wildfire_map:
            The wildfire map tensor.

    Raises:
        AssertionError:
            If the wildfire map shape is invalid.

    Returns:
        The composed wildfire map array.
    """
    assert wildfire_map.dim() == 3 and wildfire_map.shape[0] == 3, f"Invalid wildfire_map shape: {wildfire_map.shape}"
    return to_ndarray(wildfire_map[0] * wildfire_map[1])

create_ignition(shape=(128, 128), pos=((0.2, 0.2), (0.8, 0.8)), size=(0.01, 0.01), mode='center', count=5)

Create ignition map

Parameters:
  • shape (tuple[int, int], default: (128, 128) ) –

    The size should match the size of the map.

  • pos (tuple[tuple[float, float], tuple[float, float]], default: ((0.2, 0.2), (0.8, 0.8)) ) –

    The position of ignition.

  • size (tuple[float, float], default: (0.01, 0.01) ) –

    The size of ignition.

  • mode (str, default: 'center' ) –

    mode to generate ignition map.

  • count (int, default: 5 ) –

    The number of ignition points.

Raises:
  • AssertionError

    If the mode is not valid.

Returns:
  • torch.Tensor

    Ignition map

Source code in wildtorch/utils.py
def create_ignition(shape: tuple[int, int] = (128, 128),
                    pos: tuple[tuple[float, float], tuple[float, float]] = ((0.2, 0.2), (0.8, 0.8)),
                    size: tuple[float, float] = (0.01, 0.01),
                    mode: str = 'center',
                    count: int = 5,
                    ) -> torch.Tensor:
    """
    Create ignition map

    Parameters:
        shape:
            The size should match the size of the map.
        pos:
            The position of ignition.
        size:
            The size of ignition.
        mode:
            mode to generate ignition map.
        count:
            The number of ignition points.

    Raises:
        AssertionError:
            If the mode is not valid.

    Returns:
        Ignition map
    """
    assert mode in ['center', 'random-single', 'random-multi'], f"Invalid mode: {mode}"

    field = torch.zeros(shape, dtype=torch.bool)

    start_x, end_x = int(pos[0][0] * shape[0]), int(pos[1][0] * shape[0])
    start_y, end_y = int(pos[0][1] * shape[1]), int(pos[1][1] * shape[1])

    width = min(20, max(3, int(size[0] * shape[0] // 2)))
    height = min(20, max(3, int(size[1] * shape[1] // 2)))

    if mode == 'center':
        center_x = int((start_x + end_x) // 2)
        center_y = int((start_y + end_y) // 2)
        field[center_x - width:center_x + width, center_y - height:center_y + height] = True

    elif mode in ['random-single', 'random-multi']:
        count = count if mode == 'random-multi' else 1
        for _ in range(count):
            x = torch.randint(start_x, end_x, (1,)).item()
            y = torch.randint(start_y, end_y, (1,)).item()
            field[x - width:x + width, y - height:y + height] = True

    return field

overlay_arrays(array1, array2, alpha=0.5)

Overlay two arrays

$$ out = \alpha \cdot arr_1+(1-\alpha)\cdot arr_2 $$

Parameters:
  • array1 (np.ndarray) –

    The first array.

  • array2 (np.ndarray) –

    The second array.

  • alpha (float, default: 0.5 ) –

    The overlay alpha value.

Raises:
  • AssertionError

    If the shape of the arrays are not equal.

  • AssertionError

    If the alpha value is invalid.

Returns:
  • np.ndarray

    The overlaid array.

Source code in wildtorch/utils.py
def overlay_arrays(array1: 'np.ndarray', array2: 'np.ndarray', alpha: float = 0.5) -> 'np.ndarray':
    """
    Overlay two arrays

    $$ out = \\alpha \\cdot arr_1+(1-\\alpha)\\cdot arr_2 $$

    Parameters:
        array1:
            The first array.
        array2:
            The second array.
        alpha:
            The overlay alpha value.

    Raises:
        AssertionError:
            If the shape of the arrays are not equal.
        AssertionError:
            If the alpha value is invalid.

    Returns:
        The overlaid array.
    """
    assert array1.shape == array2.shape, f"Shape mismatch: {array1.shape} != {array2.shape}"
    assert 0 <= alpha <= 1, f"Invalid alpha value: {alpha}"

    return alpha * array1 + (1 - alpha) * array2

plot_stats(logs, keys)

Plot statistics from logs

Parameters:
  • logs (list[Dict[str, Any]]) –

    List of log entries.

  • keys (list[str]) –

    List of keys to plot.

Raises:
  • ImportError

    If matplotlib is not available.

Examples:

>>> plot_stats(logs, keys=['burning_cells', 'burned_cells'])
Source code in wildtorch/utils.py
def plot_stats(logs: list[Dict[str, Any]],
               keys: list[str],
               ):
    """
    Plot statistics from logs

    Parameters:
        logs:
            List of log entries.
        keys:
            List of keys to plot.

    Raises:
        ImportError:
            If matplotlib is not available.

    Examples:
        >>> plot_stats(logs, keys=['burning_cells', 'burned_cells'])
    """
    check_plt()
    for key in keys:
        plt.plot([log[key] for log in logs], label=key)
        plt.legend()
        plt.show()

to_ndarray(tensor)

Convert torch.tensor or numpy.ndarray to numpy.ndarray

Parameters:
  • tensor (Union[torch.Tensor, np.ndarray]) –

    The tensor to convert.

Raises:
  • ImportError

    If numpy is not available.

  • TypeError

    If the tensor type is not supported.

Returns:
  • np.ndarray

    The converted numpy array.

Source code in wildtorch/utils.py
def to_ndarray(tensor: 'Union[torch.Tensor, np.ndarray]') -> 'np.ndarray':
    """
    Convert `torch.tensor` or `numpy.ndarray` to `numpy.ndarray`

    Parameters:
        tensor:
            The tensor to convert.

    Raises:
        ImportError:
            If numpy is not available.
        TypeError:
            If the tensor type is not supported.

    Returns:
        The converted numpy array.
    """
    check_np()

    if isinstance(tensor, torch.Tensor):
        return tensor.cpu().numpy()
    elif isinstance(tensor, np.ndarray):
        return tensor
    else:
        raise TypeError(f"Invalid input type: {type(tensor)}")

visualize_array(array, **kwargs)

Visualize array using matplotlib

Parameters:
  • array (np.ndarray) –

    The array to visualize.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Raises:
  • ImportError

    If matplotlib is not available.

Example

visualize_array(array)

Source code in wildtorch/utils.py
def visualize_array(array: 'np.ndarray', **kwargs: Any):
    """
    Visualize array using matplotlib

    Parameters:
        array:
            The array to visualize.
        **kwargs:
            Additional keyword arguments.

    Raises:
        ImportError:
            If matplotlib is not available.

    Example:
        visualize_array(array)
    """
    check_plt()
    plt.imshow(array, **kwargs)
    plt.colorbar()
    plt.show()