"""
Class EnvBase is the base class of the environment. This class will read the yaml file and create the world, robot, obstacle, and map objects.
Author: Ruihua Han
"""
import matplotlib
import platform
import importlib
import numpy as np
from typing import Optional, Union
from operator import attrgetter
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from shapely import Polygon
from shapely.strtree import STRtree
from irsim.global_param import env_param, world_param
from irsim.env.env_config import EnvConfig
from irsim.world import World, ObjectBase
from .env_plot import EnvPlot
from irsim.world.object_factory import ObjectFactory
from .env_logger import EnvLogger
from irsim.lib import random_generate_polygon
from irsim.gui.mouse_control import MouseControl
try:
from irsim.gui.keyboard_control import KeyboardControl
keyboard_module = True
except ImportError:
keyboard_module = False
# Define backend preferences for different operating systems
BACKEND_PREFERENCES = {
"Darwin": ["MacOSX", "TkAgg", "Qt5Agg", "Agg"], # macOS
"Windows": ["TkAgg", "Qt5Agg", "Agg"], # Windows
"Linux": ["TkAgg", "Qt5Agg", "Agg"], # Linux
}
def _set_matplotlib_backend(backend_list):
"""Attempt to set matplotlib backend from preference list."""
for backend in backend_list:
try:
matplotlib.use(backend)
return True
except Exception as e:
print(f"Failed to use '{backend}' backend: {e}")
print("All backends failed. Falling back to 'Agg' backend. The environment will not be displayed.")
matplotlib.use("Agg")
return False
# Get the current operating system from env_param and set backend
backends = BACKEND_PREFERENCES.get(env_param.platform_name, ["Agg"])
backend_set = _set_matplotlib_backend(backends)
[docs]
class EnvBase:
"""
The base class of environment. This class will read the yaml file and create the world, robot, obstacle, and map objects.
Args:
world_name (str): Path to the world yaml file.
display (bool): Flag to display the environment.
disable_all_plot (bool): Flag to disable all plots and figures.
save_ani (bool): Flag to save the animation.
full (bool): Flag to full screen the figure.
log_file (str): Name of the log file.
log_level (str): Level of the log output.
"""
def __init__(
self,
world_name: Optional[str] = None,
display: bool = True,
disable_all_plot: bool = False,
save_ani: bool = False,
full: bool = False,
log_file: Optional[str] = None,
log_level: str = "INFO",
) -> None:
# init env setting
self.display = display
if not self.display:
matplotlib.use("Agg")
self.disable_all_plot = disable_all_plot
self.save_ani = save_ani
env_param.logger = EnvLogger(log_file, log_level)
self.env_config = EnvConfig(world_name)
self.object_factory = ObjectFactory()
# init objects (world, obstacle, robot)
self._world = World(world_name, **self.env_config.parse["world"])
self._robot_collection = self.object_factory.create_from_parse(
self.env_config.parse["robot"], "robot"
)
self._obstacle_collection = self.object_factory.create_from_parse(
self.env_config.parse["obstacle"], "obstacle"
)
self._map_collection = self.object_factory.create_from_map(
self._world.obstacle_positions, self._world.buffer_reso
)
self._objects = (
self._robot_collection + self._obstacle_collection + self._map_collection
)
self._objects.sort(key=attrgetter("id"))
self.build_tree()
# env parameters
self._env_plot = EnvPlot(
self._world,
self.objects,
**self._world.plot_parse
)
env_param.objects = self.objects
if world_param.control_mode == "keyboard":
if not keyboard_module:
self.logger.error(
"Keyboard module is not installed. Auto control applied. Please install the dependency by 'pip install ir-sim[keyboard]'."
)
world_param.control_mode = "auto"
else:
self.keyboard = KeyboardControl(env_ref=self, **self.env_config.parse["keyboard"])
self.mouse = MouseControl(self._env_plot.ax)
# flag
self.pause_flag = False
if full:
system_platform = platform.system()
if system_platform == "Linux":
mng = plt.get_current_fig_manager()
if mng is not None:
mng.full_screen_toggle()
elif system_platform == "Windows":
mng = plt.get_current_fig_manager()
if mng is not None:
mng.full_screen_toggle()
# Log simulation start
self.logger.info(f"Simulation environment '{self._world.name}' has been initialized and started.")
def __del__(self):
pass
def __str__(self):
return f"Environment: {self._world.name}"
[docs]
def step(
self, action: Optional[Union[np.ndarray, list]] = None, action_id: Optional[Union[int, list]] = 0
):
"""
Perform a simulation step in the environment.
Args:
action (list or numpy array 2*1): Action to be performed in the environment.
- differential robot action: linear velocity, angular velocity
- omnidirectional robot action: v_x -- velocity in x; v_y -- velocity in y
- Ackermann robot action: linear velocity, Steering angle
action_id (int or list of int): Apply the action(s) to the robot(s) with the given id(s).
"""
if self.pause_flag:
return
if isinstance(action, list):
if isinstance(action_id, list):
for a, ai in zip(action, action_id):
self._object_step(a, ai)
else:
self._objects_step(action)
else:
if world_param.control_mode == "keyboard":
self._object_step(self.key_vel, self.key_id)
else:
if isinstance(action_id, list):
self._object_step(action, action_id[0])
else:
self._object_step(action, action_id)
self.build_tree()
self._objects_check_status()
self._world.step()
self.step_status()
def _objects_step(self, action: Optional[list] = None):
action = action + [None] * (len(self.objects) - len(action))
[obj.step(action) for obj, action in zip(self.objects, action)]
def _object_step(self, action: np.ndarray, obj_id: int = 0):
if len(self.objects) == 0:
return
self.objects[obj_id].step(action)
[obj.step() for obj in self.objects if obj._id != obj_id]
def _objects_check_status(self):
[obj.check_status() for obj in self.objects]
# render
[docs]
def render(
self,
interval: float = 0.02,
figure_kwargs=dict(),
mode: str = "dynamic",
**kwargs,
):
"""
Render the environment.
Args:
interval(float) : Time interval between frames in seconds.
figure_kwargs(dict) : Additional keyword arguments for saving figures, see `savefig <https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.savefig.html>`_ for detail.
mode(str) : "dynamic", "static", "all" to specify which type of objects to draw and clear.
kwargs: Additional keyword arguments for drawing components. see :py:meth:`.ObjectBase.plot` function for detail.
"""
if not self.disable_all_plot:
if self._world.sampling:
if self.display:
plt.pause(interval)
if self.save_ani:
self.save_figure(save_gif=True, **figure_kwargs)
self._env_plot.step(mode, self.objects, **kwargs)
[docs]
def show(self):
"""
Show the environment figure.
"""
self._env_plot.show()
# draw various components
[docs]
def draw_trajectory(self, traj: list, traj_type: str = "g-", **kwargs):
"""
Draw the trajectory on the environment figure.
Args:
traj (list): List of trajectory points (2 * 1 vector).
traj_type: Type of the trajectory line, see matplotlib plot function for detail.
**kwargs: Additional keyword arguments for drawing the trajectory, see :py:meth:`.EnvPlot.draw_trajectory` for detail.
"""
self._env_plot.draw_trajectory(traj, traj_type, **kwargs)
[docs]
def draw_points(
self, points: list, s: int = 30, c: str = "b", refresh: bool = True, **kwargs
):
"""
Draw points on the environment figure.
Args:
points (list): List of points (2*1) to be drawn.
or (np.array): (2, Num) to be drawn.
s (int): Size of the points.
c (str): Color of the points.
refresh (bool): Flag to refresh the points in the figure.
**kwargs: Additional keyword arguments for drawing the points, see `ax.scatter <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.scatter.html>`_ function for detail.
"""
self._env_plot.draw_points(points, s, c, refresh, **kwargs)
[docs]
def draw_box(self, vertex: np.ndarray, refresh: bool = False, color: str = "-b"):
"""
Draw a box by the vertices.
Args:
vertex (np.ndarray): matrix of vertices, point_dim*vertex_num
refresh (bool): whether to refresh the plot, default True
color (str): color of the box, default '-b'
"""
self._env_plot.draw_box(vertex, refresh, color)
[docs]
def draw_quiver(self, point, refresh=False, **kwargs):
"""
Draw a single quiver (arrow) on the environment figure.
Args:
point: Point data for the quiver
refresh (bool): Flag to refresh the quiver in the figure, default False
**kwargs: Additional keyword arguments for drawing the quiver
"""
self._env_plot.draw_quiver(point, refresh, **kwargs)
[docs]
def draw_quivers(self, points, refresh=False, **kwargs):
"""
Draw multiple quivers (arrows) on the environment figure.
Args:
points: Points data for the quivers
refresh (bool): Flag to refresh the quivers in the figure, default False
**kwargs: Additional keyword arguments for drawing the quivers
"""
self._env_plot.draw_quivers(points, refresh, **kwargs)
[docs]
def end(self, ending_time: float = 3.0, **kwargs):
"""
End the simulation, save the animation, and close the environment.
Args:
ending_time (float): Time in seconds to wait before closing the figure, default is 3 seconds.
**kwargs: Additional keyword arguments for saving the animation, see :py:meth:`.EnvPlot.save_animate` for detail.
"""
if self.disable_all_plot:
return
if self.save_ani:
self._env_plot.save_animate(**kwargs)
if self.display:
plt.pause(ending_time)
self.logger.info(f"Figure will be closed within {ending_time:.2f} seconds.")
plt.close("all")
env_param.objects = []
ObjectBase.reset_id_iter()
if world_param.control_mode == "keyboard":
self.keyboard.listener.stop()
self.logger.info(f"The simulated environment has ended. Total simulation time: {round(self._world.time, 2)} seconds.")
[docs]
def done(self, mode: str = "all"):
"""
Check if the simulation is done.
Args:
mode (str): Mode to check if all or any of the objects are done.
- all (str): Check if all objects are done.
- any (str): Check if any of the objects are done.
Returns:
bool: True if the simulation is done based on the specified mode, False otherwise.
"""
done_list = [obj.done() for obj in self.objects if obj.role == "robot"]
if len(done_list) == 0:
return False
if mode == "all":
return all(done_list)
elif mode == "any":
return any(done_list)
[docs]
def step_status(self):
'''
Get the status of the environment.
'''
arrive_list = [obj.arrive for obj in self.objects if obj.role == "robot"]
collision_list = [obj.collision for obj in self.objects if obj.role == "robot"]
if len(arrive_list) == 0:
arrive_list = [False]
if len(collision_list) == 0:
collision_list = [False]
if all(arrive_list):
self._world.status = "Arrived"
elif any(collision_list):
self._world.status = "Collision"
else:
self._world.status = "Running"
[docs]
def pause(self):
'''
Pause the environment.
'''
self._world.status = "Pause"
self.pause_flag = True
[docs]
def resume(self):
'''
Resume the environment.
'''
self._world.status = "Running"
self.pause_flag = False
[docs]
def reset(self):
"""
Reset the environment.
"""
self._reset_all()
self.step(action=np.zeros((2, 1)))
self._world.reset()
self.reset_plot()
self._world.status = "Reset"
def _reset_all(self):
[obj.reset() for obj in self.objects]
[docs]
def reset_plot(self):
"""
Reset the environment figure.
"""
self._env_plot.clear_components("all", self.objects)
self._env_plot.init_plot(self._world.grid_map, self.objects)
# region: environment change
[docs]
def random_obstacle_position(
self,
range_low: list = [0, 0, -3.14],
range_high: list = [10, 10, 3.14],
ids: Optional[list] = None,
non_overlapping: bool = False,
):
"""
Random obstacle positions in the environment.
Args:
range_low (list [x, y, theta]): Lower bound of the random range for the obstacle states. Default is [0, 0, -3.14].
range_high (list [x, y, theta]): Upper bound of the random range for the obstacle states. Default is [10, 10, 3.14].
ids (list): A list of IDs of objects for which to set random positions. Default is None.
non_overlapping (bool): If set, the obstacles that will be reset to random obstacles will not overlap with other obstacles. Default is False.
"""
if ids is None:
ids = [obs.id for obs in self.obstacle_list]
if isinstance(range_low, list):
range_low = np.c_[range_low]
if isinstance(range_high, list):
range_high = np.c_[range_high]
selected_obs = [obs for obs in self.obstacle_list if obs.id in ids]
existing_obj = [obj for obj in self.objects if obj.id not in ids]
for obj in selected_obs:
if not non_overlapping:
obj.set_state(
np.random.uniform(range_low, range_high, (3, 1)), init=True
)
else:
counter = 0
while counter < 100:
obj.set_state(
np.random.uniform(range_low, range_high, (3, 1)), init=True
)
if any([obj.check_collision(exi_obj) for exi_obj in existing_obj]):
counter += 1
else:
existing_obj.append(obj)
break
self._env_plot.clear_components("all", self.obstacle_list)
self._env_plot.draw_components("all", self.obstacle_list)
[docs]
def random_polygon_shape(
self,
center_range: list = [0, 0, 10, 10],
avg_radius_range: list = [0.1, 1],
irregularity_range: list = [0, 1],
spikeyness_range: list = [0, 1],
num_vertices_range: list = [4, 10],
):
"""
Random polygon shapes for the obstacles in the environment.
Args:
center_range (list): Range of the center of the polygon. Default is [0, 0, 10, 10].
avg_radius_range (list): Range of the average radius of the polygon. Default is [0.1, 1].
irregularity_range (list): Range of the irregularity of the polygon. Default is [0, 1].
spikeyness_range (list): Range of the spikeyness of the polygon. Default is [0, 1].
num_vertices_range (list): Range of the number of vertices of the polygon. Default is [4, 10].
center (Tuple[float, float]):
a pair representing the center of the circumference used
to generate the polygon.
avg_radius (float):
the average radius (distance of each generated vertex to
the center of the circumference) used to generate points
with a normal distribution.
irregularity (float): 0 - 1
variance of the spacing of the angles between consecutive
vertices.
spikeyness (float): 0 - 1
variance of the distance of each vertex to the center of
the circumference.
num_vertices (int):
the number of vertices of the polygon.
"""
vertices_list = random_generate_polygon(
self.obstacle_number,
center_range,
avg_radius_range,
irregularity_range,
spikeyness_range,
num_vertices_range,
)
for i, obj in enumerate(self.obstacle_list):
if obj.shape == "polygon":
geom = Polygon(vertices_list[i])
obj.set_original_geometry(geom)
self._env_plot.clear_components("all", self.obstacle_list)
self._env_plot.draw_components("all", self.obstacle_list)
# endregion: environment change
# region: object operation
[docs]
def create_obstacle(self, **kwargs):
"""
Create an obstacle in the environment.
Args:
**kwargs: Additional parameters for obstacle creation.
see ObjectFactory.create_obstacle for detail
Returns:
Obstacle: An instance of an obstacle.
"""
return self.object_factory.create_obstacle(**kwargs)
[docs]
def add_object(self, obj: ObjectBase):
"""
Add the object to the environment.
Args:
obj (ObjectBase): The object to be added to the environment.
"""
self._objects.append(obj)
self.build_tree()
[docs]
def add_objects(self, objs: list):
"""
Add the objects to the environment.
Args:
objs (list): List of objects to be added to the environment.
"""
self._objects.extend(objs)
self.build_tree()
[docs]
def delete_object(self, target_id: int):
"""
Delete the object with the given id.
Args:
target_id (int): ID of the object to be deleted.
"""
for obj in self._objects:
if obj.id == target_id:
obj.plot_clear()
self._objects.remove(obj)
break
self.build_tree()
[docs]
def delete_objects(self, target_ids: list):
"""
Delete the objects with the given ids.
Args:
target_ids (list): List of IDs of objects to be deleted.
"""
del_obj = [obj for obj in self._objects if obj.id in target_ids]
for obj in del_obj:
obj.plot_clear()
self._objects.remove(obj)
self.build_tree()
[docs]
def build_tree(self):
"""
Build the geometry tree for the objects in the environment to detect the possible collision objects.
"""
env_param.GeometryTree = STRtree([obj.geometry for obj in self.objects])
# endregion: object operation
# region: get information
[docs]
def get_robot_state(self):
"""
Get the current state of the robot.
Returns:
state: 3*1 vector [x, y, theta]
"""
return self.robot._state
[docs]
def get_lidar_scan(self, id: int = 0):
"""
Get the lidar scan of the robot with the given id.
Args:
id (int): Id of the robot.
Returns:
Dict: Dict of lidar scan points, see :py:meth:`.world.sensors.lidar2d.Lidar2D.get_scan` for detail.
"""
return self.robot_list[id].get_lidar_scan()
[docs]
def get_lidar_offset(self, id: int = 0):
"""
Get the lidar offset of the robot with the given id.
Args:
id (int): Id of the robot.
Returns:
list of float: Lidar offset of the robot, [x, y, theta]
"""
return self.robot_list[id].get_lidar_offset()
[docs]
def get_obstacle_info_list(self):
"""
Get the information of the obstacles in the environment.
Returns:
list of dict: List of obstacle information, see :py:meth:`.ObjectBase.get_obstacle_info` for detail.
"""
return [obj.get_obstacle_info() for obj in self.obstacle_list]
[docs]
def get_robot_info(self, id: int = 0):
"""
Get the information of the robot with the given id.
Args:
id (int): Id of the robot.
Returns:
see :py:meth:`.ObjectBase.get_info` for detail
"""
return self.robot_list[id].get_info()
[docs]
def get_robot_info_list(self):
"""
Get the information of the robots in the environment.
Returns:
list of dict: List of robot information, see :py:meth:`.ObjectBase.get_info` for detail.
"""
return [obj.get_info() for obj in self.robot_list]
[docs]
def get_map(self, resolution: float = 0.1):
"""
Get the map of the environment with the given resolution.
Args:
resolution (float): Resolution of the map. Default is 0.1.
Returns:
The map of the environment with the specified resolution.
"""
return self._world.get_map(resolution, self.obstacle_list)
# endregion: get information
[docs]
def set_title(self, title: str):
"""
Set the title of the plot.
"""
self._env_plot.title = title
[docs]
def load_behavior(self, behaviors: str = "behavior_methods"):
"""
Load behavior parameters from the script. Please refer to the behavior_methods.py file for more details.
Please make sure the python file is placed in the same folder with the implemented script.
Args:
behaviors (str): name of the bevavior script.
"""
try:
importlib.import_module(behaviors)
except ImportError as e:
print(f"Failed to load module '{behaviors}': {e}")
# region: property
@property
def robot_list(self):
"""
Get the list of robots in the environment.
Returns:
list: List of robot objects [].
"""
return [obj for obj in self.objects if obj.role == "robot"]
@property
def obstacle_list(self):
"""
Get the list of obstacles in the environment.
Returns:
list: List of obstacle objects.
"""
return [obj for obj in self.objects if obj.role == "obstacle"]
@property
def objects(self):
"""
Get all objects in the environment.
Returns:
list: List of all objects in the environment.
"""
return self._objects
@property
def static_objects(self):
"""
Get all static objects in the environment.
Returns:
list: List of static objects in the environment.
"""
return [obj for obj in self.objects if obj.static]
@property
def dynamic_objects(self):
"""
Get all dynamic objects in the environment.
Returns:
list: List of dynamic objects in the environment.
"""
return [obj for obj in self.objects if not obj.static]
@property
def step_time(self):
"""
Get the step time of the simulation.
Returns:
float: Step time of the simulation from the world.
"""
return self._world.step_time
@property
def time(self):
"""
Get the time of the simulation.
"""
return self._world.time
@property
def status(self):
"""
Get the status of the environment.
"""
return self._world.status
@property
def robot(self):
"""
Get the first robot in the environment.
Returns:
Robot: The first robot object in the robot list.
"""
return self.robot_list[0]
@property
def obstacle_number(self):
"""
Get the number of obstacles in the environment.
Returns:
int: Number of obstacles in the environment.
"""
return len(self.obstacle_list)
@property
def robot_number(self):
"""
Get the number of robots in the environment.
Returns:
int: Number of robots in the environment.
"""
return len(self.robot_list)
@property
def logger(self):
"""
Get the environment logger.
Returns:
EnvLogger: The logger instance for the environment.
"""
return env_param.logger
@property
def key_vel(self):
return self.keyboard.key_vel
@property
def key_id(self):
return self.keyboard.key_id
@property
def mouse_pos(self):
return self.mouse.mouse_pos
@property
def mouse_left_pos(self):
return self.mouse.left_click_pos
@property
def mouse_right_pos(self):
return self.mouse.right_click_pos
# endregion: property