Source code for irsim.lib.path_planners.jps

"""
Jump Point Search (JPS) grid planning.

An optimization of A* for uniform-cost grids that prunes symmetric paths
and expands "jump points" only, preserving optimality while reducing nodes expanded.

Collision precedence:
  1. Grid lookup when ``env_map.grid`` is not ``None``; if occupied, collision.
  2. When the grid reports free or is unavailable, Shapely vs. obstacle_list.
  (Grid and obstacle_list are combined when both are present.)

References
----------
- D. Harabor and A. Grastien. Online Graph Pruning for Pathfinding on Grid Maps.
  In AAAI, 2011. https://en.wikipedia.org/wiki/Jump_point_search
- 2D implementation reference: KumarRobotics/jps3d (C++ JPS on 2D/3D maps).
  https://github.com/KumarRobotics/jps3d

"""

from __future__ import annotations

import contextlib
import itertools
import math
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np

from irsim.lib.handler.geometry_handler import GeometryFactory
from irsim.world.map import EnvGridMap

# Type alias: ((jx, jy, dx, dy), cost) for each jump successor
JpsSuccessor = tuple[tuple[int, int, int, int], float]


# --- JPS 2D neighbor tables (aligned with jps3d JPS2DNeib) ---
# Direction id: (dx+1) + 3*(dy+1) for dx,dy in {-1,0,1}
# nsz[norm1][0] = number of natural neighbors, nsz[norm1][1] = number of forced checks
# norm1 = |dx| + |dy|  =>  0: start, 1: cardinal, 2: diagonal
_JPS2D_NSZ = ((8, 0), (1, 2), (3, 2))


# Natural neighbors ns[id][0/1][dev]: id -> (list of dx, list of dy) to try
# Precomputed like JPS2DNeib::Neib: norm1=0 -> 8 dirs; norm1=1 -> (dx,dy); norm1=2 -> (dx,0),(0,dy),(dx,dy)
def _build_jps2d_ns() -> list[list[list[int]]]:
    ns: list[list[list[int]]] = [[[] for _ in range(2)] for _ in range(9)]
    for dy in range(-1, 2):
        for dx in range(-1, 2):
            id_ = (dx + 1) + 3 * (dy + 1)
            norm1 = abs(dx) + abs(dy)
            n = _JPS2D_NSZ[norm1][0]
            for dev in range(n):
                if norm1 == 0:
                    tbl = [
                        (1, 0),
                        (-1, 0),
                        (0, 1),
                        (1, 1),
                        (-1, 1),
                        (0, -1),
                        (1, -1),
                        (-1, -1),
                    ]
                    tx, ty = tbl[dev]
                elif norm1 == 1:
                    tx, ty = dx, dy
                else:
                    tbl = [(dx, 0), (0, dy), (dx, dy)]
                    tx, ty = tbl[dev]
                ns[id_][0].append(tx)
                ns[id_][1].append(ty)
    return ns


def _build_jps2d_f1_f2() -> tuple[list[list[list[int]]], list[list[list[int]]]]:
    """f1 = offset to check if occupied; f2 = direction to jump if forced. Same as JPS2DNeib::FNeib."""
    f1: list[list[list[int]]] = [[[] for _ in range(2)] for _ in range(9)]
    f2: list[list[list[int]]] = [[[] for _ in range(2)] for _ in range(9)]
    for dy in range(-1, 2):
        for dx in range(-1, 2):
            id_ = (dx + 1) + 3 * (dy + 1)
            norm1 = abs(dx) + abs(dy)
            nf = _JPS2D_NSZ[norm1][1]
            for dev in range(nf):
                if norm1 == 1:
                    if dev == 0:
                        fx, fy = 0, 1
                    else:
                        fx, fy = 0, -1
                    if dx == 0:
                        fx, fy = fy, 0
                    nx, ny = dx + fx, dy + fy
                else:  # norm1 == 2
                    if dev == 0:
                        fx, fy = -dx, 0
                        nx, ny = -dx, dy
                    else:
                        fx, fy = 0, -dy
                        nx, ny = dx, -dy
                f1[id_][0].append(fx)
                f1[id_][1].append(fy)
                f2[id_][0].append(nx)
                f2[id_][1].append(ny)
    return f1, f2


_JPS2D_NS = _build_jps2d_ns()
_JPS2D_F1, _JPS2D_F2 = _build_jps2d_f1_f2()


@dataclass
class _JpsNode:
    """Grid node for JPS: (x, y), cost, parent index, and direction (dx, dy) for pruning."""

    x: int
    y: int
    cost: float
    parent_index: int
    dx: int = 0
    dy: int = 0


def _dir_id(dx: int, dy: int) -> int:
    """Direction id: (dx+1) + 3*(dy+1) for dx, dy in {-1, 0, 1}."""
    return (dx + 1) + 3 * (dy + 1)


