Abstracting mutable/immutable references in rust

56 views Asked by At

I need to get rid of duplication in this code:

pub struct Memory {
    layout: MemoryLayout,
    rom: Vec<u8>,
    ram: Vec<u8>,
}

impl Memory {
    pub fn get_mem_vec_ref(&self, address: u32) -> Result<&Vec<u8>, RiscvError> {
        // ...

        let mem_vec_ref = match address {
            addr if (rom_start..rom_end).contains(&addr) => Ok(&self.rom),
            addr if (ram_start..ram_end).contains(&addr) => Ok(&self.ram),
            addr => Err(RiscvError::MemoryAlignmentError(addr)),
        }?;

        return Ok(mem_vec_ref);
    }

    pub fn get_mem_vec_mut_ref(&mut self, address: u32) -> Result<&mut Vec<u8>, RiscvError> {
        // ...

        let mem_vec_ref = match address {
            addr if (rom_start..rom_end).contains(&addr) => Ok(&mut self.rom),
            addr if (ram_start..ram_end).contains(&addr) => Ok(&mut self.ram),
            addr => Err(RiscvError::MemoryAlignmentError(addr)),
        }?;

        return Ok(mem_vec_ref);
    }
}

How can I abstract using mutable vs immutable reference to self? Can Box or RefCell be helpful in this case?

1

There are 1 answers

3
vallentin On BEST ANSWER

Since you're dealing with references in both cases, then you can define a generic function, where T would be either &Vec<u8> or &mut Vec<u8>. So you can do something like this:

fn get_mem<T>(address: u32, rom: T, ram: T) -> Result<T, RiscvError> {
    // ...

    match address {
        addr if (rom_start..rom_end).contains(&addr) => Ok(rom),
        addr if (ram_start..ram_end).contains(&addr) => Ok(ram),
        addr => Err(RiscvError::MemoryAlignmentError(addr)),
    }
}

impl Memory {
    pub fn get_mem_vec_ref(&self, address: u32) -> Result<&Vec<u8>, RiscvError> {
        // ...

        let mem_vec_ref = get_mem(address, &self.rom, &self.ram)?;

        return Ok(mem_vec_ref);
    }

    pub fn get_mem_vec_mut_ref(&mut self, address: u32) -> Result<&mut Vec<u8>, RiscvError> {
        // ...

        let mem_vec_ref = get_mem(address, &mut self.rom, &mut self.ram)?;

        return Ok(mem_vec_ref);
    }
}

Now, obviously you need to modify get_mem() to account for rom_start, rom_end, ram_start, ram_end. If you want to avoid having to pass 100 fields to get_mem(), then it might be worth introducing a newtype to deal with addresses instead, e.g. something like:

struct Addr {
    // ...
}

impl Addr {
    fn get_mem<T>(&self, rom: T, ram: T) -> Result<T, RiscvError> {
        // ...
    }
}