'''
This file is the implementation of the environment plot to visualize the environment objects.
Author: Ruihua Han
'''
import matplotlib.pyplot as plt
import logging
from irsim.global_param.path_param import path_manager as pm
from irsim.global_param import world_param, env_param
import os
import imageio.v2 as imageio
import shutil
import glob
from math import sin, cos
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from typing import Optional
[docs]
class EnvPlot:
"""
EnvPlot class for visualizing the environment.
Args:
world: The world object containing environment information including grid_map, x_range, y_range.
objects (list, optional): List of objects in the environment. Default is [].
saved_figure (dict, optional): Keyword arguments for saving the figure.
See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.savefig.html for details.
Default is dict().
figure_pixels (list, optional): Width and height of the figure in pixels. Default is [1180, 1080].
show_title (bool, optional): Whether to show the title. Default is True.
kwargs: Additional options such as color_map, no_axis, and tight.
"""
def __init__(
self,
world,
objects=[],
saved_figure=dict(),
figure_pixels: list = [1000, 800],
show_title: bool = True,
**kwargs,
) -> None:
"""
Initialize the EnvPlot instance.
Sets up the matplotlib figure, configures plotting parameters,
and initializes the plot with world data and objects.
"""
# Store world and basic properties
self.world = world
self.x_range = world.x_range
self.y_range = world.y_range
self.show_title = show_title
self.title = None
# Configure figure saving options
self.saved_figure_kwargs = {
"dpi": 100,
"bbox_inches": "tight",
}
self.saved_figure_kwargs.update(saved_figure)
# Create matplotlib figure and axes
self.fig, self.ax = plt.subplots(
figsize=(figure_pixels[0] / self.saved_figure_kwargs['dpi'], figure_pixels[1] / self.saved_figure_kwargs['dpi']), dpi=self.saved_figure_kwargs['dpi']
)
# Initialize plot settings and appearance
self.color_map = {
"robot": "g",
"obstacle": "k",
"landmark": "b",
"target": "pink",
}
self.color_map.update(kwargs.get("color_map", dict()))
# Configure save options
self.saved_ani_kwargs = {}
# Initialize dynamic plotting lists
self.dyna_line_list = []
self.dyna_point_list = []
self.dyna_quiver_list = []
# Initialize the plot with world data
self.init_plot(world.grid_map, objects, **kwargs)
[docs]
def init_plot(self, grid_map, objects, no_axis=False, tight=True, **kwargs):
"""
Initialize the plot with the given grid map and objects.
Args:
grid_map (optional): The grid map of the environment.
objects (list): List of objects to plot.
no_axis (bool, optional): Whether to show the axis. Default is False.
tight (bool, optional): Whether to show the axis tightly. Default is True.
"""
if isinstance(self.ax, Axes3D):
self.ax.set_box_aspect([1, 1, 1])
else:
self.ax.set_aspect("equal")
self.ax.set_xlim(self.x_range)
self.ax.set_ylim(self.y_range)
self.ax.set_xlabel("x [m]")
self.ax.set_ylabel("y [m]")
# self.draw_components("all", objects)
self.init_objects_plot(objects, **kwargs)
self.draw_grid_map(grid_map)
if no_axis:
plt.axis("off")
if tight:
self.fig.tight_layout()
[docs]
def step(self, mode='dynamic', objects=[], **kwargs):
if isinstance(self.ax, Axes3D):
self.clear_components(mode, objects)
self.draw_components(mode, objects, **kwargs)
else:
self.clear_components(mode)
self.step_objects_plot(mode, objects, **kwargs)
if self.show_title:
self.update_title()
[docs]
def init_objects_plot(self, objects, **kwargs):
[obj._init_plot(self.ax, **kwargs) for obj in objects]
self.step_objects_plot('all', objects, **kwargs)
[docs]
def step_objects_plot(self, mode='dynamic', objects=[], **kwargs):
"""
Update the plot for the objects.
"""
if mode == 'dynamic':
[obj._step_plot(**kwargs) for obj in objects if not obj.static]
elif mode == 'static':
[obj._step_plot(**kwargs) for obj in objects if obj.static]
elif mode == 'all':
[obj._step_plot(**kwargs) for obj in objects]
else:
self.logger.error("Error: Invalid draw mode")
[docs]
def draw_components(self, mode="all", objects=[], **kwargs):
"""
Draw the components in the environment.
Args:
mode (str): 'static', 'dynamic', or 'all' to specify which objects to draw.
objects (list): List of objects to draw.
kwargs: Additional plotting options.
"""
if mode == "static":
[obj.plot(self.ax, **kwargs) for obj in objects if obj.static]
elif mode == "dynamic":
[obj.plot(self.ax, **kwargs) for obj in objects if not obj.static]
elif mode == "all":
[obj.plot(self.ax, **kwargs) for obj in objects]
else:
self.logger.error("Error: Invalid draw mode")
[docs]
def clear_components(self, mode="all", objects=[]):
"""
Clear the components in the environment.
Args:
mode (str): 'static', 'dynamic', or 'all' to specify which objects to clear.
objects (list): List of objects to clear.
"""
if mode == "dynamic":
[obj.plot_clear() for obj in objects if not obj.static]
[line.pop(0).remove() for line in self.dyna_line_list]
[points.remove() for points in self.dyna_point_list]
[quiver.remove() for quiver in self.dyna_quiver_list]
self.dyna_line_list = []
self.dyna_point_list = []
self.dyna_quiver_list = []
elif mode == "static":
[obj.plot_clear() for obj in objects if obj.static]
elif mode == "all":
[obj.plot_clear(all=True) for obj in objects]
[line.pop(0).remove() for line in self.dyna_line_list]
[points.remove() for points in self.dyna_point_list]
[quiver.remove() for quiver in self.dyna_quiver_list]
self.dyna_line_list = []
self.dyna_point_list = []
self.dyna_quiver_list = []
[docs]
def draw_grid_map(self, grid_map=None, **kwargs):
"""
Draw the grid map on the plot.
Args:
grid_map (optional): The grid map to draw.
"""
if grid_map is not None:
self.ax.imshow(
grid_map.T,
cmap="Greys",
origin="lower",
extent=self.x_range + self.y_range,
zorder=0,
)
if isinstance(self.ax, Axes3D):
print("Map will not show in 3D plot")
[docs]
def draw_trajectory(
self,
traj,
traj_type="g-",
label="trajectory",
show_direction=False,
refresh=False,
**kwargs,
):
"""
Draw a trajectory on the plot.
Args:
traj (list or np.ndarray): List of points or array of points [x, y, theta].
traj_type (str): Type of trajectory line (e.g., 'g-').
See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.plot.html for details.
label (str): Label for the trajectory.
show_direction (bool): Whether to show the direction of the trajectory.
refresh (bool): Whether to refresh the plot.
kwargs: Additional plotting options for ax.plot()
"""
if isinstance(traj, list):
path_x_list = [p[0, 0] for p in traj]
path_y_list = [p[1, 0] for p in traj]
elif isinstance(traj, np.ndarray):
path_x_list = [p[0] for p in traj.T]
path_y_list = [p[1] for p in traj.T]
line = self.ax.plot(path_x_list, path_y_list, traj_type, label=label, **kwargs)
if show_direction:
if isinstance(traj, list):
u_list = [cos(p[2, 0]) for p in traj]
v_list = [sin(p[2, 0]) for p in traj]
elif isinstance(traj, np.ndarray):
u_list = [cos(p[2]) for p in traj.T]
v_list = [sin(p[2]) for p in traj.T]
if isinstance(self.ax, Axes3D):
path_z_list = [0] * len(path_x_list)
w_list = [0] * len(u_list)
self.ax.quiver(
path_x_list, path_y_list, path_z_list, u_list, v_list, w_list
)
else:
self.ax.quiver(path_x_list, path_y_list, u_list, v_list, width=0.003)
if refresh:
self.dyna_line_list.append(line)
[docs]
def draw_points(self, points, s=10, c="m", refresh=True, **kwargs):
"""
Draw points on the plot.
Args:
points (list): List of points, each point as [x, y] or (2, 1) array
or (np.array): points array: (2, N), NL number of points
s (int): Size of the points.
c (str): Color of the points.
refresh (bool): Whether to refresh the plot.
kwargs: Additional plotting options.
"""
if points is None:
return
if isinstance(points, list):
x_coordinates = [point[0] for point in points]
y_coordinates = [point[1] for point in points]
elif isinstance(points, np.ndarray):
if points.shape[1] > 1:
x_coordinates = [point[0] for point in points.T]
y_coordinates = [point[1] for point in points.T]
else:
x_coordinates = points[0]
y_coordinates = points[1]
points_plot = self.ax.scatter(x_coordinates, y_coordinates, s, c, **kwargs)
if refresh:
self.dyna_point_list.append(points_plot)
[docs]
def draw_quiver(self, point, refresh=False, color='black', **kwargs):
"""
Draw a quiver plot on the plot.
Args:
points (4*1 np.ndarray): List of points, each point as [x, y, u, v]. u, v are the components of the vector.
kwargs: Additional plotting options.
"""
if point is None:
return
ax_point = self.ax.scatter(point[0], point[1], color=color)
ax_quiver = self.ax.quiver(
point[0], point[1], # starting positions
point[2], point[3], # vector components (direction)
color=color,
**kwargs
)
if refresh:
self.dyna_quiver_list.append(ax_quiver)
self.dyna_point_list.append(ax_point)
[docs]
def draw_quivers(self, points, refresh=False, color='black', **kwargs):
"""
Draw a series of quiver plot on the plot.
Args:
points (list or np.ndarray): List of points, each point as [x, y, u, v]. u, v are the components of the vector.
"""
for point in points:
self.draw_quiver(point, refresh, color=color, **kwargs)
[docs]
def draw_box(self, vertices, refresh=False, color="b-"):
"""
Draw a box by the vertices.
Args:
vertices (np.ndarray): 2xN array of vertices.
refresh (bool): Whether to refresh the plot.
color (str): Color and line type of the box.
"""
temp_vertex = np.c_[vertices, vertices[0:2, 0]]
box_line = self.ax.plot(temp_vertex[0, :], temp_vertex[1, :], color)
if refresh:
self.dyna_line_list.append(box_line)
[docs]
def update_title(self):
if self.title is not None:
self.ax.set_title(self.title, pad=3)
else:
self.ax.set_title(f"Simulation Time: {self.world.time:.2f}s, Status: {self.world.status}", pad=3)
[docs]
def save_animate(
self,
ani_name="animation",
suffix=".gif",
keep_len=30,
rm_fig_path=True,
**kwargs,
):
"""
Save the animation.
Args:
ani_name (str): Name of the animation. Default is 'animation'.
keep_len (int): Length of the last frame. Default is 30.
suffix (str): Suffix of the animation file. Default is '.gif'.
rm_fig_path (bool): Whether to remove the figure path after saving. Default is True.
kwargs: Additional arguments for saving the animation.
See `format_gif <https://imageio.readthedocs.io/en/v2.8.0/format_gif-pil.html>`_ for details.
"""
if suffix == ".gif":
self.saved_ani_kwargs.update({"subrectangles": True, "loop": 0})
self.saved_ani_kwargs.update(kwargs)
self.logger.info("Start to create animation")
ap = pm.ani_path
fp = pm.ani_buffer_path
if not os.path.exists(ap):
os.makedirs(ap)
images = list(glob.glob(fp + "/*.png"))
images.sort()
image_list = []
for i, file_name in enumerate(images):
if i == 0:
continue
image_list.append(imageio.imread(str(file_name)))
if i == len(images) - 1:
for j in range(keep_len):
image_list.append(imageio.imread(str(file_name)))
imageio.mimsave(
ap + "/" + ani_name + suffix, image_list, **self.saved_ani_kwargs
)
self.logger.info("Animation created successfully, saved in " + ap)
if rm_fig_path:
shutil.rmtree(fp)
[docs]
def show(self):
"""
Display the plot.
"""
plt.show()
[docs]
def close(self):
"""
Close the plot.
"""
plt.close()
@property
def logger(self):
return env_param.logger
[docs]
def linewidth_from_data_units(linewidth, axis, reference="y"):
"""
Convert a linewidth in data units to linewidth in points.
Parameters
----------
linewidth: float
Linewidth in data units of the respective reference-axis
axis: matplotlib axis
The axis which is used to extract the relevant transformation
data (data limits and size must not change afterwards)
reference: string
The axis that is taken as a reference for the data width.
Possible values: 'x' and 'y'. Defaults to 'y'.
Returns
-------
linewidth: float
Linewidth in points
"""
fig = axis.get_figure()
if reference == "x":
length = fig.bbox_inches.width * axis.get_position().width
value_range = np.diff(axis.get_xlim())
elif reference == "y":
length = fig.bbox_inches.height * axis.get_position().height
value_range = np.diff(axis.get_ylim())
# Convert length to points
length *= 72
# Scale linewidth to value range
return linewidth * (length / value_range)