Shared memory in Rust

175 views Asked by At

Environment:

macOS Sonoma Ver.14.0 (M1 mac) Rust Ver.1.65.0

What I want to do: I want to share a vec with an array of [u8;128] elements between multithreads. The requirements I want to perform when sharing are as follows.

  1. the entire vec must be readable
  2. to be able to rewrite elements of a specific [u8; 128] type in the vec
  3. be able to insert data of type [u8; 128] into vec

Below is the code I wrote, but this code can do up to reading, but there is a problem that the writing is not reflected. If I run this code and then run the following command once on the computer where it was executed


    nc -v localhost 50051


    [[0u8; 128],[1u8; 128],[2u8; 128]]

will be output. This is correct up to this point, but the data output on the second run is the same as the first run. My intention is that the second element will output data with 3 fillings as shown below, because I am updating the data in the first run.


    [[0u8; 128],[3u8; 128],[2u8; 128]]

I am guessing that my use of Arc is wrong and that it is actually a clone of SharedData being passed around instead of a reference to SharedData, but I don't know how I can identify this. How can I fix the code to make it work as I intended?

main.rs:

use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use tokio_task_pool::Pool;

struct SharedData {
    data: Arc<RwLock<Vec<[u8; 128]>>>
}

impl SharedData {
    fn new(data: RwLock<Vec<[u8; 128]>>) -> Self {
        Self {
            data: Arc::new(data)
        }
    }

    fn update(&self, index: usize, update_data: [u8; 128]) {
        let read_guard_for_array = self.data.read().unwrap();
        let write_lock = RwLock::new((*read_guard_for_array)[index]);
        let mut write_guard_for_item = write_lock.write().unwrap();
        *write_guard_for_item = update_data;
    }
}

fn socket_to_async_tcplistener(s: socket2::Socket) -> std::io::Result<tokio::net::TcpListener> {
    std::net::TcpListener::from(s).try_into()
}

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    db_arc.update(1, [3u8; 128]);
}

async fn serve(_: usize, tcplistener_arc: Arc<tokio::net::TcpListener>, db_arc: Arc<SharedData>) {
    let task_pool_capacity = 10;

    let task_pool = Pool::bounded(task_pool_capacity)
        .with_spawn_timeout(Duration::from_secs(300))
        .with_run_timeout(Duration::from_secs(300));
    
    loop {
        let (stream, _) = tcplistener_arc.as_ref().accept().await.unwrap();
        let db_arc_clone = db_arc.clone();

        task_pool.spawn(async move {
            process(stream, db_arc_clone).await;
        }).await.unwrap();
    }
}

#[tokio::main]
async fn main() {
    let addr: std::net::SocketAddr = "0.0.0.0:50051".parse().unwrap();
    let soc2 = socket2::Socket::new(
        match addr {
            SocketAddr::V4(_) => socket2::Domain::IPV4,
            SocketAddr::V6(_) => socket2::Domain::IPV6,
        },
        socket2::Type::STREAM,
        Some(socket2::Protocol::TCP)
    ).unwrap();
    
    soc2.set_reuse_address(true).unwrap();
    soc2.set_reuse_port(true).unwrap();
    soc2.set_nonblocking(true).unwrap();
    soc2.bind(&addr.into()).unwrap();
    soc2.listen(8192).unwrap();

    let tcp_listener = Arc::new(socket_to_async_tcplistener(soc2).unwrap());

    let mut vec = vec![
        [0u8; 128],
        [1u8; 128],
        [2u8; 128],
    ];

    let share_db = Arc::new(SharedData::new(RwLock::new(vec)));
    let mut handlers = Vec::new();
    for i in 0..num_cpus::get() - 1 {
        let cloned_listener = Arc::clone(&tcp_listener);
        let db_arc = share_db.clone();

        let h = std::thread::spawn(move || {
            tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .unwrap()
                .block_on(serve(i, cloned_listener, db_arc));
        });
        handlers.push(h);
    }

    for h in handlers {
        h.join().unwrap();
    }
}

Cargo.toml:

[package]
name = "tokio-test"
version = "0.1.0"
edition = "2021"

[dependencies]
log = "0.4.20"
env_logger = "0.10.0"
tokio = { version = "1.34.0", features = ["full"] }
tokio-stream = { version = "0.1.14", features = ["net"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_yaml = "0.9.27"
serde_derive = "1.0.193"
mio = {version="0.8.9", features=["net", "os-poll", "os-ext"]}
num_cpus = "1.16.0"
socket2 = { version="0.5.5", features = ["all"]}
array-macro = "2.1.8"
tokio-task-pool = "0.1.5"
argparse = "0.2.2"
2

There are 2 answers

4
Yoric On BEST ANSWER

I haven't looked at the entire code, but there are a few errors.

fn update()

    fn update(&self, index: usize, update_data: [u8; 128]) {
        let read_guard_for_array = self.data.read().unwrap();
        let write_lock = RwLock::new((*read_guard_for_array)[index]);
        let mut write_guard_for_item = write_lock.write().unwrap();
        *write_guard_for_item = update_data;
    }

That's not how you use a RwLock:

  • if you want to modify the data, instead of using self.data.read(), use self.data.write();
  • I'm not sure what you intend to do with this the second RwLock, but it is useless.

Rather, do something like

    fn update(&self, index: usize, update_data: [u8; 128]) {
        let write_guard_for_array = self.data.write().unwrap();
        write_guard_for_array[index] = update_data;
    }

fn process()

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    db_arc.update(1, [3u8; 128]);
}

Generally, you probably shouldn't access db_arc.data directly. But beyond that, once you fix function update(), this is going to deadlock:

  1. You acquire db_arc.data.read(). By definition of a RwLock, this means that nobody can modify the contents of db_arc.data until the read lock is released.
  2. The read lock is released only at the end of the scope.
  3. Before the end of the scope, you call update(), which is going to attempt to acquire data.write(). But it cannot acquire it until the read lock is released.

You probably want something along the lines of:

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    {
      let read_guard = db_arc.data.read().unwrap();
      println!("In process() read: {:?}", *read_guard);
    } // End of scope, `read_guard` is released.
    db_arc.update(1, [3u8; 128]);
}

tokio + threads

You're mixing threads and tokio. It's theoretically possible, but risky. Both choices are valid, but I suggest picking either one or the other. Typically, pick tokio if you have lots of I/O (e.g. network requests or disk access) or threads if you have lots of CPU usage.

1
Jmb On
fn update(&self, index: usize, update_data: [u8; 128]) {
    let read_guard_for_array = self.data.read().unwrap();
    let write_lock = RwLock::new((*read_guard_for_array)[index]);

This creates a copy of the data and wraps it in a useless RwLock (useless because that copy is always kept in a single thread).

    let mut write_guard_for_item = write_lock.write().unwrap();
    *write_guard_for_item = update_data;
}

This modifies the copy, which then gets immediately discarded at the end of the function.

Instead you need to lock the RwLock that you already have:

fn update(&self, index: usize, update_data: [u8; 128]) {
    let mut write_guard = self.data.write().unwrap();
    write_guard[index] = update_data;
}

Note that there is no way to get a write lock for only a specific item and read locks for the whole array: both read and write locks must relate to the same data. That means you also need to release the read lock before you can update:

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    drop (read_guard);
    db_arc.update(1, [3u8; 128]);
}