1use crate::path::Path;
2use core::{
3 cmp::Ordering::{self, Equal},
4 fmt,
5 hash::{BuildHasher, Hash},
6};
7use hashbrown::HashMap;
8use std::collections::BinaryHeap;
9
10#[derive(Copy, Clone, Debug)]
11pub struct PathEntry<S> {
12 cost_estimate: f32,
14 node: S,
15}
16
17impl<S: Eq> PartialEq for PathEntry<S> {
18 fn eq(&self, other: &PathEntry<S>) -> bool { self.node.eq(&other.node) }
19}
20
21impl<S: Eq> Eq for PathEntry<S> {}
22
23impl<S: Eq> Ord for PathEntry<S> {
24 fn cmp(&self, other: &PathEntry<S>) -> Ordering {
27 other
28 .cost_estimate
29 .partial_cmp(&self.cost_estimate)
30 .unwrap_or(Equal)
31 }
32}
33
34impl<S: Eq> PartialOrd for PathEntry<S> {
35 fn partial_cmp(&self, other: &PathEntry<S>) -> Option<Ordering> { Some(self.cmp(other)) }
36
37 fn le(&self, other: &PathEntry<S>) -> bool { other.cost_estimate <= self.cost_estimate }
45}
46
47pub enum PathResult<T> {
48 None(Path<T>),
53 Exhausted(Path<T>),
58 Path(Path<T>, f32),
62 Pending,
63}
64
65impl<T> PathResult<T> {
66 pub fn into_path(self) -> Option<(Path<T>, f32)> {
69 match self {
70 PathResult::Path(path, cost) => Some((path, cost)),
71 _ => None,
72 }
73 }
74
75 pub fn map<U>(self, f: impl FnOnce(Path<T>) -> Path<U>) -> PathResult<U> {
76 match self {
77 PathResult::None(p) => PathResult::None(f(p)),
78 PathResult::Exhausted(p) => PathResult::Exhausted(f(p)),
79 PathResult::Path(p, cost) => PathResult::Path(f(p), cost),
80 PathResult::Pending => PathResult::Pending,
81 }
82 }
83}
84
85#[derive(Clone, Debug)]
87struct NodeEntry<S> {
88 came_from: S,
94 cost: f32,
98}
99
100#[derive(Clone)]
101pub struct Astar<S, Hasher> {
102 iter: usize,
103 max_iters: usize,
104 max_cost: f32,
105 potential_nodes: BinaryHeap<PathEntry<S>>, visited_nodes: HashMap<S, NodeEntry<S>, Hasher>,
107 closest_node: Option<(S, f32)>,
111}
112
113impl<S: Clone + Eq + Hash + fmt::Debug, H: BuildHasher> fmt::Debug for Astar<S, H> {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("Astar")
117 .field("iter", &self.iter)
118 .field("max_iters", &self.max_iters)
119 .field("potential_nodes", &self.potential_nodes)
120 .field("visited_nodes", &self.visited_nodes)
121 .field("closest_node", &self.closest_node)
122 .finish()
123 }
124}
125
126impl<S: Clone + Eq + Hash, H: BuildHasher + Clone> Astar<S, H> {
127 pub fn new(max_iters: usize, start: S, hasher: H) -> Self {
128 Self {
129 max_iters,
130 max_cost: f32::MAX,
131 iter: 0,
132 potential_nodes: core::iter::once(PathEntry {
133 cost_estimate: 0.0,
134 node: start.clone(),
135 })
136 .collect(),
137 visited_nodes: {
138 let mut s = HashMap::with_capacity_and_hasher(1, hasher);
139 s.extend(core::iter::once((start.clone(), NodeEntry {
140 came_from: start,
141 cost: 0.0,
142 })));
143 s
144 },
145 closest_node: None,
146 }
147 }
148
149 pub fn with_max_cost(mut self, max_cost: f32) -> Self {
150 self.max_cost = max_cost;
151 self
152 }
153
154 pub fn poll<I>(
157 &mut self,
158 iters: usize,
159 mut heuristic: impl FnMut(&S) -> f32,
161 mut neighbors: impl FnMut(&S) -> I,
163 mut satisfied: impl FnMut(&S) -> bool,
165 ) -> PathResult<S>
166 where
167 I: Iterator<Item = (S, f32)>, {
169 let iter_limit = self.max_iters.min(self.iter + iters);
170 while self.iter < iter_limit {
171 if let Some(PathEntry {
172 node,
173 cost_estimate,
174 }) = self.potential_nodes.pop()
175 {
176 let (node_cost, came_from) = self
177 .visited_nodes
178 .get(&node)
179 .map(|n| (n.cost, n.came_from.clone()))
180 .expect("All nodes in the queue should be included in visisted_nodes");
181
182 if satisfied(&node) {
183 return PathResult::Path(self.reconstruct_path_to(node), node_cost);
184 } else if cost_estimate > self.max_cost {
187 return PathResult::Exhausted(
188 self.closest_node
189 .clone()
190 .map(|(lc, _)| self.reconstruct_path_to(lc))
191 .unwrap_or_default(),
192 );
193 } else {
194 for (neighbor, transition_cost) in neighbors(&node) {
195 if neighbor == came_from {
196 continue;
197 }
198 let neighbor_cost = self
199 .visited_nodes
200 .get(&neighbor)
201 .map_or(f32::MAX, |n| n.cost);
202
203 let cost = node_cost + transition_cost;
205
206 if cost < neighbor_cost {
207 let previously_visited = self
208 .visited_nodes
209 .insert(neighbor.clone(), NodeEntry {
210 came_from: node.clone(),
211 cost,
212 })
213 .is_some();
214 let h = heuristic(&neighbor);
215 let cost_estimate = cost + h;
218
219 if self
220 .closest_node
221 .as_ref()
222 .map(|&(_, ch)| h < ch)
223 .unwrap_or(true)
224 {
225 self.closest_node = Some((node.clone(), h));
226 };
227
228 if !previously_visited {
232 self.potential_nodes.push(PathEntry {
233 cost_estimate,
234 node: neighbor,
235 });
236 }
237 }
238 }
239 }
240 } else {
241 return PathResult::None(
242 self.closest_node
243 .clone()
244 .map(|(lc, _)| self.reconstruct_path_to(lc))
245 .unwrap_or_default(),
246 );
247 }
248
249 self.iter += 1
250 }
251
252 if self.iter >= self.max_iters {
253 PathResult::Exhausted(
254 self.closest_node
255 .clone()
256 .map(|(lc, _)| self.reconstruct_path_to(lc))
257 .unwrap_or_default(),
258 )
259 } else {
260 PathResult::Pending
261 }
262 }
263
264 fn reconstruct_path_to(&mut self, end: S) -> Path<S> {
265 let mut path = vec![end.clone()];
266 let mut cnode = &end;
267 while let Some(node) = self
268 .visited_nodes
269 .get(cnode)
270 .map(|n| &n.came_from)
271 .filter(|n| *n != cnode)
272 {
273 path.push(node.clone());
274 cnode = node;
275 }
276 path.into_iter().rev().collect()
277 }
278}