Source code for irsim.env.env_plot3d

from typing import Any, Optional, Union

import numpy as np

from .env_plot import EnvPlot


[docs] class EnvPlot3D(EnvPlot): def __init__( self, world: Any, objects: Optional[list[Any]] = None, saved_figure: Optional[dict[str, Any]] = None, figure_pixels: Optional[list[int]] = None, show_title: bool = True, **kwargs: Any, ) -> None: if objects is None: objects = [] if saved_figure is None: saved_figure = {} if figure_pixels is None: figure_pixels = [1180, 1080] super().__init__( world, objects, saved_figure, figure_pixels, show_title, **kwargs ) self.ax = self.fig.add_subplot(projection="3d") self.z_range = world.z_range self.init_plot(world.grid_map, objects, **kwargs) self.ax.set_zlim(self.z_range)
[docs] def draw_points( self, points: Optional[Union[list, np.ndarray]], s: int = 10, c: str = "m", refresh: bool = True, **kwargs: Any, ) -> None: """ Draw points on the plot. Args: points (list): List of points, each point as [x, y, z]. s (int): Size of the points. c (str): Color of the points. refresh (bool): Whether to refresh the plot. kwargs: Additional plotting options. """ if points is None: return if isinstance(points, list): x_coordinates = [point[0] for point in points] y_coordinates = [point[1] for point in points] z_coordinates = [point[2] for point in points] elif isinstance(points, np.ndarray): if points.shape[1] > 1: x_coordinates = [point[0] for point in points.T] y_coordinates = [point[1] for point in points.T] z_coordinates = [point[2] for point in points.T] else: x_coordinates = points[0] y_coordinates = points[1] z_coordinates = points[2] points = self.ax.scatter( x_coordinates, y_coordinates, z_coordinates, "z", s, c, **kwargs ) if refresh: self.dyna_point_list.append(points)
[docs] def draw_quiver( self, point: Optional[np.ndarray], refresh: bool = False, **kwargs: Any ) -> None: """ Draw a quiver plot on the plot. Args: points (6*1 np.ndarray): List of points, each point as [x, y, z, u, v, w]. u, v, w are the components of the vector. kwargs: Additional plotting options. """ if point is None: return ax_point = self.ax.scatter( point[0], point[1], point[2], color=kwargs.get("point_color", "blue"), label="Points", ) ax_quiver = self.ax.quiver( point[0], point[1], point[2], # starting positions point[3], point[4], point[5], # vector components (direction) length=0.2, normalize=True, color=kwargs.get("quiver_color", "red"), label="Direction", ) if refresh: self.dyna_quiver_list.append(ax_quiver) self.dyna_point_list.append(ax_point)
[docs] def draw_quivers( self, points: Union[list, np.ndarray], refresh: bool = False, **kwargs: Any ) -> None: """ Draw a series of quiver plot on the plot. Args: points (list or np.ndarray): List of points, each point as [x, y, z, u, v, w]. u, v, w are the components of the vector. """ for point in points.T if isinstance(points, np.ndarray) else points: self.draw_quiver(point, refresh, **kwargs)
[docs] def draw_trajectory( self, traj: Union[list, np.ndarray], traj_type: str = "g-", label: str = "trajectory", show_direction: bool = False, refresh: bool = False, **kwargs: Any, ) -> None: """ Draw a trajectory on the plot. Args: traj (list or np.ndarray): List of points or array of points [x, y, z]. traj_type (str): Type of trajectory line (e.g., 'g-'). See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.plot.html for details. label (str): Label for the trajectory. show_direction (bool): Whether to show the direction of the trajectory. refresh (bool): Whether to refresh the plot. kwargs: Additional plotting options for ax.plot() """ if isinstance(traj, list): path_x_list = [p[0, 0] for p in traj] path_y_list = [p[1, 0] for p in traj] path_z_list = [p[2, 0] for p in traj] elif isinstance(traj, np.ndarray): path_x_list = [p[0] for p in traj.T] path_y_list = [p[1] for p in traj.T] path_z_list = [p[2] for p in traj.T] line = self.ax.plot( path_x_list, path_y_list, path_z_list, traj_type, label=label, **kwargs ) if show_direction: print("Not support currently") # if isinstance(traj, list): # u_list = [cos(p[2, 0]) for p in traj] # v_list = [sin(p[2, 0]) for p in traj] # elif isinstance(traj, np.ndarray): # u_list = [cos(p[2]) for p in traj.T] # v_list = [sin(p[2]) for p in traj.T] # if isinstance(self.ax, Axes3D): # path_z_list = [0] * len(path_x_list) # w_list = [0] * len(u_list) # self.ax.quiver(path_x_list, path_y_list, path_z_list, u_list, v_list, w_list) # else: # self.ax.quiver(path_x_list, path_y_list, u_list, v_list) if refresh: self.dyna_line_list.append(line)
[docs] def update_title(self) -> None: """ Override the parent's update_title method to handle 3D plots properly. """ if not self.show_title: return if self.title is not None: self.fig.suptitle(self.title, fontsize=12) else: self.fig.suptitle( f"Simulation Time: {self.world.time:.2f}s, Status: {self.world.status}", fontsize=12, )