[docs] class JPSPlanner: """Jump Point Search planner for uniform-cost 8-connected grids. When the environment map carries an occupancy grid (``env_map.grid``), collision checks use a fast O(1) grid lookup. Otherwise, the planner falls back to Shapely geometry intersection (same as :class:`AStarPlanner`). """ def __init__(self, env_map: EnvGridMap) -> None: """ Initialize JPS planner. Args: env_map: Environment map (any :class:`~irsim.world.map.EnvGridMap` compatible object). Resolution and bounds are taken from the map (same as :class:`AStarPlanner`). """ self._map = env_map off = np.asarray(env_map.world_offset, dtype=float).flatten() self.origin_x = float(off[0]) self.origin_y = float(off[1]) self.min_x, self.min_y = 0, 0 # grid indices are 0-based self.max_x = self.origin_x + env_map.width self.max_y = self.origin_y + env_map.height # When map has a grid, use its actual resolution and shape so planner grid # matches collision lookups (avoids "Open set is empty" on resolution mismatch). grid = getattr(env_map, "grid", None) gr = None if grid is not None and hasattr(env_map, "grid_resolution"): with contextlib.suppress(Exception): gr = env_map.grid_resolution if grid is not None and gr is not None: self.resolution = gr[0] # m/cell; assume square cells (gr[0]==gr[1]) self.x_width = grid.shape[0] self.y_width = grid.shape[1] else: self.resolution = env_map.resolution self.x_width = round((self.max_x - self.origin_x) / self.resolution) self.y_width = round((self.max_y - self.origin_y) / self.resolution) self.obstacle_list = env_map.obstacle_list[:]
[docs] def planning( self, start_pose: np.ndarray, goal_pose: np.ndarray, show_animation: bool = True, ) -> np.ndarray | None: """ JPS path search. Args: start_pose (np.ndarray): start pose [x, y] goal_pose (np.ndarray): goal pose [x, y] show_animation (bool): If true, shows the animation of planning process Returns: np.ndarray | None: shape (2, N) array [rx, ry] of the final path, or None if the start or goal cell is not walkable, or if no path exists (open set exhausted). """ start_pose = np.asarray(start_pose, dtype=float).flatten() goal_pose = np.asarray(goal_pose, dtype=float).flatten() sx = self.calc_xy_index(float(start_pose[0]), self.origin_x) sy = self.calc_xy_index(float(start_pose[1]), self.origin_y) start_node = _JpsNode(sx, sy, 0.0, -1, 0, 0) gx = self.calc_xy_index(float(goal_pose[0]), self.origin_x) gy = self.calc_xy_index(float(goal_pose[1]), self.origin_y) goal_node = _JpsNode(gx, gy, 0.0, -1) if not self._is_walkable(start_node.x, start_node.y): return None if not self._is_walkable(goal_node.x, goal_node.y): return None open_set: dict[int, _JpsNode] = {} closed_set: dict[int, _JpsNode] = {} open_set[self.calc_grid_index_from_xy(start_node.x, start_node.y)] = start_node while open_set: c_id = min( open_set, key=lambda o: ( open_set[o].cost + self._heuristic( goal_node.x, goal_node.y, open_set[o].x, open_set[o].y ) ), ) current = open_set[c_id] if show_animation: # pragma: no cover plt.plot( self.calc_grid_position(current.x, self.origin_x), self.calc_grid_position(current.y, self.origin_y), "xc", ) plt.gcf().canvas.mpl_connect( "key_release_event", lambda event: ( plt.close(event.canvas.figure) if event.key == "escape" else None ), ) if len(closed_set) % 10 == 0: plt.pause(0.01) if current.x == goal_node.x and current.y == goal_node.y: print("Find goal") goal_node.parent_index = current.parent_index goal_node.cost = current.cost break del open_set[c_id] closed_set[c_id] = current for (jx, jy, dx, dy), move_cost in self._get_jps_successors( current, goal_node.x, goal_node.y ): node = _JpsNode(jx, jy, current.cost + move_cost, c_id, dx, dy) n_id = self.calc_grid_index_from_xy(jx, jy) if n_id in closed_set: continue if n_id not in open_set or open_set[n_id].cost > node.cost: open_set[n_id] = node if goal_node.parent_index == -1: print("Open set is empty..") return None rx, ry = self._calc_final_path(goal_node, closed_set) return np.array([rx, ry])
def _get_jps_successors( self, current: _JpsNode, gx: int, gy: int ) -> list[JpsSuccessor]: """Return list of ((jx, jy, dx, dy), cost) for each jump point successor (jps3d getJpsSucc style).""" dx, dy = current.dx, current.dy norm1 = abs(dx) + abs(dy) num_neib, num_fneib = _JPS2D_NSZ[norm1] id_ = _dir_id(dx, dy) x, y = current.x, current.y out: list[JpsSuccessor] = [] for dev in range(num_neib + num_fneib): if dev < num_neib: ddx = _JPS2D_NS[id_][0][dev] ddy = _JPS2D_NS[id_][1][dev] if (jp := self._jump(x, y, ddx, ddy, gx, gy)) is None: continue jx, jy = jp else: fn = dev - num_neib nx = x + _JPS2D_F1[id_][0][fn] ny = y + _JPS2D_F1[id_][1][fn] if not self._is_occupied(nx, ny): continue ddx = _JPS2D_F2[id_][0][fn] ddy = _JPS2D_F2[id_][1][fn] if (jp := self._jump(x, y, ddx, ddy, gx, gy)) is None: continue jx, jy = jp cost = math.hypot(jx - x, jy - y) out.append(((jx, jy, ddx, ddy), cost)) return out def _jump( self, x: int, y: int, dx: int, dy: int, gx: int, gy: int, ) -> tuple[int, int] | None: """Jump from (x,y) in direction (dx,dy); return (jx, jy) or None. Uses iteration along the primary direction to avoid recursion depth limits.""" nx, ny = x + dx, y + dy while True: if not self._is_walkable(nx, ny): return None # Match A*: no corner-cutting check (A* only verifies the target cell). if (nx, ny) == (gx, gy) or self._has_forced(nx, ny, dx, dy): return (nx, ny) id_ = _dir_id(dx, dy) norm1 = abs(dx) + abs(dy) num_neib = _JPS2D_NSZ[norm1][0] for k in range(num_neib - 1): ddx = _JPS2D_NS[id_][0][k] ddy = _JPS2D_NS[id_][1][k] if self._jump(nx, ny, ddx, ddy, gx, gy) is not None: return (nx, ny) next_nx, next_ny = nx + dx, ny + dy if not self._is_walkable(next_nx, next_ny): return (nx, ny) nx, ny = next_nx, next_ny def _has_forced(self, x: int, y: int, dx: int, dy: int) -> bool: """True if (x,y) has a forced neighbor when approached along (dx,dy); uses f1 table.""" id_ = _dir_id(dx, dy) for fn in range(_JPS2D_NSZ[abs(dx) + abs(dy)][1]): nx = x + _JPS2D_F1[id_][0][fn] ny = y + _JPS2D_F1[id_][1][fn] if self._is_occupied(nx, ny): return True return False def _is_occupied(self, ix: int, iy: int) -> bool: """True if grid cell (ix, iy) is in bounds and occupied.""" if ix < 0 or iy < 0 or ix >= self.x_width or iy >= self.y_width: return False px = self.calc_grid_position(ix, self.origin_x) py = self.calc_grid_position(iy, self.origin_y) return self.is_collision(px, py) def _is_walkable(self, ix: int, iy: int) -> bool: """True if grid cell (ix, iy) is in bounds and not in collision.""" if ix < 0 or iy < 0 or ix >= self.x_width or iy >= self.y_width: return False px = self.calc_grid_position(ix, self.origin_x) py = self.calc_grid_position(iy, self.origin_y) return not self.is_collision(px, py) def _heuristic(self, gx: int, gy: int, x: int, y: int) -> float: """Octile heuristic for 8-directional grid.""" dx = abs(gx - x) dy = abs(gy - y) return (dx + dy) + (math.sqrt(2) - 1) * min(dx, dy) def _calc_final_path( self, goal_node: _JpsNode, closed_set: dict[int, _JpsNode] ) -> tuple[list[float], list[float]]: """Build the final path with intermediate cells between jump points. JPS only stores jump points in ``closed_set``. Between consecutive jump points, the path follows one of the 8 grid directions, so we interpolate all intermediate grid cells to produce a dense trajectory that stays on verified-walkable cells. """ waypoints: list[tuple[int, int]] = [(goal_node.x, goal_node.y)] idx = goal_node.parent_index while idx != -1: n = closed_set[idx] waypoints.append((n.x, n.y)) idx = n.parent_index rx, ry = [], [] for (cx, cy), (px, py) in itertools.pairwise(waypoints): ddx = 0 if px == cx else (1 if px > cx else -1) ddy = 0 if py == cy else (1 if py > cy else -1) ix, iy = cx, cy while (ix, iy) != (px, py): rx.append(self.calc_grid_position(ix, self.origin_x)) ry.append(self.calc_grid_position(iy, self.origin_y)) ix += ddx iy += ddy lx, ly = waypoints[-1] rx.append(self.calc_grid_position(lx, self.origin_x)) ry.append(self.calc_grid_position(ly, self.origin_y)) return rx, ry
[docs] def calc_grid_position(self, index: int, min_position: float) -> float: return index * self.resolution + min_position
[docs] def calc_xy_index(self, position: float, min_pos: float) -> int: return round((position - min_pos) / self.resolution)
[docs] def calc_grid_index_from_xy(self, x: int, y: int) -> int: return y * self.x_width + x
[docs] def is_collision(self, x: float, y: float) -> bool: """True if world position ``(x, y)`` is in collision.""" shape = { "name": "rectangle", "length": self.resolution, "width": self.resolution, } geometry = GeometryFactory.create_geometry(**shape).step(np.array([[x, y]]).T) return self._map.is_collision(geometry)