veloren_common/
slowjob.rs

1use hashbrown::HashMap;
2use rayon::ThreadPool;
3use std::{
4    collections::VecDeque,
5    sync::{Arc, Mutex},
6    time::Instant,
7};
8use tracing::{error, warn};
9
10/// Provides a Wrapper around rayon threadpool to execute slow-jobs.
11/// slow means, the job doesn't need to not complete within the same tick.
12/// DO NOT USE I/O blocking jobs, but only CPU heavy jobs.
13/// Jobs run here, will reduce the ammount of threads rayon can use during the
14/// main tick.
15///
16/// ## Configuration
17/// This Pool allows you to configure certain names of jobs and assign them a
18/// maximum number of threads # Example
19/// Your system has 16 cores, you assign 12 cores for slow-jobs.
20/// Then you can configure all jobs with the name `CHUNK_GENERATOR` to spawn on
21/// max 50% (6 = cores)
22///
23/// ## Spawn Order
24/// - At least 1 job of a configuration is allowed to run if global limit isn't
25///   hit.
26/// - remaining capacities are spread in relation to their limit. e.g. a
27///   configuration with double the limit will be sheduled to spawn double the
28///   tasks, starting by a round robin.
29///
30/// ## States
31/// - queued
32/// - spawned
33/// - started
34/// - finished
35/// ```
36/// # use veloren_common::slowjob::SlowJobPool;
37/// # use std::sync::Arc;
38///
39/// let threadpool = rayon::ThreadPoolBuilder::new()
40///     .num_threads(16)
41///     .build()
42///     .unwrap();
43/// let pool = SlowJobPool::new(3, 10, Arc::new(threadpool));
44/// pool.configure("CHUNK_GENERATOR", |n| n / 2);
45/// pool.spawn("CHUNK_GENERATOR", move || println!("this is a job"));
46/// ```
47#[derive(Clone)]
48pub struct SlowJobPool {
49    internal: Arc<Mutex<InternalSlowJobPool>>,
50}
51
52#[derive(Debug)]
53pub struct SlowJob {
54    name: String,
55    id: u64,
56}
57
58type JobType = Box<dyn FnOnce() + Send + Sync + 'static>;
59
60struct InternalSlowJobPool {
61    next_id: u64,
62    queue: HashMap<String, VecDeque<Queue>>,
63    configs: HashMap<String, Config>,
64    last_spawned_configs: Vec<String>,
65    global_spawned_and_running: u64,
66    global_limit: u64,
67    jobs_metrics_cnt: usize,
68    jobs_metrics: HashMap<String, Vec<JobMetrics>>,
69    threadpool: Arc<ThreadPool>,
70    internal: Option<Arc<Mutex<Self>>>,
71}
72
73#[derive(Debug)]
74struct Config {
75    local_limit: u64,
76    local_spawned_and_running: u64,
77}
78
79struct Queue {
80    id: u64,
81    name: String,
82    task: JobType,
83}
84
85pub struct JobMetrics {
86    pub queue_created: Instant,
87    pub execution_start: Instant,
88    pub execution_end: Instant,
89}
90
91impl Queue {
92    fn new<F>(name: &str, id: u64, internal: &Arc<Mutex<InternalSlowJobPool>>, f: F) -> Self
93    where
94        F: FnOnce() + Send + Sync + 'static,
95    {
96        let internal = Arc::clone(internal);
97        let name_cloned = name.to_owned();
98        let queue_created = Instant::now();
99        Self {
100            id,
101            name: name.to_owned(),
102            task: Box::new(move || {
103                common_base::prof_span_alloc!(_guard, &name_cloned);
104                let execution_start = Instant::now();
105                f();
106                let execution_end = Instant::now();
107                let metrics = JobMetrics {
108                    queue_created,
109                    execution_start,
110                    execution_end,
111                };
112                // directly maintain the next task afterwards
113                {
114                    let mut lock = internal.lock().expect("slowjob lock poisoned");
115                    lock.finish(&name_cloned, metrics);
116                    lock.spawn_queued();
117                }
118            }),
119        }
120    }
121}
122
123impl InternalSlowJobPool {
124    pub fn new(
125        global_limit: u64,
126        jobs_metrics_cnt: usize,
127        _threadpool: Arc<ThreadPool>,
128    ) -> Arc<Mutex<Self>> {
129        // rayon is having a bug where a ECS task could work-steal a slowjob if we use
130        // the same threadpool, which would cause lagspikes we dont want!
131        let threadpool = Arc::new(
132            rayon::ThreadPoolBuilder::new()
133                .num_threads(global_limit as usize)
134                .thread_name(move |i| format!("slowjob-{}", i))
135                .build()
136                .unwrap(),
137        );
138        let link = Arc::new(Mutex::new(Self {
139            next_id: 0,
140            queue: HashMap::new(),
141            configs: HashMap::new(),
142            last_spawned_configs: Vec::new(),
143            global_spawned_and_running: 0,
144            global_limit: global_limit.max(1),
145            jobs_metrics_cnt,
146            jobs_metrics: HashMap::new(),
147            threadpool,
148            internal: None,
149        }));
150
151        let link_clone = Arc::clone(&link);
152        link.lock()
153            .expect("poisoned on InternalSlowJobPool::new")
154            .internal = Some(link_clone);
155        link
156    }
157
158    /// returns order of configuration which are queued next
159    fn calc_queued_order(
160        &self,
161        mut queued: HashMap<&String, u64>,
162        mut limit: usize,
163    ) -> Vec<String> {
164        let mut roundrobin = self.last_spawned_configs.clone();
165        let mut result = vec![];
166        let spawned = self
167            .configs
168            .iter()
169            .map(|(n, c)| (n, c.local_spawned_and_running))
170            .collect::<HashMap<_, u64>>();
171        let mut queried_capped = self
172            .configs
173            .iter()
174            .map(|(n, c)| {
175                (
176                    n,
177                    queued
178                        .get(&n)
179                        .cloned()
180                        .unwrap_or(0)
181                        .min(c.local_limit - c.local_spawned_and_running),
182                )
183            })
184            .collect::<HashMap<_, _>>();
185        // grab all configs that are queued and not running. in roundrobin order
186        for n in roundrobin.clone().into_iter() {
187            if let Some(c) = queued.get_mut(&n)
188                && *c > 0
189                && spawned.get(&n).cloned().unwrap_or(0) == 0
190            {
191                result.push(n.clone());
192                *c -= 1;
193                limit -= 1;
194                queried_capped.get_mut(&n).map(|v| *v -= 1);
195                roundrobin
196                    .iter()
197                    .position(|e| e == &n)
198                    .map(|i| roundrobin.remove(i));
199                roundrobin.push(n);
200                if limit == 0 {
201                    return result;
202                }
203            }
204        }
205        //schedule rest based on their possible limites, don't use round robin here
206        let total_limit = queried_capped.values().sum::<u64>() as f32;
207        if total_limit < f32::EPSILON {
208            return result;
209        }
210        let mut spawn_rates = queried_capped
211            .iter()
212            .map(|(&n, l)| (n, ((*l as f32 * limit as f32) / total_limit).min(*l as f32)))
213            .collect::<Vec<_>>();
214        while limit > 0 {
215            spawn_rates.sort_by(|(_, a), (_, b)| {
216                if b < a {
217                    core::cmp::Ordering::Less
218                } else if (b - a).abs() < f32::EPSILON {
219                    core::cmp::Ordering::Equal
220                } else {
221                    core::cmp::Ordering::Greater
222                }
223            });
224            match spawn_rates.first_mut() {
225                Some((n, r)) => {
226                    if *r > f32::EPSILON {
227                        result.push(n.clone());
228                        limit -= 1;
229                        *r -= 1.0;
230                    } else {
231                        break;
232                    }
233                },
234                None => break,
235            }
236        }
237        result
238    }
239
240    fn can_spawn(&self, name: &str) -> bool {
241        let queued = self
242            .queue
243            .iter()
244            .map(|(n, m)| (n, m.len() as u64))
245            .collect::<HashMap<_, u64>>();
246        let mut to_be_queued = queued.clone();
247        let name = name.to_owned();
248        *to_be_queued.entry(&name).or_default() += 1;
249        let limit = (self.global_limit - self.global_spawned_and_running) as usize;
250        // calculate to_be_queued first
251        let to_be_queued_order = self.calc_queued_order(to_be_queued, limit);
252        let queued_order = self.calc_queued_order(queued, limit);
253        // if its queued one time more then its okay to spawn
254        let to_be_queued_cnt = to_be_queued_order
255            .into_iter()
256            .filter(|n| n == &name)
257            .count();
258        let queued_cnt = queued_order.into_iter().filter(|n| n == &name).count();
259        to_be_queued_cnt > queued_cnt
260    }
261
262    pub fn spawn<F>(&mut self, name: &str, f: F) -> SlowJob
263    where
264        F: FnOnce() + Send + Sync + 'static,
265    {
266        let id = self.next_id;
267        self.next_id += 1;
268        let queue = Queue::new(name, id, self.internal.as_ref().expect("internal empty"), f);
269        self.queue
270            .entry(name.to_string())
271            .or_default()
272            .push_back(queue);
273        debug_assert!(
274            self.configs.contains_key(name),
275            "Can't spawn unconfigured task!"
276        );
277        //spawn already queued
278        self.spawn_queued();
279        SlowJob {
280            name: name.to_string(),
281            id,
282        }
283    }
284
285    fn finish(&mut self, name: &str, metrics: JobMetrics) {
286        let metric = self.jobs_metrics.entry(name.to_string()).or_default();
287
288        if metric.len() < self.jobs_metrics_cnt {
289            metric.push(metrics);
290        }
291        self.global_spawned_and_running -= 1;
292        if let Some(c) = self.configs.get_mut(name) {
293            c.local_spawned_and_running -= 1;
294        } else {
295            warn!(?name, "sync_maintain on a no longer existing config");
296        }
297    }
298
299    fn spawn_queued(&mut self) {
300        let queued = self
301            .queue
302            .iter()
303            .map(|(n, m)| (n, m.len() as u64))
304            .collect::<HashMap<_, u64>>();
305        let limit = self.global_limit as usize;
306        let queued_order = self.calc_queued_order(queued, limit);
307        for name in queued_order.into_iter() {
308            match self.queue.get_mut(&name) {
309                Some(deque) => match deque.pop_front() {
310                    Some(queue) => {
311                        //fire
312                        self.global_spawned_and_running += 1;
313                        self.configs
314                            .get_mut(&queue.name)
315                            .expect("cannot fire a unconfigured job")
316                            .local_spawned_and_running += 1;
317                        self.last_spawned_configs
318                            .iter()
319                            .position(|e| e == &queue.name)
320                            .map(|i| self.last_spawned_configs.remove(i));
321                        self.last_spawned_configs.push(queue.name.to_owned());
322                        self.threadpool.spawn(queue.task);
323                    },
324                    None => error!(
325                        "internal calculation is wrong, we extected a schedulable job to be \
326                         present in the queue"
327                    ),
328                },
329                None => error!(
330                    "internal calculation is wrong, we marked a queue as schedulable which \
331                     doesn't exist"
332                ),
333            }
334        }
335    }
336
337    pub fn take_metrics(&mut self) -> HashMap<String, Vec<JobMetrics>> {
338        core::mem::take(&mut self.jobs_metrics)
339    }
340}
341
342impl SlowJobPool {
343    pub fn new(global_limit: u64, jobs_metrics_cnt: usize, threadpool: Arc<ThreadPool>) -> Self {
344        Self {
345            internal: InternalSlowJobPool::new(global_limit, jobs_metrics_cnt, threadpool),
346        }
347    }
348
349    /// configure a NAME to spawn up to f(n) threads, depending on how many
350    /// threads we globally have available
351    pub fn configure<F>(&self, name: &str, f: F)
352    where
353        F: Fn(u64) -> u64,
354    {
355        let mut lock = self.internal.lock().expect("lock poisoned while configure");
356        let cnf = Config {
357            local_limit: f(lock.global_limit).max(1),
358            local_spawned_and_running: 0,
359        };
360        lock.configs.insert(name.to_owned(), cnf);
361        lock.last_spawned_configs.push(name.to_owned());
362    }
363
364    /// spawn a new slow job on a certain NAME IF it can run immediately
365    #[expect(clippy::result_unit_err)]
366    pub fn try_run<F>(&self, name: &str, f: F) -> Result<SlowJob, ()>
367    where
368        F: FnOnce() + Send + Sync + 'static,
369    {
370        let mut lock = self.internal.lock().expect("lock poisoned while try_run");
371        //spawn already queued
372        lock.spawn_queued();
373        if lock.can_spawn(name) {
374            Ok(lock.spawn(name, f))
375        } else {
376            Err(())
377        }
378    }
379
380    pub fn spawn<F>(&self, name: &str, f: F) -> SlowJob
381    where
382        F: FnOnce() + Send + Sync + 'static,
383    {
384        self.internal
385            .lock()
386            .expect("lock poisoned while spawn")
387            .spawn(name, f)
388    }
389
390    pub fn cancel(&self, job: SlowJob) -> Result<(), SlowJob> {
391        let mut lock = self.internal.lock().expect("lock poisoned while cancel");
392        if let Some(m) = lock.queue.get_mut(&job.name) {
393            let p = match m.iter().position(|p| p.id == job.id) {
394                Some(p) => p,
395                None => return Err(job),
396            };
397            if m.remove(p).is_some() {
398                return Ok(());
399            }
400        }
401        Err(job)
402    }
403
404    pub fn take_metrics(&self) -> HashMap<String, Vec<JobMetrics>> {
405        self.internal
406            .lock()
407            .expect("lock poisoned while take_metrics")
408            .take_metrics()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use std::{
416        sync::{
417            Barrier,
418            atomic::{AtomicBool, AtomicU64, Ordering},
419        },
420        time::Duration,
421    };
422
423    fn mock_pool(
424        pool_threads: usize,
425        global_threads: u64,
426        metrics: usize,
427        foo: u64,
428        bar: u64,
429        baz: u64,
430    ) -> SlowJobPool {
431        let threadpool = rayon::ThreadPoolBuilder::new()
432            .num_threads(pool_threads)
433            .build()
434            .unwrap();
435        let pool = SlowJobPool::new(global_threads, metrics, Arc::new(threadpool));
436        if foo != 0 {
437            pool.configure("FOO", |x| x / foo);
438        }
439        if bar != 0 {
440            pool.configure("BAR", |x| x / bar);
441        }
442        if baz != 0 {
443            pool.configure("BAZ", |x| x / baz);
444        }
445        pool
446    }
447
448    #[test]
449    fn simple_queue() {
450        let pool = mock_pool(4, 4, 0, 1, 0, 0);
451        let internal = pool.internal.lock().unwrap();
452        let queue_data = [("FOO", 1u64)]
453            .iter()
454            .map(|(n, c)| ((*n).to_owned(), *c))
455            .collect::<Vec<_>>();
456        let queued = queue_data
457            .iter()
458            .map(|(s, c)| (s, *c))
459            .collect::<HashMap<_, _>>();
460        let result = internal.calc_queued_order(queued, 4);
461        assert_eq!(result.len(), 1);
462        assert_eq!(result[0], "FOO");
463    }
464
465    #[test]
466    fn multiple_queue() {
467        let pool = mock_pool(4, 4, 0, 1, 0, 0);
468        let internal = pool.internal.lock().unwrap();
469        let queue_data = [("FOO", 2u64)]
470            .iter()
471            .map(|(n, c)| ((*n).to_owned(), *c))
472            .collect::<Vec<_>>();
473        let queued = queue_data
474            .iter()
475            .map(|(s, c)| (s, *c))
476            .collect::<HashMap<_, _>>();
477        let result = internal.calc_queued_order(queued, 4);
478        assert_eq!(result.len(), 2);
479        assert_eq!(result[0], "FOO");
480        assert_eq!(result[1], "FOO");
481    }
482
483    #[test]
484    fn limit_queue() {
485        let pool = mock_pool(5, 5, 0, 1, 0, 0);
486        let internal = pool.internal.lock().unwrap();
487        let queue_data = [("FOO", 80u64)]
488            .iter()
489            .map(|(n, c)| ((*n).to_owned(), *c))
490            .collect::<Vec<_>>();
491        let queued = queue_data
492            .iter()
493            .map(|(s, c)| (s, *c))
494            .collect::<HashMap<_, _>>();
495        let result = internal.calc_queued_order(queued, 4);
496        assert_eq!(result.len(), 4);
497        assert_eq!(result[0], "FOO");
498        assert_eq!(result[1], "FOO");
499        assert_eq!(result[2], "FOO");
500        assert_eq!(result[3], "FOO");
501    }
502
503    #[test]
504    fn simple_queue_2() {
505        let pool = mock_pool(4, 4, 0, 1, 1, 0);
506        let internal = pool.internal.lock().unwrap();
507        let queue_data = [("FOO", 1u64), ("BAR", 1u64)]
508            .iter()
509            .map(|(n, c)| ((*n).to_owned(), *c))
510            .collect::<Vec<_>>();
511        let queued = queue_data
512            .iter()
513            .map(|(s, c)| (s, *c))
514            .collect::<HashMap<_, _>>();
515        let result = internal.calc_queued_order(queued, 4);
516        assert_eq!(result.len(), 2);
517        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 1);
518        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 1);
519    }
520
521    #[test]
522    fn multiple_queue_3() {
523        let pool = mock_pool(4, 4, 0, 1, 1, 0);
524        let internal = pool.internal.lock().unwrap();
525        let queue_data = [("FOO", 2u64), ("BAR", 2u64)]
526            .iter()
527            .map(|(n, c)| ((*n).to_owned(), *c))
528            .collect::<Vec<_>>();
529        let queued = queue_data
530            .iter()
531            .map(|(s, c)| (s, *c))
532            .collect::<HashMap<_, _>>();
533        let result = internal.calc_queued_order(queued, 4);
534        assert_eq!(result.len(), 4);
535        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 2);
536        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 2);
537    }
538
539    #[test]
540    fn multiple_queue_4() {
541        let pool = mock_pool(4, 4, 0, 2, 1, 0);
542        let internal = pool.internal.lock().unwrap();
543        let queue_data = [("FOO", 3u64), ("BAR", 3u64)]
544            .iter()
545            .map(|(n, c)| ((*n).to_owned(), *c))
546            .collect::<Vec<_>>();
547        let queued = queue_data
548            .iter()
549            .map(|(s, c)| (s, *c))
550            .collect::<HashMap<_, _>>();
551        let result = internal.calc_queued_order(queued, 4);
552        assert_eq!(result.len(), 4);
553        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 2);
554        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 2);
555    }
556
557    #[test]
558    fn multiple_queue_5() {
559        let pool = mock_pool(4, 4, 0, 2, 1, 0);
560        let internal = pool.internal.lock().unwrap();
561        let queue_data = [("FOO", 5u64), ("BAR", 5u64)]
562            .iter()
563            .map(|(n, c)| ((*n).to_owned(), *c))
564            .collect::<Vec<_>>();
565        let queued = queue_data
566            .iter()
567            .map(|(s, c)| (s, *c))
568            .collect::<HashMap<_, _>>();
569        let result = internal.calc_queued_order(queued, 5);
570        assert_eq!(result.len(), 5);
571        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 2);
572        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 3);
573    }
574
575    #[test]
576    fn multiple_queue_6() {
577        let pool = mock_pool(40, 40, 0, 2, 1, 0);
578        let internal = pool.internal.lock().unwrap();
579        let queue_data = [("FOO", 5u64), ("BAR", 5u64)]
580            .iter()
581            .map(|(n, c)| ((*n).to_owned(), *c))
582            .collect::<Vec<_>>();
583        let queued = queue_data
584            .iter()
585            .map(|(s, c)| (s, *c))
586            .collect::<HashMap<_, _>>();
587        let result = internal.calc_queued_order(queued, 11);
588        assert_eq!(result.len(), 10);
589        assert_eq!(result.iter().filter(|&x| x == "FOO").count(), 5);
590        assert_eq!(result.iter().filter(|&x| x == "BAR").count(), 5);
591    }
592
593    #[test]
594    fn roundrobin() {
595        let pool = mock_pool(4, 4, 0, 2, 2, 0);
596        let queue_data = [("FOO", 5u64), ("BAR", 5u64)]
597            .iter()
598            .map(|(n, c)| ((*n).to_owned(), *c))
599            .collect::<Vec<_>>();
600        let queued = queue_data
601            .iter()
602            .map(|(s, c)| (s, *c))
603            .collect::<HashMap<_, _>>();
604        // Spawn a FOO task.
605        pool.internal
606            .lock()
607            .unwrap()
608            .spawn("FOO", || println!("foo"));
609        // a barrier in f doesnt work as we need to wait for the cleanup
610        while pool.internal.lock().unwrap().global_spawned_and_running != 0 {
611            std::thread::yield_now();
612        }
613        let result = pool
614            .internal
615            .lock()
616            .unwrap()
617            .calc_queued_order(queued.clone(), 1);
618        assert_eq!(result.len(), 1);
619        assert_eq!(result[0], "BAR");
620        // keep order if no new is spawned
621        let result = pool
622            .internal
623            .lock()
624            .unwrap()
625            .calc_queued_order(queued.clone(), 1);
626        assert_eq!(result.len(), 1);
627        assert_eq!(result[0], "BAR");
628        // spawn a BAR task
629        pool.internal
630            .lock()
631            .unwrap()
632            .spawn("BAR", || println!("bar"));
633        while pool.internal.lock().unwrap().global_spawned_and_running != 0 {
634            std::thread::yield_now();
635        }
636        let result = pool.internal.lock().unwrap().calc_queued_order(queued, 1);
637        assert_eq!(result.len(), 1);
638        assert_eq!(result[0], "FOO");
639    }
640
641    #[test]
642    #[should_panic]
643    fn unconfigured() {
644        let pool = mock_pool(4, 4, 0, 2, 1, 0);
645        let mut internal = pool.internal.lock().unwrap();
646        internal.spawn("UNCONFIGURED", || println!());
647    }
648
649    #[test]
650    fn correct_spawn_doesnt_panic() {
651        let pool = mock_pool(4, 4, 0, 2, 1, 0);
652        let mut internal = pool.internal.lock().unwrap();
653        internal.spawn("FOO", || println!("foo"));
654        internal.spawn("BAR", || println!("bar"));
655    }
656
657    #[test]
658    fn can_spawn() {
659        let pool = mock_pool(4, 4, 0, 2, 1, 0);
660        let internal = pool.internal.lock().unwrap();
661        assert!(internal.can_spawn("FOO"));
662        assert!(internal.can_spawn("BAR"));
663    }
664
665    #[test]
666    fn try_run_works() {
667        let pool = mock_pool(4, 4, 0, 2, 1, 0);
668        pool.try_run("FOO", || println!("foo")).unwrap();
669        pool.try_run("BAR", || println!("bar")).unwrap();
670    }
671
672    #[test]
673    fn try_run_exhausted() {
674        let pool = mock_pool(8, 8, 0, 4, 2, 0);
675        let func = || loop {
676            std::thread::sleep(Duration::from_secs(1))
677        };
678        pool.try_run("FOO", func).unwrap();
679        pool.try_run("BAR", func).unwrap();
680        pool.try_run("FOO", func).unwrap();
681        pool.try_run("BAR", func).unwrap();
682        pool.try_run("FOO", func).unwrap_err();
683        pool.try_run("BAR", func).unwrap();
684        pool.try_run("FOO", func).unwrap_err();
685        pool.try_run("BAR", func).unwrap();
686        pool.try_run("FOO", func).unwrap_err();
687        pool.try_run("BAR", func).unwrap_err();
688        pool.try_run("FOO", func).unwrap_err();
689    }
690
691    #[test]
692    fn actually_runs_1() {
693        let pool = mock_pool(4, 4, 0, 0, 0, 1);
694        let barrier = Arc::new(Barrier::new(2));
695        let barrier_clone = Arc::clone(&barrier);
696        pool.try_run("BAZ", move || {
697            barrier_clone.wait();
698        })
699        .unwrap();
700        barrier.wait();
701    }
702
703    #[test]
704    fn actually_runs_2() {
705        let pool = mock_pool(4, 4, 0, 0, 0, 1);
706        let barrier = Arc::new(Barrier::new(2));
707        let barrier_clone = Arc::clone(&barrier);
708        pool.spawn("BAZ", move || {
709            barrier_clone.wait();
710        });
711        barrier.wait();
712    }
713
714    #[test]
715    fn actually_waits() {
716        let pool = mock_pool(4, 4, 0, 4, 0, 1);
717        let ops_i_ran = Arc::new(AtomicBool::new(false));
718        let ops_i_ran_clone = Arc::clone(&ops_i_ran);
719        let barrier = Arc::new(Barrier::new(2));
720        let barrier_clone = Arc::clone(&barrier);
721        let barrier2 = Arc::new(Barrier::new(2));
722        let barrier2_clone = Arc::clone(&barrier2);
723        pool.try_run("FOO", move || {
724            barrier_clone.wait();
725        })
726        .unwrap();
727        pool.spawn("FOO", move || {
728            ops_i_ran_clone.store(true, Ordering::SeqCst);
729            barrier2_clone.wait();
730        });
731        // in this case we have to sleep
732        std::thread::sleep(Duration::from_secs(1));
733        assert!(!ops_i_ran.load(Ordering::SeqCst));
734        // now finish the first job
735        barrier.wait();
736        // now wait on the second job to be actually finished
737        barrier2.wait();
738    }
739
740    #[test]
741    fn verify_metrics() {
742        let pool = mock_pool(4, 4, 2, 1, 0, 4);
743        let barrier = Arc::new(Barrier::new(5));
744        for name in &["FOO", "BAZ", "FOO", "FOO"] {
745            let barrier_clone = Arc::clone(&barrier);
746            pool.spawn(name, move || {
747                barrier_clone.wait();
748            });
749        }
750        // now finish all jobs
751        barrier.wait();
752        // in this case we have to sleep to give it some time to store all the metrics
753        std::thread::sleep(Duration::from_secs(2));
754        let metrics = pool.take_metrics();
755        let foo = metrics.get("FOO").expect("FOO doesn't exist in metrics");
756        //its limited to 2, even though we had 3 jobs
757        assert_eq!(foo.len(), 2);
758        assert!(metrics.get("BAR").is_none());
759        let baz = metrics.get("BAZ").expect("BAZ doesn't exist in metrics");
760        assert_eq!(baz.len(), 1);
761    }
762
763    fn work_barrier(counter: &Arc<AtomicU64>, ms: u64) -> impl std::ops::FnOnce() -> () + use<> {
764        let counter = Arc::clone(counter);
765        println!("Create work_barrier");
766        move || {
767            println!(".{}..", ms);
768            std::thread::sleep(Duration::from_millis(ms));
769            println!(".{}..Done", ms);
770            counter.fetch_add(1, Ordering::SeqCst);
771        }
772    }
773
774    #[test]
775    fn verify_that_spawn_doesnt_block_par_iter() {
776        let threadpool = Arc::new(
777            rayon::ThreadPoolBuilder::new()
778                .num_threads(20)
779                .build()
780                .unwrap(),
781        );
782        let pool = SlowJobPool::new(2, 100, Arc::<rayon::ThreadPool>::clone(&threadpool));
783        pool.configure("BAZ", |_| 2);
784        let counter = Arc::new(AtomicU64::new(0));
785        let start = Instant::now();
786
787        threadpool.install(|| {
788            use rayon::prelude::*;
789            (0..100)
790                .into_par_iter()
791                .map(|i| {
792                    std::thread::sleep(Duration::from_millis(10));
793                    if i == 50 {
794                        pool.spawn("BAZ", work_barrier(&counter, 2000));
795                    }
796                    if i == 99 {
797                        println!("The first ITER end, at {}ms", start.elapsed().as_millis());
798                    }
799                })
800                .collect::<Vec<_>>();
801            let elapsed = start.elapsed().as_millis();
802            println!("The first ITER finished, at {}ms", elapsed);
803            assert!(
804                elapsed < 1900,
805                "It seems like the par_iter waited on the 2s sleep task to finish"
806            );
807        });
808
809        while counter.load(Ordering::SeqCst) == 0 {
810            println!("waiting for BAZ task to finish");
811            std::thread::sleep(Duration::from_secs(1));
812        }
813    }
814}