I'm having difficulty understanding how to implement a task scheduler using crossbeam deque

608 views Asked by At

The issue I'm running into is lifetime management. I've an injector, which I've shared between worker threads. I had to use Arc<> as & was causing lifetime complaints. The local queue isn't shared, so no worries there. Then we come to the stealers.

From the docs, the steals can be shared across threads. So, 4 threads, each with a worker ends up with each thread having 3 stealers. However, again I run into lifetime issues that seem to point to move of self when the thread is being spawned. The example task picker, indicates that the stealers can be represented as a slice of stealers, what I'm ending up with is a slice of references to stealers.

Maybe I've misunderstood the docs, and I should be using Arc-style reference counting to manage the lifetimes of the injector and stealers. Can someone clarify this for me?

Here's a simplification (of sorts) with a pool of 3 threads. I've removed the lifetime management around the thread's stealer and injector as those could be either Arc or <'a> or something else.

// details on injector, worker and stealer can be found here:
// https://docs.rs/crossbeam/0.7.3/crossbeam/deque/index.html
struct ThreadData {
    injector: &deque::Injector::<Task>,     // common global queue
    task_q: deque::Worker::<Task>,          // local queue
    stealers: Vec<&deque::Stealer<Task>>    // stealers for other threads local queue
}
impl ThreadData {
    fn spawn(self) ->  Option<std::thread::JoinHandle<()>> {
        let thread = std::thread::spawn(move|| {
            find_task( &self.task_q, &self.injector, &self.stealers);
        });
        Some(thread)
    }
}
struct Worker {
    stealer: deque::Stealer<Task>,      // to be shared with other threads
    thread: Option<std::thread::JoinHandle<()>>
}

struct Factory {
    injector: deque::Injector::<Task> ,  // owner of the global queue
    workers: Vec<Worker>
}
impl Factory {
    fn new() -> Self {
        Self { injector: deque::Injector::<Task>::new(), workers: Vec::new() }
    }
    fn build_threadpool(mut self) {
        let mut t1 = ThreadData {
            injector: &self.injector,
            task_q: deque::Worker::<Task>::new_fifo(),
            stealers: Vec::new(),
        };
        let w1 = Worker {stealer: t1.task_q.stealer(), thread: None };

        let t2 = ThreadData {
            injector: &self.injector,
            task_q: deque::Worker::<Task>::new_fifo(),
            stealers: Vec::new(),
        };
        let w2 = Worker {stealer: t2.task_q.stealer(), thread: None};

        let t3 = ThreadData {
            injector: &self.injector,
            task_q: deque::Worker::<Task>::new_fifo(),
            stealers: Vec::new(),
        };
        let w3 = Worker {stealer: t3.task_q.stealer(), thread: None };

        t1.stealers.push(&w2.stealer);
        t1.stealers.push(&w3.stealer);

        t2.stealers.push(&w1.stealer);
        t2.stealers.push(&w3.stealer);

        t3.stealers.push(&w1.stealer);
        t3.stealers.push(&w2.stealer);

        // launch threads and save workers
        w1.thread = t1.spawn();
        w2.thread = t2.spawn();
        w3.thread = t3.spawn();

        self.workers.push(w1);
        self.workers.push(w2);
        self.workers.push(w3);
    }

}
1

There are 1 answers

1
globalyoung7 On

https://docs.rs/crossbeam/latest/crossbeam/deque/struct.Injector.html

use std::iter;

use crossbeam_deque::{Injector, Stealer, Worker};

fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {
    local.pop().or_else(|| {
        iter::repeat_with(|| {
            global
                .steal_batch_and_pop(local)
                .or_else(|| stealers.iter().map(|s| s.steal()).collect())
        })
        .find(|s| s.is_retry())
        .and_then(|s| s.success())
    })
}

#[cfg(test)]

mod tests {
    use super::*;
    use crossbeam_deque::Steal;

    #[test]
    fn it_works() {
        let q = Injector::new();
        q.push(1);
        q.push(2);
        dbg!(&q);
        assert_eq!(q.steal(), Steal::Success(1));
        assert_eq!(q.steal(), Steal::Success(2));
        assert_eq!(q.steal(), Steal::Empty);
        dbg!(q);

        let w = Worker::new_lifo();
        let collect = w.stealer();

        w.push(1);
        w.push(2);
        w.push(3);
        dbg!(&w);

        assert_eq!(collect.steal(), Steal::Success(1));
        assert_eq!(w.pop(), Some(3));
        assert_eq!(w.pop(), Some(2));
        let q = Injector::new();
        q.push(1);
        q.push(2);
        q.push(3);
        q.push(4);
        q.push(5);
        q.push(6);

        let w = Worker::new_fifo();
        let _ = q.steal_batch_with_limit(&w, 2);
        assert_eq!(w.pop(), Some(1));
        assert_eq!(w.pop(), Some(2));
        assert_eq!(w.pop(), None);

        q.push(7);
        q.push(8);
        dbg!(&q);
        // Setting a large limit does not guarantee that all elements will be popped. In this case,
        // half of the elements are currently popped, but the number of popped elements is considered
        // an implementation detail that may be changed in the future.
        let _ = q.steal_batch_with_limit(&w, std::usize::MAX);
        assert_eq!(w.len(), 3);
        //
        //
    }
    #[test]
    fn it_works_find_task() {
        let local_worker: Worker<i32> = Worker::new_lifo();
        let global_injector = Injector::new();
        let mut stealers: Vec<Stealer<_>> = vec![];
        local_worker.push(1);
        local_worker.push(2);
        local_worker.push(3);
        local_worker.push(9);
        global_injector.push(1);
        global_injector.push(2);
        global_injector.push(9);
        global_injector.push(3);
        stealers.pop();
        stealers.pop();

        let result = find_task(&local_worker, &global_injector, &stealers);
        dbg!(result);
    }
    // Using find_task function
}

  • cargo nextest run --no-capture
cargo nextest run --no-capture
   Compiling a01_2_deque_crossbeam v0.1.0 (/Users/gy-gyoung/my_project/Rust_Lang/Algorithm_Training/02_Algorithm_Rust_Zig_C_etc/01_Rust_Algorithm/02_data_structures/Deque_crossbeam/a01_2_deque_crossbeam)
warning: function `find_task` is never used
 --> src/lib.rs:5:4
  |
5 | fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {
  |    ^^^^^^^^^
  |
  = note: `#[warn(dead_code)]` on by default

warning: `a01_2_deque_crossbeam` (lib) generated 1 warning
    Finished test [unoptimized + debuginfo] target(s) in 0.02s
    Starting 2 tests across 1 binary
       START             a01_2_deque_crossbeam tests::it_works

running 1 test
[src/lib.rs:28:9] &q = Worker { .. }
[src/lib.rs:32:9] q = Worker { .. }
[src/lib.rs:40:9] &w = Worker { .. }
[src/lib.rs:61:9] &q = Worker { .. }
test tests::it_works ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 1 filtered out; finished in 0.00s

        PASS [   0.004s] a01_2_deque_crossbeam tests::it_works
       START             a01_2_deque_crossbeam tests::it_works_find_task

running 1 test
[src/lib.rs:87:9] result = Some(
    9,
)
test tests::it_works_find_task ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 1 filtered out; finished in 0.00s

        PASS [   0.004s] a01_2_deque_crossbeam tests::it_works_find_task
------------
     Summary [   0.008s] 2 tests run: 2 passed, 0 skipped