veloren_common/
astar.rs

1use crate::path::Path;
2use core::{
3    cmp::Ordering::{self, Equal},
4    fmt,
5    hash::{BuildHasher, Hash},
6};
7use hashbrown::HashMap;
8use std::collections::BinaryHeap;
9
10#[derive(Copy, Clone, Debug)]
11pub struct PathEntry<S> {
12    // cost so far + heursitic
13    cost_estimate: f32,
14    node: S,
15}
16
17impl<S: Eq> PartialEq for PathEntry<S> {
18    fn eq(&self, other: &PathEntry<S>) -> bool { self.node.eq(&other.node) }
19}
20
21impl<S: Eq> Eq for PathEntry<S> {}
22
23impl<S: Eq> Ord for PathEntry<S> {
24    // This method implements reverse ordering, so that the lowest cost
25    // will be ordered first
26    fn cmp(&self, other: &PathEntry<S>) -> Ordering {
27        other
28            .cost_estimate
29            .partial_cmp(&self.cost_estimate)
30            .unwrap_or(Equal)
31    }
32}
33
34impl<S: Eq> PartialOrd for PathEntry<S> {
35    fn partial_cmp(&self, other: &PathEntry<S>) -> Option<Ordering> { Some(self.cmp(other)) }
36
37    // This is particularily hot in `BinaryHeap::pop`, so we provide this
38    // implementation.
39    //
40    // NOTE: This probably doesn't handle edge cases like `NaNs` in a consistent
41    // manner with `Ord`, but I don't think we need to care about that here(?)
42    //
43    // See note about reverse ordering above.
44    fn le(&self, other: &PathEntry<S>) -> bool { other.cost_estimate <= self.cost_estimate }
45}
46
47pub enum PathResult<T> {
48    /// No reachable nodes were satisfactory.
49    ///
50    /// Contains path to node with the lowest heuristic value (out of the
51    /// explored nodes).
52    None(Path<T>),
53    /// Either max_iters or max_cost was reached.
54    ///
55    /// Contains path to node with the lowest heuristic value (out of the
56    /// explored nodes).
57    Exhausted(Path<T>),
58    /// Path succefully found.
59    ///
60    /// Second field is cost.
61    Path(Path<T>, f32),
62    Pending,
63}
64
65impl<T> PathResult<T> {
66    /// Returns `Some((path, cost))` if a path reaching the target was
67    /// successfully found.
68    pub fn into_path(self) -> Option<(Path<T>, f32)> {
69        match self {
70            PathResult::Path(path, cost) => Some((path, cost)),
71            _ => None,
72        }
73    }
74
75    pub fn map<U>(self, f: impl FnOnce(Path<T>) -> Path<U>) -> PathResult<U> {
76        match self {
77            PathResult::None(p) => PathResult::None(f(p)),
78            PathResult::Exhausted(p) => PathResult::Exhausted(f(p)),
79            PathResult::Path(p, cost) => PathResult::Path(f(p), cost),
80            PathResult::Pending => PathResult::Pending,
81        }
82    }
83}
84
85// If node entry exists, this was visited!
86#[derive(Clone, Debug)]
87struct NodeEntry<S> {
88    /// Previous node in the cheapest path (known so far) that goes from the
89    /// start to this node.
90    ///
91    /// If `came_from == self` this is the start node! (to avoid inflating the
92    /// size with `Option`)
93    came_from: S,
94    /// Cost to reach this node from the start by following the cheapest path
95    /// known so far. This is the sum of the transition costs between all the
96    /// nodes on this path.
97    cost: f32,
98}
99
100#[derive(Clone)]
101pub struct Astar<S, Hasher> {
102    iter: usize,
103    max_iters: usize,
104    max_cost: f32,
105    potential_nodes: BinaryHeap<PathEntry<S>>, // cost, node pairs
106    visited_nodes: HashMap<S, NodeEntry<S>, Hasher>,
107    /// Node with the lowest heuristic value so far.
108    ///
109    /// (node, heuristic value)
110    closest_node: Option<(S, f32)>,
111}
112
113/// NOTE: Must manually derive since Hasher doesn't implement it.
114impl<S: Clone + Eq + Hash + fmt::Debug, H: BuildHasher> fmt::Debug for Astar<S, H> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("Astar")
117            .field("iter", &self.iter)
118            .field("max_iters", &self.max_iters)
119            .field("potential_nodes", &self.potential_nodes)
120            .field("visited_nodes", &self.visited_nodes)
121            .field("closest_node", &self.closest_node)
122            .finish()
123    }
124}
125
126impl<S: Clone + Eq + Hash, H: BuildHasher + Clone> Astar<S, H> {
127    pub fn new(max_iters: usize, start: S, hasher: H) -> Self {
128        Self {
129            max_iters,
130            max_cost: f32::MAX,
131            iter: 0,
132            potential_nodes: core::iter::once(PathEntry {
133                cost_estimate: 0.0,
134                node: start.clone(),
135            })
136            .collect(),
137            visited_nodes: {
138                let mut s = HashMap::with_capacity_and_hasher(1, hasher);
139                s.extend(core::iter::once((start.clone(), NodeEntry {
140                    came_from: start,
141                    cost: 0.0,
142                })));
143                s
144            },
145            closest_node: None,
146        }
147    }
148
149    pub fn with_max_cost(mut self, max_cost: f32) -> Self {
150        self.max_cost = max_cost;
151        self
152    }
153
154    pub fn set_max_iters(&mut self, max_iters: usize) { self.max_iters = max_iters; }
155
156    /// To guarantee an optimal path the heuristic function needs to be
157    /// [admissible](https://en.wikipedia.org/wiki/A*_search_algorithm#Admissibility).
158    pub fn poll<I>(
159        &mut self,
160        iters: usize,
161        // Estimate how far we are from the target.
162        mut heuristic: impl FnMut(&S) -> f32,
163        // get neighboring nodes
164        mut neighbors: impl FnMut(&S) -> I,
165        // have we reached target?
166        mut satisfied: impl FnMut(&S) -> bool,
167    ) -> PathResult<S>
168    where
169        I: Iterator<Item = (S, f32)>, // (node, transition cost)
170    {
171        let iter_limit = self.max_iters.min(self.iter + iters);
172        while self.iter < iter_limit {
173            if let Some(PathEntry {
174                node,
175                cost_estimate,
176            }) = self.potential_nodes.pop()
177            {
178                let (node_cost, came_from) = self
179                    .visited_nodes
180                    .get(&node)
181                    .map(|n| (n.cost, n.came_from.clone()))
182                    .expect("All nodes in the queue should be included in visisted_nodes");
183
184                if satisfied(&node) {
185                    return PathResult::Path(self.reconstruct_path_to(node), node_cost);
186                // Note, we assume that cost_estimate isn't an overestimation
187                // (i.e. that `heuristic` doesn't overestimate).
188                } else if cost_estimate > self.max_cost {
189                    return PathResult::Exhausted(
190                        self.closest_node
191                            .clone()
192                            .map(|(lc, _)| self.reconstruct_path_to(lc))
193                            .unwrap_or_default(),
194                    );
195                } else {
196                    for (neighbor, transition_cost) in neighbors(&node) {
197                        if neighbor == came_from {
198                            continue;
199                        }
200                        let neighbor_cost = self
201                            .visited_nodes
202                            .get(&neighbor)
203                            .map_or(f32::MAX, |n| n.cost);
204
205                        // compute cost to traverse to each neighbor
206                        let cost = node_cost + transition_cost;
207
208                        if cost < neighbor_cost {
209                            let previously_visited = self
210                                .visited_nodes
211                                .insert(neighbor.clone(), NodeEntry {
212                                    came_from: node.clone(),
213                                    cost,
214                                })
215                                .is_some();
216                            let h = heuristic(&neighbor);
217                            // note that `cost` field does not include the heuristic
218                            // priority queue does include heuristic
219                            let cost_estimate = cost + h;
220
221                            if self
222                                .closest_node
223                                .as_ref()
224                                .map(|&(_, ch)| h < ch)
225                                .unwrap_or(true)
226                            {
227                                self.closest_node = Some((node.clone(), h));
228                            };
229
230                            // We don't need to reconsider already visited nodes as astar finds the
231                            // shortest path to a node the first time it's visited, assuming the
232                            // heuristic function is admissible.
233                            if !previously_visited {
234                                self.potential_nodes.push(PathEntry {
235                                    cost_estimate,
236                                    node: neighbor,
237                                });
238                            }
239                        }
240                    }
241                }
242            } else {
243                return PathResult::None(
244                    self.closest_node
245                        .clone()
246                        .map(|(lc, _)| self.reconstruct_path_to(lc))
247                        .unwrap_or_default(),
248                );
249            }
250
251            self.iter += 1
252        }
253
254        if self.iter >= self.max_iters {
255            PathResult::Exhausted(
256                self.closest_node
257                    .clone()
258                    .map(|(lc, _)| self.reconstruct_path_to(lc))
259                    .unwrap_or_default(),
260            )
261        } else {
262            PathResult::Pending
263        }
264    }
265
266    fn reconstruct_path_to(&mut self, end: S) -> Path<S> {
267        let mut path = vec![end.clone()];
268        let mut cnode = &end;
269        while let Some(node) = self
270            .visited_nodes
271            .get(cnode)
272            .map(|n| &n.came_from)
273            .filter(|n| *n != cnode)
274        {
275            path.push(node.clone());
276            cnode = node;
277        }
278        path.into_iter().rev().collect()
279    }
280}