"""
Rapidly-exploring Random Tree (RRT) path planner.
Collision precedence (delegated to ``env_map.is_collision(geometry)``):
1. Grid lookup when ``env_map.grid`` is not ``None``; if occupied, collision.
2. When the grid reports free or is unavailable, geometry vs. obstacle_list.
The planner builds the robot shape as a geometry (e.g. circle, or polygon for
non-circular robots) and calls the unified interface; the map supports any
Shapely geometry.
Reference:
S. M. LaValle, "Rapidly-Exploring Random Trees: A New Tool for Path
Planning," 1998.
Implementation reference:
ZJU-FAST-Lab/sampling-based-path-finding
https://github.com/ZJU-FAST-Lab/sampling-based-path-finding
(C++ implementation with kd-tree and efficient tree management.)
Adapted for ir-sim.
"""
from __future__ import annotations
import logging
import math
import random
from collections import deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from shapely import minimum_bounding_radius
from irsim.util.util import geometry_transform
from irsim.world.map import EnvGridMap
from irsim.world.object_base import ObjectBase
if TYPE_CHECKING:
from matplotlib.lines import Line2D
from scipy.spatial import KDTree as _SciKDTree
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Tree node (shared by RRT, RRT*, Informed RRT*)
# ---------------------------------------------------------------------------
[docs]
@dataclass(slots=True, eq=False)
class TreeNode:
"""A node in the RRT search tree.
Identity-based equality (``eq=False``) is used so that nodes can be
safely compared with ``==`` / ``is`` and used as dictionary keys
without triggering recursive field comparisons.
Attributes:
x: X position.
y: Y position.
cost: Cumulative cost from the start node (``cost_from_start``).
cost_from_parent: Cost of the edge from parent to this node.
parent: Parent node (``None`` for the root).
children: Direct children in the tree.
path_x: Discretised X coordinates from parent to this node
(used for visualisation and per-point collision checking).
path_y: Discretised Y coordinates from parent to this node.
"""
x: float
y: float
cost: float = 0.0
cost_from_parent: float = 0.0
parent: TreeNode | None = None
children: list[TreeNode] = field(default_factory=list, repr=False)
path_x: list[float] = field(default_factory=list, repr=False)
path_y: list[float] = field(default_factory=list, repr=False)
# ---------------------------------------------------------------------------
# RRT planner
# ---------------------------------------------------------------------------
[docs]
class RRT:
"""Rapidly-exploring Random Tree (RRT) path planner.
Algorithmic improvements over the naive implementation (inspired by
ZJU-FAST-Lab/sampling-based-path-finding):
* Nodes track their children, enabling O(subtree) cost propagation
via BFS instead of O(n) linear scans.
* ``scipy.spatial.KDTree`` is used for nearest-neighbour and range
queries when available, falling back to numpy otherwise.
* A dedicated *goal node* is kept in the tree; whenever a new node
can reach the goal at a lower cost the goal's parent is rewired.
"""
# Expose TreeNode as a class-level alias for backward compatibility
Node = TreeNode
[docs]
class AreaBounds:
"""Rectangular play-area bounds in world coordinates."""
__slots__ = ("xmax", "xmin", "ymax", "ymin")
def __init__(self, env_map: EnvGridMap) -> None:
off = np.asarray(env_map.world_offset, dtype=float).flatten()
self.xmin: float = float(off[0])
self.ymin: float = float(off[1])
w = float(np.asarray(env_map.width).flat[0])
h = float(np.asarray(env_map.height).flat[0])
self.xmax: float = self.xmin + w
self.ymax: float = self.ymin + h
# ------------------------------------------------------------------
# Construction
# ------------------------------------------------------------------
def __init__(
self,
env_map: EnvGridMap,
robot: ObjectBase,
expand_dis: float = 1.0,
path_resolution: float = 0.25,
goal_sample_rate: int = 5,
max_iter: int = 500,
) -> None:
"""Initialise the RRT planner.
Robot shape is taken from ``robot.original_geometry`` (e.g. pass
``robot=env.robot``).
Args:
env_map: Environment map (:class:`~irsim.world.map.EnvGridMap`).
robot: Robot object; its :attr:`~irsim.world.object_base.ObjectBase.original_geometry` is used for collision.
expand_dis: Maximum extension distance per steer step.
path_resolution: Discretisation resolution along steered edges.
goal_sample_rate: Percentage chance of sampling the goal.
max_iter: Maximum number of iterations.
"""
if robot is None:
raise ValueError("robot is required (e.g. robot=env.robot).")
self._map = env_map
self.obstacle_list = env_map.obstacle_list[:]
off = np.asarray(env_map.world_offset, dtype=float).flatten()
self._origin_x = float(off[0])
self._origin_y = float(off[1])
self.play_area = self.AreaBounds(env_map)
self.max_x = self.play_area.xmax
self.max_y = self.play_area.ymax
self.min_rand = min(self.play_area.xmin, self.play_area.ymin)
self.max_rand = max(self.play_area.xmax, self.play_area.ymax)
self.expand_dis = expand_dis
self.path_resolution = path_resolution
self.goal_sample_rate = goal_sample_rate
self.max_iter = max_iter
self.node_list = []
self.start = None
self.end = None
self._collision_geometry = robot.original_geometry
self.robot_radius = float(minimum_bounding_radius(self._collision_geometry))
# --- KDTree state ---
self._kd_coords: np.ndarray | None = None
self._kd_tree: _SciKDTree | None = None
self._kd_dirty: bool = True
# --- visualisation state ---
self._vis_temp: list = []
self._vis_setup_done: bool = False
self._tree_line: Line2D | None = None
# ------------------------------------------------------------------
# Tree management helpers (reference-style)
# ------------------------------------------------------------------
def _add_tree_node(
self,
parent: TreeNode,
x: float,
y: float,
cost_from_parent: float,
path_x: list[float] | None = None,
path_y: list[float] | None = None,
) -> TreeNode:
"""Create a ``TreeNode``, attach it to *parent*, and register it."""
node = TreeNode(
x=x,
y=y,
cost=parent.cost + cost_from_parent,
cost_from_parent=cost_from_parent,
parent=parent,
children=[],
path_x=path_x if path_x is not None else [],
path_y=path_y if path_y is not None else [],
)
parent.children.append(node)
self.node_list.append(node)
self._kd_dirty = True
return node
def _change_node_parent(
self,
node: TreeNode,
new_parent: TreeNode,
cost_from_parent: float,
) -> None:
"""Re-parent *node* and BFS-propagate costs to all descendants.
Mirrors ``changeNodeParent`` from the C++ reference.
"""
if node.parent is not None:
node.parent.children.remove(node)
node.parent = new_parent
node.cost_from_parent = cost_from_parent
node.cost = new_parent.cost + cost_from_parent
new_parent.children.append(node)
# BFS cost propagation
queue: deque[TreeNode] = deque(node.children)
while queue:
child = queue.popleft()
child.cost = child.parent.cost + child.cost_from_parent
queue.extend(child.children)
def _fill_path(self, node: TreeNode) -> np.ndarray:
"""Trace parent chain from *node* to root, return ``(2, N)`` array.
Path order is *node* to root (e.g. goal -> start).
"""
path_x: list[float] = []
path_y: list[float] = []
current: TreeNode | None = node
while current is not None:
path_x.append(current.x)
path_y.append(current.y)
current = current.parent
return np.array([path_x, path_y])
# ------------------------------------------------------------------
# KDTree helpers
# ------------------------------------------------------------------
def _rebuild_kd_tree(self) -> None:
"""Rebuild the spatial index from the current ``node_list``."""
if not self._kd_dirty or not self.node_list:
return
self._kd_coords = np.array([[n.x, n.y] for n in self.node_list])
self._kd_tree = _SciKDTree(self._kd_coords)
self._kd_dirty = False
def _nearest(self, x: float, y: float) -> int:
"""Index of the closest node to ``(x, y)``."""
self._rebuild_kd_tree()
_, idx = self._kd_tree.query([x, y])
return int(idx)
def _near_in_radius(self, x: float, y: float, radius: float) -> list[int]:
"""Indices of all nodes within *radius* of ``(x, y)``."""
self._rebuild_kd_tree()
return self._kd_tree.query_ball_point([x, y], radius)
# ------------------------------------------------------------------
# Planning
# ------------------------------------------------------------------
[docs]
def planning(
self,
start_pose: list[float],
goal_pose: list[float],
show_animation: bool = True,
) -> np.ndarray | None:
"""Run RRT path planning.
Args:
start_pose: Start position ``[x, y]``.
goal_pose: Goal position ``[x, y]``.
show_animation: Render the exploration tree during planning.
Returns:
``(2, N)`` numpy array ``[rx, ry]`` or *None*.
"""
start_pose = np.asarray(start_pose, dtype=float).flatten()
goal_pose = np.asarray(goal_pose, dtype=float).flatten()
sx, sy = float(start_pose[0]), float(start_pose[1])
gx, gy = float(goal_pose[0]), float(goal_pose[1])
self.start = TreeNode(x=sx, y=sy, cost=0.0)
self.end = TreeNode(x=gx, y=gy, cost=float("inf"))
# Reset
self.node_list = [self.start]
self._kd_dirty = True
self._vis_setup_done = False
self._tree_line = None
self._vis_temp = []
for _ in range(self.max_iter):
# 1. Sample
rnd_node = self.get_random_node()
# 2. Nearest
nearest_ind = self._nearest(rnd_node.x, rnd_node.y)
nearest_node = self.node_list[nearest_ind]
# 3. Steer
new_node = self.steer(nearest_node, rnd_node, self.expand_dis)
# 4. Bounds + collision
if not self._check_bounds(new_node.x, new_node.y):
continue
if not self.is_collision(new_node):
continue
# 5. Add to tree
cost_fp = math.hypot(
new_node.x - nearest_node.x,
new_node.y - nearest_node.y,
)
added = self._add_tree_node(
nearest_node,
new_node.x,
new_node.y,
cost_fp,
path_x=new_node.path_x,
path_y=new_node.path_y,
)
if show_animation:
self.draw_graph(added)
# 6. Try to connect to goal
dist_to_goal = math.hypot(
added.x - self.end.x,
added.y - self.end.y,
)
if dist_to_goal <= self.expand_dis:
goal_edge = self.steer(added, self.end, self.expand_dis)
if self.is_collision(goal_edge):
# For basic RRT, return on the first feasible connection
self.end.parent = added
self.end.cost_from_parent = dist_to_goal
self.end.cost = added.cost + dist_to_goal
self.end.path_x = goal_edge.path_x
self.end.path_y = goal_edge.path_y
return self._fill_path(self.end)
return None # cannot find path
# ------------------------------------------------------------------
# Steer
# ------------------------------------------------------------------
[docs]
def steer(
self,
from_node: TreeNode,
to_node: TreeNode,
extend_length: float = float("inf"),
) -> TreeNode:
"""Steer from *from_node* towards *to_node*.
Returns a **temporary** ``TreeNode`` (not yet registered in the
tree) with ``path_x``/``path_y`` populated and ``parent`` set to
*from_node*.
"""
d, theta = self.calc_distance_and_angle(from_node, to_node)
if extend_length > d:
extend_length = d
new_node = TreeNode(x=from_node.x, y=from_node.y)
new_node.path_x = [new_node.x]
new_node.path_y = [new_node.y]
n_expand = math.floor(extend_length / self.path_resolution)
cos_t = math.cos(theta)
sin_t = math.sin(theta)
for _ in range(n_expand):
new_node.x += self.path_resolution * cos_t
new_node.y += self.path_resolution * sin_t
new_node.path_x.append(new_node.x)
new_node.path_y.append(new_node.y)
d2, _ = self.calc_distance_and_angle(new_node, to_node)
if d2 <= self.path_resolution:
new_node.x = to_node.x
new_node.y = to_node.y
new_node.path_x.append(to_node.x)
new_node.path_y.append(to_node.y)
new_node.parent = from_node
return new_node
# ------------------------------------------------------------------
# Sampling
# ------------------------------------------------------------------
[docs]
def get_random_node(self) -> TreeNode:
"""Uniform random sample with goal-bias."""
if random.randint(0, 100) > self.goal_sample_rate:
return TreeNode(
x=random.uniform(self.min_rand, self.max_rand),
y=random.uniform(self.min_rand, self.max_rand),
)
return TreeNode(x=self.end.x, y=self.end.y)
# ------------------------------------------------------------------
# Collision detection
# ------------------------------------------------------------------
def _check_bounds(self, x: float, y: float) -> bool:
"""Return *True* if ``(x, y)`` lies inside the play area."""
pa = self.play_area
return pa.xmin <= x <= pa.xmax and pa.ymin <= y <= pa.ymax
[docs]
def is_collision(self, node: TreeNode) -> bool:
"""Check whether *node*'s edge is collision-free.
Uses :attr:`_collision_geometry` translated to each path point.
Returns *True* if the path is **collision-free**.
"""
if node is None:
return False
# Check each point along the discretised edge
for px, py in zip(node.path_x, node.path_y, strict=True):
if self._check_point(px, py):
return False
# Also check the node endpoint itself
return not self._check_point(node.x, node.y)
def _check_point(self, x: float, y: float, theta: float = 0.0) -> bool:
"""Single-point collision check.
Translates :attr:`_collision_geometry` to *(x, y)* with orientation
*theta* via :func:`~irsim.util.util.geometry_transform` and calls
``env_map.is_collision(geometry)``. Supports any Shapely geometry
(circle, rectangle, polygon, linestring from ir-sim).
Returns *True* if a **collision is detected**.
"""
state = np.array([x, y, theta], dtype=float)
moved = geometry_transform(self._collision_geometry, state)
return self._map.is_collision(moved)
# ------------------------------------------------------------------
# Visualisation
# ------------------------------------------------------------------
[docs]
def draw_graph(self, rnd: TreeNode | None = None) -> None:
"""Render the RRT tree on the active matplotlib axes."""
ax = plt.gca()
# Remove transient markers from previous frame
for a in self._vis_temp:
a.remove()
self._vis_temp.clear()
# One-time setup
if not self._vis_setup_done:
ax.figure.canvas.mpl_connect(
"key_release_event",
lambda event: (
plt.close(event.canvas.figure) if event.key == "escape" else None
),
)
ax.plot(self.start.x, self.start.y, "xr", markersize=8, zorder=5)
ax.plot(self.end.x, self.end.y, "xr", markersize=8, zorder=5)
if self.play_area is not None:
pa = self.play_area
ax.plot(
[pa.xmin, pa.xmax, pa.xmax, pa.xmin, pa.xmin],
[pa.ymin, pa.ymin, pa.ymax, pa.ymax, pa.ymin],
"-k",
linewidth=0.6,
)
self._vis_setup_done = True
plt.pause(0.05)
# Transient markers
if rnd is not None:
(marker,) = ax.plot(rnd.x, rnd.y, "^k")
self._vis_temp.append(marker)
# Draw robot shape at random node (theta=0)
state = np.array([rnd.x, rnd.y, 0.0], dtype=float)
moved = geometry_transform(self._collision_geometry, state)
if hasattr(moved, "exterior"):
x_coords, y_coords = moved.exterior.xy
elif hasattr(moved, "xy"):
x_coords, y_coords = moved.xy
else:
x_coords, y_coords = [], []
if len(x_coords) > 0:
(shape_line,) = ax.plot(x_coords, y_coords, "-r", linewidth=0.8)
self._vis_temp.append(shape_line)
# Tree edges (single Line2D, updated in-place)
xs: list[float] = []
ys: list[float] = []
for node in self.node_list:
if node.parent and node.path_x:
xs.extend(node.path_x)
ys.extend(node.path_y)
xs.append(float("nan"))
ys.append(float("nan"))
if self._tree_line is None and xs:
(self._tree_line,) = ax.plot(xs, ys, "-g", linewidth=0.5)
elif self._tree_line is not None:
self._tree_line.set_data(xs, ys)
plt.pause(0.01)
@staticmethod
def _plot_circle(
x: float,
y: float,
size: float,
color: str = "-b",
ax: plt.Axes | None = None,
) -> Line2D:
"""Plot a circle and return the ``Line2D`` artist."""
if ax is None:
ax = plt.gca()
deg = list(range(0, 360, 5))
deg.append(0)
xl = [x + size * math.cos(np.deg2rad(d)) for d in deg]
yl = [y + size * math.sin(np.deg2rad(d)) for d in deg]
(line,) = ax.plot(xl, yl, color)
return line
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
[docs]
def calc_dist_to_goal(self, x: float, y: float) -> float:
"""Euclidean distance from ``(x, y)`` to the goal."""
return math.hypot(x - self.end.x, y - self.end.y)
[docs]
@staticmethod
def calc_distance_and_angle(
from_node: TreeNode,
to_node: TreeNode,
) -> tuple[float, float]:
"""Euclidean distance and heading between two nodes."""
dx = to_node.x - from_node.x
dy = to_node.y - from_node.y
return math.hypot(dx, dy), math.atan2(dy, dx)