Source code for irsim.env.env_plot

"""
This file is the implementation of the environment plot to visualize the environment objects.

Author: Ruihua Han
"""

from __future__ import annotations

import glob
import os
import shutil
from collections.abc import Iterable, Sequence
from math import cos, sin
from typing import Any, Optional

import imageio.v3 as imageio
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

from irsim.config import env_param, world_param
from irsim.config.path_param import path_manager as pm


[docs] class EnvPlot: """ EnvPlot class for visualizing the environment. Args: world: The world object containing environment information including grid_map, x_range, y_range. objects (list, optional): List of objects in the environment. Default is []. saved_figure (dict, optional): Keyword arguments for saving the figure. See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.savefig.html for details. Default is dict(). figure_pixels (list, optional): Width and height of the figure in pixels. Default is [1180, 1080]. show_title (bool, optional): Whether to show the title. Default is True. kwargs: Additional options such as color_map, no_axis, and tight. """ def __init__( self, world: Any, objects: Optional[list[Any]] = None, saved_figure: Optional[dict[str, Any]] = None, figure_pixels: Optional[Sequence[int]] = None, show_title: bool = True, **kwargs: Any, ) -> None: """ Initialize the EnvPlot instance. Sets up the matplotlib figure, configures plotting parameters, and initializes the plot with world data and objects. """ # Store world and basic properties if saved_figure is None: saved_figure = {} if figure_pixels is None: figure_pixels = [1000, 800] if objects is None: objects = [] self.world = world self.x_range = world.x_range self.y_range = world.y_range self.show_title = show_title self.title = None # Configure figure saving options self.saved_figure_kwargs: dict[str, Any] = { "dpi": 100, "bbox_inches": "tight", } self.saved_figure_kwargs.update(saved_figure) # Create matplotlib figure and axes self.fig, self.ax = plt.subplots( figsize=( figure_pixels[0] / self.saved_figure_kwargs["dpi"], figure_pixels[1] / self.saved_figure_kwargs["dpi"], ), dpi=self.saved_figure_kwargs["dpi"], ) # Initialize plot settings and appearance self.color_map: dict[str, str] = { "robot": "g", "obstacle": "k", "landmark": "b", "target": "pink", } self.color_map.update(kwargs.get("color_map", {})) # Configure save options self.saved_ani_kwargs: dict[str, Any] = {} # Initialize dynamic plotting lists self.dyna_line_list: list[Any] = [] self.dyna_point_list: list[Any] = [] self.dyna_quiver_list: list[Any] = [] # Initialize the plot with world data self.init_plot(world.grid_map, objects, **kwargs)
[docs] def init_plot( self, grid_map: Optional[Any], objects: list[Any], no_axis: bool = False, tight: bool = True, **kwargs: Any, ) -> None: """ Initialize the plot with the given grid map and objects. Args: grid_map (optional): The grid map of the environment. objects (list): List of objects to plot. no_axis (bool, optional): Whether to show the axis. Default is False. tight (bool, optional): Whether to show the axis tightly. Default is True. """ if isinstance(self.ax, Axes3D): self.ax.set_box_aspect([1, 1, 1]) else: self.ax.set_aspect("equal") self.ax.set_xlim(self.x_range) self.ax.set_ylim(self.y_range) self.ax.set_xlabel("x [m]") self.ax.set_ylabel("y [m]") # self.draw_components("all", objects) self.init_objects_plot(objects) self.draw_grid_map(grid_map) if no_axis: plt.axis("off") if tight: self.fig.tight_layout()
[docs] def step( self, mode: str = "dynamic", objects: Optional[list[Any]] = None, **kwargs: Any ) -> None: """Advance the plot by one step for the given objects. Args: mode (str): Which objects to update: "dynamic", "static", or "all". objects (list | None): The objects to update/draw. Defaults to empty list. **kwargs: Extra drawing options passed through to objects' plot methods. """ if objects is None: objects = [] if self.show_title: self.update_title() if isinstance(self.ax, Axes3D): self.clear_components(mode, objects) self.draw_components(mode, objects, **kwargs) else: self.clear_components(mode) self.step_objects_plot(mode, objects, **kwargs)
[docs] def init_objects_plot(self, objects: list[Any], **kwargs: Any) -> None: """Initialize plot state for provided objects, then render once. Args: objects (list): Objects to be initialized on the axes. **kwargs: Extra drawing options passed to initialization/plot. """ if self.show_title: self.update_title() [obj._init_plot(self.ax, **kwargs) for obj in objects] self.step_objects_plot("all", objects, **kwargs)
[docs] def step_objects_plot( self, mode: str = "dynamic", objects: Optional[list[Any]] = None, **kwargs: Any ) -> None: """ Update the plot for the objects. """ if objects is None: objects = [] if mode == "dynamic": [obj._step_plot(**kwargs) for obj in objects if not obj.static] elif mode == "static": [obj._step_plot(**kwargs) for obj in objects if obj.static] elif mode == "all": [obj._step_plot(**kwargs) for obj in objects] else: self.logger.error("Error: Invalid draw mode")
[docs] def draw_components( self, mode: str = "all", objects: Optional[list[Any]] = None, **kwargs: Any ) -> None: """ Draw the components in the environment. Args: mode (str): 'static', 'dynamic', or 'all' to specify which objects to draw. objects (list): List of objects to draw. kwargs: Additional plotting options. """ if objects is None: objects = [] if mode == "static": [obj.plot(self.ax, **kwargs) for obj in objects if obj.static] elif mode == "dynamic": [obj.plot(self.ax, **kwargs) for obj in objects if not obj.static] elif mode == "all": [obj.plot(self.ax, **kwargs) for obj in objects] else: self.logger.error("Error: Invalid draw mode")
[docs] def clear_components( self, mode: str = "all", objects: Optional[list[Any]] = None ) -> None: """ Clear the components in the environment. Args: mode (str): 'static', 'dynamic', or 'all' to specify which objects to clear. objects (list): List of objects to clear. """ if objects is None: objects = [] if mode == "dynamic": [obj.plot_clear() for obj in objects if not obj.static] [line.pop(0).remove() for line in self.dyna_line_list] [points.remove() for points in self.dyna_point_list] [quiver.remove() for quiver in self.dyna_quiver_list] self.dyna_line_list = [] self.dyna_point_list = [] self.dyna_quiver_list = [] elif mode == "static": [obj.plot_clear() for obj in objects if obj.static] elif mode == "all": [obj.plot_clear(all=True) for obj in objects] [line.pop(0).remove() for line in self.dyna_line_list] [points.remove() for points in self.dyna_point_list] [quiver.remove() for quiver in self.dyna_quiver_list] self.dyna_line_list = [] self.dyna_point_list = [] self.dyna_quiver_list = []
[docs] def draw_grid_map(self, grid_map: Optional[Any] = None, **kwargs: Any) -> None: """ Draw the grid map on the plot. Args: grid_map (optional): The grid map to draw. """ if grid_map is not None: self.ax.imshow( grid_map.T, cmap="Greys", origin="lower", extent=self.x_range + self.y_range, zorder=0, ) if isinstance(self.ax, Axes3D): print("Map will not show in 3D plot")
[docs] def draw_trajectory( self, traj: list[Any] | np.ndarray, traj_type: str = "g-", label: str = "trajectory", show_direction: bool = False, refresh: bool = False, **kwargs: Any, ) -> None: """ Draw a trajectory on the plot. Args: traj (list or np.ndarray): List of points or array of points [x, y, theta]. traj_type (str): Type of trajectory line (e.g., 'g-'). See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.plot.html for details. label (str): Label for the trajectory. show_direction (bool): Whether to show the direction of the trajectory. refresh (bool): Whether to refresh the plot. kwargs: Additional plotting options for ax.plot() """ if isinstance(traj, list): path_x_list = [p[0, 0] for p in traj] path_y_list = [p[1, 0] for p in traj] elif isinstance(traj, np.ndarray): path_x_list = [p[0] for p in traj.T] path_y_list = [p[1] for p in traj.T] line = self.ax.plot(path_x_list, path_y_list, traj_type, label=label, **kwargs) if show_direction: if isinstance(traj, list): u_list = [cos(p[2, 0]) for p in traj] v_list = [sin(p[2, 0]) for p in traj] elif isinstance(traj, np.ndarray): u_list = [cos(p[2]) for p in traj.T] v_list = [sin(p[2]) for p in traj.T] if isinstance(self.ax, Axes3D): path_z_list = [0] * len(path_x_list) w_list = [0] * len(u_list) self.ax.quiver( path_x_list, path_y_list, path_z_list, u_list, v_list, w_list ) else: self.ax.quiver(path_x_list, path_y_list, u_list, v_list, width=0.003) if refresh: self.dyna_line_list.append(line)
[docs] def draw_points( self, points: Optional[list[Any] | np.ndarray], s: int = 10, c: str = "m", refresh: bool = True, **kwargs: Any, ) -> None: """ Draw points on the plot. Args: points (list): List of points, each point as [x, y] or (2, 1) array or (np.array): points array: (2, N), NL number of points s (int): Size of the points. c (str): Color of the points. refresh (bool): Whether to refresh the plot. kwargs: Additional plotting options. """ if points is None: return if isinstance(points, list): x_coordinates = [point[0] for point in points] y_coordinates = [point[1] for point in points] elif isinstance(points, np.ndarray): if points.shape[1] > 1: x_coordinates = [point[0] for point in points.T] y_coordinates = [point[1] for point in points.T] else: x_coordinates = points[0] y_coordinates = points[1] points_plot = self.ax.scatter(x_coordinates, y_coordinates, s, c, **kwargs) if refresh: self.dyna_point_list.append(points_plot)
[docs] def draw_quiver( self, point: Optional[np.ndarray], refresh: bool = False, color: str = "black", **kwargs: Any, ) -> None: """ Draw a quiver plot on the plot. Args: points (4*1 np.ndarray): List of points, each point as [x, y, u, v]. u, v are the components of the vector. kwargs: Additional plotting options. """ if point is None: return ax_point = self.ax.scatter(point[0], point[1], color=color) ax_quiver = self.ax.quiver( point[0], point[1], # starting positions point[2], point[3], # vector components (direction) color=color, **kwargs, ) if refresh: self.dyna_quiver_list.append(ax_quiver) self.dyna_point_list.append(ax_point)
[docs] def draw_quivers( self, points: Iterable[np.ndarray], refresh: bool = False, color: str = "black", **kwargs: Any, ) -> None: """ Draw a series of quiver plot on the plot. Args: points (list or np.ndarray): List of points, each point as [x, y, u, v]. u, v are the components of the vector. """ for point in points: self.draw_quiver(point, refresh, color=color, **kwargs)
[docs] def draw_box( self, vertices: np.ndarray, refresh: bool = False, color: str = "b-" ) -> None: """ Draw a box by the vertices. Args: vertices (np.ndarray): 2xN array of vertices. refresh (bool): Whether to refresh the plot. color (str): Color and line type of the box. """ temp_vertex = np.c_[vertices, vertices[0:2, 0]] box_line = self.ax.plot(temp_vertex[0, :], temp_vertex[1, :], color) if refresh: self.dyna_line_list.append(box_line)
[docs] def update_title(self) -> None: """Update the figure title with current time/status or a custom title.""" if self.title is not None: self.ax.set_title(self.title, pad=3) else: self.ax.set_title( f"Simulation Time: {self.world.time:.2f}s, Status: {self.world.status}", pad=3, )
[docs] def save_figure( self, file_name: str = "", file_format: str = "png", include_index: bool = False, save_gif: bool = False, **kwargs: Any, ) -> None: """ Save the current figure. Args: file_name (str): Name of the figure. Default is ''. file_format (str): Format of the figure. Default is 'png'. kwargs: Additional arguments for saving the figure. See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.savefig.html for details. """ fp = pm.ani_buffer_path if save_gif else pm.fig_path if not os.path.exists(fp): os.makedirs(fp) self.saved_figure_kwargs.update(kwargs) if include_index or save_gif: order = str(world_param.count).zfill(3) full_name = fp + "/" + file_name + "_" + order + "." + file_format else: full_name = fp + "/" + file_name + "." + file_format self.fig.savefig(full_name, format=file_format, **self.saved_figure_kwargs)
[docs] def save_animate( self, ani_name: str = "animation", suffix: str = ".gif", last_frame_duration: int = 1, rm_fig_path: bool = True, **kwargs: Any, ) -> None: """ Save the animation. Args: ani_name (str): Name of the animation. Default is 'animation'. last_frame_duration (int): Duration of the last frame for the gif. Default is 1 second. suffix (str): Suffix of the animation file. Default is '.gif'. rm_fig_path (bool): Whether to remove the figure path after saving. Default is True. kwargs: Additional arguments for saving the animation. See `imageio.imwrite <https://imageio.readthedocs.io/en/stable/_autosummary/imageio.v3.imwrite.html#imageio.v3.imwrite>`_ for details. """ self.saved_ani_kwargs.update(kwargs) self.logger.info("Start to create animation") ap = pm.ani_path fp = pm.ani_buffer_path if not os.path.exists(ap): os.makedirs(ap) images = list(glob.glob(fp + "/*.png")) images.sort() image_list = [imageio.imread(str(file_name)) for file_name in images] if suffix == ".gif": # default arguments for gif durations = [100] * (len(image_list) - 1) + [last_frame_duration * 1000] self.saved_ani_kwargs.update( {"plugin": "pillow", "duration": durations, "loop": 0} ) full_name = ap + "/" + ani_name + suffix imageio.imwrite(full_name, image_list, **self.saved_ani_kwargs) self.logger.info(f"{ani_name} created successfully, saved in {ap}") if rm_fig_path: shutil.rmtree(fp)
[docs] def show(self) -> None: """ Display the plot. """ plt.show()
[docs] def close(self) -> None: """ Close the plot. """ plt.close()
@property def logger(self): return env_param.logger
[docs] def linewidth_from_data_units( linewidth: float, axis: Any, reference: str = "y" ) -> float: """ Convert a linewidth in data units to linewidth in points. Parameters ---------- linewidth: float Linewidth in data units of the respective reference-axis axis: matplotlib axis The axis which is used to extract the relevant transformation data (data limits and size must not change afterwards) reference: string The axis that is taken as a reference for the data width. Possible values: 'x' and 'y'. Defaults to 'y'. Returns ------- linewidth: float Linewidth in points """ fig = axis.get_figure() if reference == "x": length = fig.bbox_inches.width * axis.get_position().width value_range = np.diff(axis.get_xlim()).item() elif reference == "y": length = fig.bbox_inches.height * axis.get_position().height value_range = np.diff(axis.get_ylim()).item() # Convert length to points length *= 72 # Scale linewidth to value range return linewidth * (length / value_range)