veloren_common/
astar.rs

1use crate::path::Path;
2use core::{
3    cmp::Ordering::{self, Equal},
4    fmt,
5    hash::{BuildHasher, Hash},
6};
7use hashbrown::HashMap;
8use std::collections::BinaryHeap;
9
10#[derive(Copy, Clone, Debug)]
11pub struct PathEntry<S> {
12    // cost so far + heursitic
13    cost_estimate: f32,
14    node: S,
15}
16
17impl<S: Eq> PartialEq for PathEntry<S> {
18    fn eq(&self, other: &PathEntry<S>) -> bool { self.node.eq(&other.node) }
19}
20
21impl<S: Eq> Eq for PathEntry<S> {}
22
23impl<S: Eq> Ord for PathEntry<S> {
24    // This method implements reverse ordering, so that the lowest cost
25    // will be ordered first
26    fn cmp(&self, other: &PathEntry<S>) -> Ordering {
27        other
28            .cost_estimate
29            .partial_cmp(&self.cost_estimate)
30            .unwrap_or(Equal)
31    }
32}
33
34impl<S: Eq> PartialOrd for PathEntry<S> {
35    fn partial_cmp(&self, other: &PathEntry<S>) -> Option<Ordering> { Some(self.cmp(other)) }
36
37    // This is particularily hot in `BinaryHeap::pop`, so we provide this
38    // implementation.
39    //
40    // NOTE: This probably doesn't handle edge cases like `NaNs` in a consistent
41    // manner with `Ord`, but I don't think we need to care about that here(?)
42    //
43    // See note about reverse ordering above.
44    fn le(&self, other: &PathEntry<S>) -> bool { other.cost_estimate <= self.cost_estimate }
45}
46
47pub enum PathResult<T> {
48    /// No reachable nodes were satisfactory.
49    ///
50    /// Contains path to node with the lowest heuristic value (out of the
51    /// explored nodes).
52    None(Path<T>),
53    /// Either max_iters or max_cost was reached.
54    ///
55    /// Contains path to node with the lowest heuristic value (out of the
56    /// explored nodes).
57    Exhausted(Path<T>),
58    /// Path succefully found.
59    ///
60    /// Second field is cost.
61    Path(Path<T>, f32),
62    Pending,
63}
64
65impl<T> PathResult<T> {
66    /// Returns `Some((path, cost))` if a path reaching the target was
67    /// successfully found.
68    pub fn into_path(self) -> Option<(Path<T>, f32)> {
69        match self {
70            PathResult::Path(path, cost) => Some((path, cost)),
71            _ => None,
72        }
73    }
74
75    pub fn map<U>(self, f: impl FnOnce(Path<T>) -> Path<U>) -> PathResult<U> {
76        match self {
77            PathResult::None(p) => PathResult::None(f(p)),
78            PathResult::Exhausted(p) => PathResult::Exhausted(f(p)),
79            PathResult::Path(p, cost) => PathResult::Path(f(p), cost),
80            PathResult::Pending => PathResult::Pending,
81        }
82    }
83}
84
85// If node entry exists, this was visited!
86#[derive(Clone, Debug)]
87struct NodeEntry<S> {
88    /// Previous node in the cheapest path (known so far) that goes from the
89    /// start to this node.
90    ///
91    /// If `came_from == self` this is the start node! (to avoid inflating the
92    /// size with `Option`)
93    came_from: S,
94    /// Cost to reach this node from the start by following the cheapest path
95    /// known so far. This is the sum of the transition costs between all the
96    /// nodes on this path.
97    cost: f32,
98}
99
100#[derive(Clone)]
101pub struct Astar<S, Hasher> {
102    iter: usize,
103    max_iters: usize,
104    max_cost: f32,
105    potential_nodes: BinaryHeap<PathEntry<S>>, // cost, node pairs
106    visited_nodes: HashMap<S, NodeEntry<S>, Hasher>,
107    /// Node with the lowest heuristic value so far.
108    ///
109    /// (node, heuristic value)
110    closest_node: Option<(S, f32)>,
111}
112
113/// NOTE: Must manually derive since Hasher doesn't implement it.
114impl<S: Clone + Eq + Hash + fmt::Debug, H: BuildHasher> fmt::Debug for Astar<S, H> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("Astar")
117            .field("iter", &self.iter)
118            .field("max_iters", &self.max_iters)
119            .field("potential_nodes", &self.potential_nodes)
120            .field("visited_nodes", &self.visited_nodes)
121            .field("closest_node", &self.closest_node)
122            .finish()
123    }
124}
125
126impl<S: Clone + Eq + Hash, H: BuildHasher + Clone> Astar<S, H> {
127    pub fn new(max_iters: usize, start: S, hasher: H) -> Self {
128        Self {
129            max_iters,
130            max_cost: f32::MAX,
131            iter: 0,
132            potential_nodes: core::iter::once(PathEntry {
133                cost_estimate: 0.0,
134                node: start.clone(),
135            })
136            .collect(),
137            visited_nodes: {
138                let mut s = HashMap::with_capacity_and_hasher(1, hasher);
139                s.extend(core::iter::once((start.clone(), NodeEntry {
140                    came_from: start,
141                    cost: 0.0,
142                })));
143                s
144            },
145            closest_node: None,
146        }
147    }
148
149    pub fn with_max_cost(mut self, max_cost: f32) -> Self {
150        self.max_cost = max_cost;
151        self
152    }
153
154    /// To guarantee an optimal path the heuristic function needs to be
155    /// [admissible](https://en.wikipedia.org/wiki/A*_search_algorithm#Admissibility).
156    pub fn poll<I>(
157        &mut self,
158        iters: usize,
159        // Estimate how far we are from the target.
160        mut heuristic: impl FnMut(&S) -> f32,
161        // get neighboring nodes
162        mut neighbors: impl FnMut(&S) -> I,
163        // have we reached target?
164        mut satisfied: impl FnMut(&S) -> bool,
165    ) -> PathResult<S>
166    where
167        I: Iterator<Item = (S, f32)>, // (node, transition cost)
168    {
169        let iter_limit = self.max_iters.min(self.iter + iters);
170        while self.iter < iter_limit {
171            if let Some(PathEntry {
172                node,
173                cost_estimate,
174            }) = self.potential_nodes.pop()
175            {
176                let (node_cost, came_from) = self
177                    .visited_nodes
178                    .get(&node)
179                    .map(|n| (n.cost, n.came_from.clone()))
180                    .expect("All nodes in the queue should be included in visisted_nodes");
181
182                if satisfied(&node) {
183                    return PathResult::Path(self.reconstruct_path_to(node), node_cost);
184                // Note, we assume that cost_estimate isn't an overestimation
185                // (i.e. that `heuristic` doesn't overestimate).
186                } else if cost_estimate > self.max_cost {
187                    return PathResult::Exhausted(
188                        self.closest_node
189                            .clone()
190                            .map(|(lc, _)| self.reconstruct_path_to(lc))
191                            .unwrap_or_default(),
192                    );
193                } else {
194                    for (neighbor, transition_cost) in neighbors(&node) {
195                        if neighbor == came_from {
196                            continue;
197                        }
198                        let neighbor_cost = self
199                            .visited_nodes
200                            .get(&neighbor)
201                            .map_or(f32::MAX, |n| n.cost);
202
203                        // compute cost to traverse to each neighbor
204                        let cost = node_cost + transition_cost;
205
206                        if cost < neighbor_cost {
207                            let previously_visited = self
208                                .visited_nodes
209                                .insert(neighbor.clone(), NodeEntry {
210                                    came_from: node.clone(),
211                                    cost,
212                                })
213                                .is_some();
214                            let h = heuristic(&neighbor);
215                            // note that `cost` field does not include the heuristic
216                            // priority queue does include heuristic
217                            let cost_estimate = cost + h;
218
219                            if self
220                                .closest_node
221                                .as_ref()
222                                .map(|&(_, ch)| h < ch)
223                                .unwrap_or(true)
224                            {
225                                self.closest_node = Some((node.clone(), h));
226                            };
227
228                            // We don't need to reconsider already visited nodes as astar finds the
229                            // shortest path to a node the first time it's visited, assuming the
230                            // heuristic function is admissible.
231                            if !previously_visited {
232                                self.potential_nodes.push(PathEntry {
233                                    cost_estimate,
234                                    node: neighbor,
235                                });
236                            }
237                        }
238                    }
239                }
240            } else {
241                return PathResult::None(
242                    self.closest_node
243                        .clone()
244                        .map(|(lc, _)| self.reconstruct_path_to(lc))
245                        .unwrap_or_default(),
246                );
247            }
248
249            self.iter += 1
250        }
251
252        if self.iter >= self.max_iters {
253            PathResult::Exhausted(
254                self.closest_node
255                    .clone()
256                    .map(|(lc, _)| self.reconstruct_path_to(lc))
257                    .unwrap_or_default(),
258            )
259        } else {
260            PathResult::Pending
261        }
262    }
263
264    fn reconstruct_path_to(&mut self, end: S) -> Path<S> {
265        let mut path = vec![end.clone()];
266        let mut cnode = &end;
267        while let Some(node) = self
268            .visited_nodes
269            .get(cnode)
270            .map(|n| &n.came_from)
271            .filter(|n| *n != cnode)
272        {
273            path.push(node.clone());
274            cnode = node;
275        }
276        path.into_iter().rev().collect()
277    }
278}