veloren_common/
slowjob.rs

1use hashbrown::HashMap;
2use rayon::ThreadPool;
3use std::{
4    collections::VecDeque,
5    sync::{Arc, Mutex},
6    time::Instant,
7};
8use tracing::{error, warn};
9
10/// Provides a Wrapper around rayon threadpool to execute slow-jobs.
11/// slow means, the job doesn't need to not complete within the same tick.
12/// DO NOT USE I/O blocking jobs, but only CPU heavy jobs.
13/// Jobs run here, will reduce the ammount of threads rayon can use during the
14/// main tick.
15///
16/// ## Configuration
17/// This Pool allows you to configure certain names of jobs and assign them a
18/// maximum number of threads # Example
19/// Your system has 16 cores, you assign 12 cores for slow-jobs.
20/// Then you can configure all jobs with the name `CHUNK_GENERATOR` to spawn on
21/// max 50% (6 = cores)
22///
23/// ## Spawn Order
24/// - At least 1 job of a configuration is allowed to run if global limit isn't
25///   hit.
26/// - remaining capacities are spread in relation to their limit. e.g. a
27///   configuration with double the limit will be sheduled to spawn double the
28///   tasks, starting by a round robin.
29///
30/// ## States
31/// - queued
32/// - spawned
33/// - started
34/// - finished
35/// ```
36/// # use veloren_common::slowjob::SlowJobPool;
37/// # use std::sync::Arc;
38///
39/// let threadpool = rayon::ThreadPoolBuilder::new()
40///     .num_threads(16)
41///     .build()
42///     .unwrap();
43/// let pool = SlowJobPool::new(3, 10, Arc::new(threadpool));
44/// pool.configure("CHUNK_GENERATOR", |n| n / 2);
45/// pool.spawn("CHUNK_GENERATOR", move || println!("this is a job"));
46/// ```
47#[derive(Clone)]
48pub struct SlowJobPool {
49    internal: Arc<Mutex<InternalSlowJobPool>>,
50}
51
52#[derive(Debug)]
53pub struct SlowJob {
54    name: String,
55    id: u64,
56}
57
58type JobType = Box<dyn FnOnce() + Send + Sync + 'static>;
59
60struct InternalSlowJobPool {
61    next_id: u64,
62    queue: HashMap<String, VecDeque<Queue>>,
63    configs: HashMap<String, Config>,
64    last_spawned_configs: Vec<String>,
65    global_spawned_and_running: u64,
66    global_limit: u64,
67    jobs_metrics_cnt: usize,
68    jobs_metrics: HashMap<String, Vec<JobMetrics>>,
69    threadpool: Arc<ThreadPool>,
70    internal: Option<Arc<Mutex<Self>>>,
71}
72
73#[derive(Debug)]
74struct Config {
75    local_limit: u64,
76    local_spawned_and_running: u64,
77}
78
79struct Queue {
80    id: u64,
81    name: String,
82    task: JobType,
83}
84
85pub struct JobMetrics {
86    pub queue_created: Instant,
87    pub execution_start: Instant,
88    pub execution_end: Instant,
89}
90
91impl Queue {
92    fn new<F>(name: &str, id: u64, internal: &Arc<Mutex<InternalSlowJobPool>>, f: F) -> Self
93    where
94        F: FnOnce() + Send + Sync + 'static,
95    {
96        let internal = Arc::clone(internal);
97        let name_cloned = name.to_owned();
98        let queue_created = Instant::now();
99        Self {
100            id,
101            name: name.to_owned(),
102            task: Box::new(move || {
103                common_base::prof_span_alloc!(_guard, &name_cloned);
104                let execution_start = Instant::now();
105                f();
106                let execution_end = Instant::now();
107                let metrics = JobMetrics {
108                    queue_created,
109                    execution_start,
110                    execution_end,
111                };
112                // directly maintain the next task afterwards
113                {
114                    let mut lock = internal.lock().expect("slowjob lock poisoned");
115                    lock.finish(&name_cloned, metrics);
116                    lock.spawn_queued();
117                }
118            }),
119        }
120    }
121}
122
123impl InternalSlowJobPool {
124    pub fn new(
125        global_limit: u64,
126        jobs_metrics_cnt: usize,
127        _threadpool: Arc<ThreadPool>,
128    ) -> Arc<Mutex<Self>> {
129        // rayon is having a bug where a ECS task could work-steal a slowjob if we use
130        // the same threadpool, which would cause lagspikes we dont want!
131        let threadpool = Arc::new(
132            rayon::ThreadPoolBuilder::new()
133                .num_threads(global_limit as usize)
134                .thread_name(move |i| format!("slowjob-{}", i))
135                .build()
136                .unwrap(),
137        );
138        let link = Arc::new(Mutex::new(Self {
139            next_id: 0,
140            queue: HashMap::new(),
141            configs: HashMap::new(),
142            last_spawned_configs: Vec::new(),
143            global_spawned_and_running: 0,
144            global_limit: global_limit.max(1),
145            jobs_metrics_cnt,
146            jobs_metrics: HashMap::new(),
147            threadpool,
148            internal: None,
149        }));
150
151        let link_clone = Arc::clone(&link);
152        link.lock()
153            .expect("poisoned on InternalSlowJobPool::new")
154            .internal = Some(link_clone);
155        link
156    }
157
158    /// returns order of configuration which are queued next
159    fn calc_queued_order(
160        &self,
161        mut queued: HashMap<&String, u64>,
162        mut limit: usize,
163    ) -> Vec<String> {
164        let mut roundrobin = self.last_spawned_configs.clone();
165        let mut result = vec![];
166        let spawned = self
167            .configs
168            .iter()
169            .map(|(n, c)| (n, c.local_spawned_and_running))
170            .collect::<HashMap<_, u64>>();
171        let mut queried_capped = self
172            .configs
173            .iter()
174            .map(|(n, c)| {
175                (
176                    n,
177                    queued
178                        .get(&n)
179                        .cloned()
180                        .unwrap_or(0)
181                        .min(c.local_limit - c.local_spawned_and_running),
182                )
183            })
184            .collect::<HashMap<_, _>>();
185        // grab all configs that are queued and not running. in roundrobin order
186        for n in roundrobin.clone().into_iter() {
187            if let Some(c) = queued.get_mut(&n) {
188                if *c > 0 && spawned.get(&n).cloned().unwrap_or(0) == 0 {
189                    result.push(n.clone());
190                    *c -= 1;
191                    limit -= 1;
192                    queried_capped.get_mut(&n).map(|v| *v -= 1);
193                    roundrobin
194                        .iter()
195                        .position(|e| e == &n)
196                        .map(|i| roundrobin.remove(i));
197                    roundrobin.push(n);
198                    if limit == 0 {
199                        return result;
200                    }
201                }
202            }
203        }
204        //schedule rest based on their possible limites, don't use round robin here
205        let total_limit = queried_capped.values().sum::<u64>() as f32;
206        if total_limit < f32::EPSILON {
207            return result;
208        }
209        let mut spawn_rates = queried_capped
210            .iter()
211            .map(|(&n, l)| (n, ((*l as f32 * limit as f32) / total_limit).min(*l as f32)))
212            .collect::<Vec<_>>();
213        while limit > 0 {
214            spawn_rates.sort_by(|(_, a), (_, b)| {
215                if b < a {
216                    core::cmp::Ordering::Less
217                } else if (b - a).abs() < f32::EPSILON {
218                    core::cmp::Ordering::Equal
219                } else {
220                    core::cmp::Ordering::Greater
221                }
222            });
223            match spawn_rates.first_mut() {
224                Some((n, r)) => {
225                    if *r > f32::EPSILON {
226                        result.push(n.clone());
227                        limit -= 1;
228                        *r -= 1.0;
229                    } else {
230                        break;
231                    }
232                },
233                None => break,
234            }
235        }
236        result
237    }
238
239    fn can_spawn(&self, name: &str) -> bool {
240        let queued = self
241            .queue
242            .iter()
243            .map(|(n, m)| (n, m.len() as u64))
244            .collect::<HashMap<_, u64>>();
245        let mut to_be_queued = queued.clone();
246        let name = name.to_owned();
247        *to_be_queued.entry(&name).or_default() += 1;
248        let limit = (self.global_limit - self.global_spawned_and_running) as usize;
249        // calculate to_be_queued first
250        let to_be_queued_order = self.calc_queued_order(to_be_queued, limit);
251        let queued_order = self.calc_queued_order(queued, limit);
252        // if its queued one time more then its okay to spawn
253        let to_be_queued_cnt = to_be_queued_order
254            .into_iter()
255            .filter(|n| n == &name)
256            .count();
257        let queued_cnt = queued_order.into_iter().filter(|n| n == &name).count();
258        to_be_queued_cnt > queued_cnt
259    }
260
261    pub fn spawn<F>(&mut self, name: &str, f: F) -> SlowJob
262    where
263        F: FnOnce() + Send + Sync + 'static,
264    {
265        let id = self.next_id;
266        self.next_id += 1;
267        let queue = Queue::new(name, id, self.internal.as_ref().expect("internal empty"), f);
268        self.queue
269            .entry(name.to_string())
270            .or_default()
271            .push_back(queue);
272        debug_assert!(
273            self.configs.contains_key(name),
274            "Can't spawn unconfigured task!"
275        );
276        //spawn already queued
277        self.spawn_queued();
278        SlowJob {
279            name: name.to_string(),
280            id,
281        }
282    }
283
284    fn finish(&mut self, name: &str, metrics: JobMetrics) {
285        let metric = self.jobs_metrics.entry(name.to_string()).or_default();
286
287        if metric.len() < self.jobs_metrics_cnt {
288            metric.push(metrics);
289        }
290        self.global_spawned_and_running -= 1;
291        if let Some(c) = self.configs.get_mut(name) {
292            c.local_spawned_and_running -= 1;
293        } else {
294            warn!(?name, "sync_maintain on a no longer existing config");
295        }
296    }
297
298    fn spawn_queued(&mut self) {
299        let queued = self
300            .queue
301            .iter()
302            .map(|(n, m)| (n, m.len() as u64))
303            .collect::<HashMap<_, u64>>();
304        let limit = self.global_limit as usize;
305        let queued_order = self.calc_queued_order(queued, limit);
306        for name in queued_order.into_iter() {
307            match self.queue.get_mut(&name) {
308                Some(deque) => match deque.pop_front() {
309                    Some(queue) => {
310                        //fire
311                        self.global_spawned_and_running += 1;
312                        self.configs
313                            .get_mut(&queue.name)
314                            .expect("cannot fire a unconfigured job")
315                            .local_spawned_and_running += 1;
316                        self.last_spawned_configs
317                            .iter()
318                            .position(|e| e == &queue.name)
319                            .map(|i| self.last_spawned_configs.remove(i));
320                        self.last_spawned_configs.push(queue.name.to_owned());
321                        self.threadpool.spawn(queue.task);
322                    },
323                    None => error!(
324                        "internal calculation is wrong, we extected a schedulable job to be \
325                         present in the queue"
326                    ),
327                },
328                None => error!(
329                    "internal calculation is wrong, we marked a queue as schedulable which \
330                     doesn't exist"
331                ),
332            }
333        }
334    }
335
336    pub fn take_metrics(&mut self) -> HashMap<String, Vec<JobMetrics>> {
337        core::mem::take(&mut self.jobs_metrics)
338    }
339}
340
341impl SlowJobPool {
342    pub fn new(global_limit: u64, jobs_metrics_cnt: usize, threadpool: Arc<ThreadPool>) -> Self {
343        Self {
344            internal: InternalSlowJobPool::new(global_limit, jobs_metrics_cnt, threadpool),
345        }
346    }
347
348    /// configure a NAME to spawn up to f(n) threads, depending on how many
349    /// threads we globally have available
350    pub fn configure<F>(&self, name: &str, f: F)
351    where
352        F: Fn(u64) -> u64,
353    {
354        let mut lock = self.internal.lock().expect("lock poisoned while configure");
355        let cnf = Config {
356            local_limit: f(lock.global_limit).max(1),
357            local_spawned_and_running: 0,
358        };
359        lock.configs.insert(name.to_owned(), cnf);
360        lock.last_spawned_configs.push(name.to_owned());
361    }
362
363    /// spawn a new slow job on a certain NAME IF it can run immediately
364    #[expect(clippy::result_unit_err)]
365    pub fn try_run<F>(&self, name: &str, f: F) -> Result<SlowJob, ()>
366    where
367        F: FnOnce() + Send + Sync + 'static,
368    {
369        let mut lock = self.internal.lock().expect("lock poisoned while try_run");
370        //spawn already queued
371        lock.spawn_queued();
372        if lock.can_spawn(name) {
373            Ok(lock.spawn(name, f))
374        } else {
375            Err(())
376        }
377    }
378
379    pub fn spawn<F>(&self, name: &str, f: F) -> SlowJob
380    where
381        F: FnOnce() + Send + Sync + 'static,
382    {
383        self.internal
384            .lock()
385            .expect("lock poisoned while spawn")
386            .spawn(name, f)
387    }
388
389    pub fn cancel(&self, job: SlowJob) -> Result<(), SlowJob> {
390        let mut lock = self.internal.lock().expect("lock poisoned while cancel");
391        if let Some(m) = lock.queue.get_mut(&job.name) {
392            let p = match m.iter().position(|p| p.id == job.id) {
393                Some(p) => p,
394                None => return Err(job),
395            };
396            if m.remove(p).is_some() {
397                return Ok(());
398            }
399        }
400        Err(job)
401    }
402
403    pub fn take_metrics(&self) -> HashMap<String, Vec<JobMetrics>> {
404        self.internal
405            .lock()
406            .expect("lock poisoned while take_metrics")
407            .take_metrics()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use std::{
415        sync::{
416            Barrier,
417            atomic::{AtomicBool, AtomicU64, Ordering},
418        },
419        time::Duration,
420    };
421
422    fn mock_pool(
423        pool_threads: usize,
424        global_threads: u64,
425        metrics: usize,
426        foo: u64,
427        bar: u64,
428        baz: u64,
429    ) -> SlowJobPool {
430        let threadpool = rayon::ThreadPoolBuilder::new()
431            .num_threads(pool_threads)
432            .build()
433            .unwrap();
434        let pool = SlowJobPool::new(global_threads, metrics, Arc::new(threadpool));
435        if foo != 0 {
436            pool.configure("FOO", |x| x / foo);
437        }
438        if bar != 0 {
439            pool.configure("BAR", |x| x / bar);
440        }
441        if baz != 0 {
442            pool.configure("BAZ", |x| x / baz);
443        }
444        pool
445    }
446
447    #[test]
448    fn simple_queue() {
449        let pool = mock_pool(4, 4, 0, 1, 0, 0);
450        let internal = pool.internal.lock().unwrap();
451        let queue_data = [("FOO", 1u64)]
452            .iter()
453            .map(|(n, c)| ((*n).to_owned(), *c))
454            .collect::<Vec<_>>();
455        let queued = queue_data
456            .iter()
457            .map(|(s, c)| (s, *c))
458            .collect::<HashMap<_, _>>();
459        let result = internal.calc_queued_order(queued, 4);
460        assert_eq!(result.len(), 1);
461        assert_eq!(result[0], "FOO");
462    }
463
464    #[test]
465    fn multiple_queue() {
466        let pool = mock_pool(4, 4, 0, 1, 0, 0);
467        let internal = pool.internal.lock().unwrap();
468        let queue_data = [("FOO", 2u64)]
469            .iter()
470            .map(|(n, c)| ((*n).to_owned(), *c))
471            .collect::<Vec<_>>();
472        let queued = queue_data
473            .iter()
474            .map(|(s, c)| (s, *c))
475            .collect::<HashMap<_, _>>();
476        let result = internal.calc_queued_order(queued, 4);
477        assert_eq!(result.len(), 2);
478        assert_eq!(result[0], "FOO");
479        assert_eq!(result[1], "FOO");
480    }
481
482    #[test]
483    fn limit_queue() {
484        let pool = mock_pool(5, 5, 0, 1, 0, 0);
485        let internal = pool.internal.lock().unwrap();
486        let queue_data = [("FOO", 80u64)]
487            .iter()
488            .map(|(n, c)| ((*n).to_owned(), *c))
489            .collect::<Vec<_>>();
490        let queued = queue_data
491            .iter()
492            .map(|(s, c)| (s, *c))
493            .collect::<HashMap<_, _>>();
494        let result = internal.calc_queued_order(queued, 4);
495        assert_eq!(result.len(), 4);
496        assert_eq!(result[0], "FOO");
497        assert_eq!(result[1], "FOO");
498        assert_eq!(result[2], "FOO");
499        assert_eq!(result[3], "FOO");
500    }
501
502    #[test]
503    fn simple_queue_2() {
504        let pool = mock_pool(4, 4, 0, 1, 1, 0);
505        let internal = pool.internal.lock().unwrap();
506        let queue_data = [("FOO", 1u64), ("BAR", 1u64)]
507            .iter()
508            .map(|(n, c)| ((*n).to_owned(), *c))
509            .collect::<Vec<_>>();
510        let queued = queue_data
511            .iter()
512            .map(|(s, c)| (s, *c))
513            .collect::<HashMap<_, _>>();
514        let result = internal.calc_queued_order(queued, 4);
515        assert_eq!(result.len(), 2);
516        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 1);
517        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 1);
518    }
519
520    #[test]
521    fn multiple_queue_3() {
522        let pool = mock_pool(4, 4, 0, 1, 1, 0);
523        let internal = pool.internal.lock().unwrap();
524        let queue_data = [("FOO", 2u64), ("BAR", 2u64)]
525            .iter()
526            .map(|(n, c)| ((*n).to_owned(), *c))
527            .collect::<Vec<_>>();
528        let queued = queue_data
529            .iter()
530            .map(|(s, c)| (s, *c))
531            .collect::<HashMap<_, _>>();
532        let result = internal.calc_queued_order(queued, 4);
533        assert_eq!(result.len(), 4);
534        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 2);
535        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 2);
536    }
537
538    #[test]
539    fn multiple_queue_4() {
540        let pool = mock_pool(4, 4, 0, 2, 1, 0);
541        let internal = pool.internal.lock().unwrap();
542        let queue_data = [("FOO", 3u64), ("BAR", 3u64)]
543            .iter()
544            .map(|(n, c)| ((*n).to_owned(), *c))
545            .collect::<Vec<_>>();
546        let queued = queue_data
547            .iter()
548            .map(|(s, c)| (s, *c))
549            .collect::<HashMap<_, _>>();
550        let result = internal.calc_queued_order(queued, 4);
551        assert_eq!(result.len(), 4);
552        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 2);
553        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 2);
554    }
555
556    #[test]
557    fn multiple_queue_5() {
558        let pool = mock_pool(4, 4, 0, 2, 1, 0);
559        let internal = pool.internal.lock().unwrap();
560        let queue_data = [("FOO", 5u64), ("BAR", 5u64)]
561            .iter()
562            .map(|(n, c)| ((*n).to_owned(), *c))
563            .collect::<Vec<_>>();
564        let queued = queue_data
565            .iter()
566            .map(|(s, c)| (s, *c))
567            .collect::<HashMap<_, _>>();
568        let result = internal.calc_queued_order(queued, 5);
569        assert_eq!(result.len(), 5);
570        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 2);
571        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 3);
572    }
573
574    #[test]
575    fn multiple_queue_6() {
576        let pool = mock_pool(40, 40, 0, 2, 1, 0);
577        let internal = pool.internal.lock().unwrap();
578        let queue_data = [("FOO", 5u64), ("BAR", 5u64)]
579            .iter()
580            .map(|(n, c)| ((*n).to_owned(), *c))
581            .collect::<Vec<_>>();
582        let queued = queue_data
583            .iter()
584            .map(|(s, c)| (s, *c))
585            .collect::<HashMap<_, _>>();
586        let result = internal.calc_queued_order(queued, 11);
587        assert_eq!(result.len(), 10);
588        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 5);
589        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 5);
590    }
591
592    #[test]
593    fn roundrobin() {
594        let pool = mock_pool(4, 4, 0, 2, 2, 0);
595        let queue_data = [("FOO", 5u64), ("BAR", 5u64)]
596            .iter()
597            .map(|(n, c)| ((*n).to_owned(), *c))
598            .collect::<Vec<_>>();
599        let queued = queue_data
600            .iter()
601            .map(|(s, c)| (s, *c))
602            .collect::<HashMap<_, _>>();
603        // Spawn a FOO task.
604        pool.internal
605            .lock()
606            .unwrap()
607            .spawn("FOO", || println!("foo"));
608        // a barrier in f doesnt work as we need to wait for the cleanup
609        while pool.internal.lock().unwrap().global_spawned_and_running != 0 {
610            std::thread::yield_now();
611        }
612        let result = pool
613            .internal
614            .lock()
615            .unwrap()
616            .calc_queued_order(queued.clone(), 1);
617        assert_eq!(result.len(), 1);
618        assert_eq!(result[0], "BAR");
619        // keep order if no new is spawned
620        let result = pool
621            .internal
622            .lock()
623            .unwrap()
624            .calc_queued_order(queued.clone(), 1);
625        assert_eq!(result.len(), 1);
626        assert_eq!(result[0], "BAR");
627        // spawn a BAR task
628        pool.internal
629            .lock()
630            .unwrap()
631            .spawn("BAR", || println!("bar"));
632        while pool.internal.lock().unwrap().global_spawned_and_running != 0 {
633            std::thread::yield_now();
634        }
635        let result = pool.internal.lock().unwrap().calc_queued_order(queued, 1);
636        assert_eq!(result.len(), 1);
637        assert_eq!(result[0], "FOO");
638    }
639
640    #[test]
641    #[should_panic]
642    fn unconfigured() {
643        let pool = mock_pool(4, 4, 0, 2, 1, 0);
644        let mut internal = pool.internal.lock().unwrap();
645        internal.spawn("UNCONFIGURED", || println!());
646    }
647
648    #[test]
649    fn correct_spawn_doesnt_panic() {
650        let pool = mock_pool(4, 4, 0, 2, 1, 0);
651        let mut internal = pool.internal.lock().unwrap();
652        internal.spawn("FOO", || println!("foo"));
653        internal.spawn("BAR", || println!("bar"));
654    }
655
656    #[test]
657    fn can_spawn() {
658        let pool = mock_pool(4, 4, 0, 2, 1, 0);
659        let internal = pool.internal.lock().unwrap();
660        assert!(internal.can_spawn("FOO"));
661        assert!(internal.can_spawn("BAR"));
662    }
663
664    #[test]
665    fn try_run_works() {
666        let pool = mock_pool(4, 4, 0, 2, 1, 0);
667        pool.try_run("FOO", || println!("foo")).unwrap();
668        pool.try_run("BAR", || println!("bar")).unwrap();
669    }
670
671    #[test]
672    fn try_run_exhausted() {
673        let pool = mock_pool(8, 8, 0, 4, 2, 0);
674        let func = || loop {
675            std::thread::sleep(Duration::from_secs(1))
676        };
677        pool.try_run("FOO", func).unwrap();
678        pool.try_run("BAR", func).unwrap();
679        pool.try_run("FOO", func).unwrap();
680        pool.try_run("BAR", func).unwrap();
681        pool.try_run("FOO", func).unwrap_err();
682        pool.try_run("BAR", func).unwrap();
683        pool.try_run("FOO", func).unwrap_err();
684        pool.try_run("BAR", func).unwrap();
685        pool.try_run("FOO", func).unwrap_err();
686        pool.try_run("BAR", func).unwrap_err();
687        pool.try_run("FOO", func).unwrap_err();
688    }
689
690    #[test]
691    fn actually_runs_1() {
692        let pool = mock_pool(4, 4, 0, 0, 0, 1);
693        let barrier = Arc::new(Barrier::new(2));
694        let barrier_clone = Arc::clone(&barrier);
695        pool.try_run("BAZ", move || {
696            barrier_clone.wait();
697        })
698        .unwrap();
699        barrier.wait();
700    }
701
702    #[test]
703    fn actually_runs_2() {
704        let pool = mock_pool(4, 4, 0, 0, 0, 1);
705        let barrier = Arc::new(Barrier::new(2));
706        let barrier_clone = Arc::clone(&barrier);
707        pool.spawn("BAZ", move || {
708            barrier_clone.wait();
709        });
710        barrier.wait();
711    }
712
713    #[test]
714    fn actually_waits() {
715        let pool = mock_pool(4, 4, 0, 4, 0, 1);
716        let ops_i_ran = Arc::new(AtomicBool::new(false));
717        let ops_i_ran_clone = Arc::clone(&ops_i_ran);
718        let barrier = Arc::new(Barrier::new(2));
719        let barrier_clone = Arc::clone(&barrier);
720        let barrier2 = Arc::new(Barrier::new(2));
721        let barrier2_clone = Arc::clone(&barrier2);
722        pool.try_run("FOO", move || {
723            barrier_clone.wait();
724        })
725        .unwrap();
726        pool.spawn("FOO", move || {
727            ops_i_ran_clone.store(true, Ordering::SeqCst);
728            barrier2_clone.wait();
729        });
730        // in this case we have to sleep
731        std::thread::sleep(Duration::from_secs(1));
732        assert!(!ops_i_ran.load(Ordering::SeqCst));
733        // now finish the first job
734        barrier.wait();
735        // now wait on the second job to be actually finished
736        barrier2.wait();
737    }
738
739    #[test]
740    fn verify_metrics() {
741        let pool = mock_pool(4, 4, 2, 1, 0, 4);
742        let barrier = Arc::new(Barrier::new(5));
743        for name in &["FOO", "BAZ", "FOO", "FOO"] {
744            let barrier_clone = Arc::clone(&barrier);
745            pool.spawn(name, move || {
746                barrier_clone.wait();
747            });
748        }
749        // now finish all jobs
750        barrier.wait();
751        // in this case we have to sleep to give it some time to store all the metrics
752        std::thread::sleep(Duration::from_secs(2));
753        let metrics = pool.take_metrics();
754        let foo = metrics.get("FOO").expect("FOO doesn't exist in metrics");
755        //its limited to 2, even though we had 3 jobs
756        assert_eq!(foo.len(), 2);
757        assert!(metrics.get("BAR").is_none());
758        let baz = metrics.get("BAZ").expect("BAZ doesn't exist in metrics");
759        assert_eq!(baz.len(), 1);
760    }
761
762    fn work_barrier(counter: &Arc<AtomicU64>, ms: u64) -> impl std::ops::FnOnce() -> () {
763        let counter = Arc::clone(counter);
764        println!("Create work_barrier");
765        move || {
766            println!(".{}..", ms);
767            std::thread::sleep(Duration::from_millis(ms));
768            println!(".{}..Done", ms);
769            counter.fetch_add(1, Ordering::SeqCst);
770        }
771    }
772
773    #[test]
774    fn verify_that_spawn_doesnt_block_par_iter() {
775        let threadpool = Arc::new(
776            rayon::ThreadPoolBuilder::new()
777                .num_threads(20)
778                .build()
779                .unwrap(),
780        );
781        let pool = SlowJobPool::new(2, 100, Arc::<rayon::ThreadPool>::clone(&threadpool));
782        pool.configure("BAZ", |_| 2);
783        let counter = Arc::new(AtomicU64::new(0));
784        let start = Instant::now();
785
786        threadpool.install(|| {
787            use rayon::prelude::*;
788            (0..100)
789                .into_par_iter()
790                .map(|i| {
791                    std::thread::sleep(Duration::from_millis(10));
792                    if i == 50 {
793                        pool.spawn("BAZ", work_barrier(&counter, 2000));
794                    }
795                    if i == 99 {
796                        println!("The first ITER end, at {}ms", start.elapsed().as_millis());
797                    }
798                })
799                .collect::<Vec<_>>();
800            let elapsed = start.elapsed().as_millis();
801            println!("The first ITER finished, at {}ms", elapsed);
802            assert!(
803                elapsed < 1900,
804                "It seems like the par_iter waited on the 2s sleep task to finish"
805            );
806        });
807
808        while counter.load(Ordering::SeqCst) == 0 {
809            println!("waiting for BAZ task to finish");
810            std::thread::sleep(Duration::from_secs(1));
811        }
812    }
813}