Efficient way of traversing an Octree and doing ray hit intersection in a shader

59 views Asked by At

I am traversing my octree in a compute shader in wgsl and WGPU.

I construct and serialize the Tree in Rust to send it to the GPU:

pub enum Node<T: Clone> {
    Group(Box<[Node<T>; 8]>),
    Filled(T),
    Empty,
}

pub fn serialize_chunk(chunk: Chunk, coord: Coord) -> Vec<u8> {
    // chunk.make_dense(6);
    let mut bytes = ByteArray::new();

    let root = chunk.get_root_node();
    let data = serialize_node(root);
    let x = coord[0] as i32;
    let y = coord[1] as i32;
    let z = coord[2] as i32;

    // println!("{} {} {}", x, y, z);

    bytes.write_u32(data.len() as u32 + BYTES_PER_COORD); // length
    bytes.write_i32(x);
    bytes.write_i32(y);
    bytes.write_i32(z);
    bytes.append(data);

    let data = bytes.data();
    // println!("data len: {}", data.len());
    data
}

fn serialize_node(node: &Node<Block>) -> Vec<u8> {
    let mut bytes = ByteArray::new();

    match node {
        Node::Empty => {
            bytes.write_u32(NodeType::Empty as u32);
            bytes.write_u32(BYTES_PER_BLOCK);
            // bytes.write_u32(0x00);
            bytes.write_u8(BlockFlag::Empty as u8);
            bytes.write_u8(0x00);
            bytes.write_u8(0x00);
            bytes.write_u8(0x00);
        }
        Node::Filled(block) => {
            let color = block.to_color();

            bytes.write_u32(NodeType::Filled as u32);
            bytes.write_u32(BYTES_PER_BLOCK);
            bytes.write_u8(BlockFlag::Color as u8); // block flag
            bytes.write_u8(color.r);
            bytes.write_u8(color.g);
            bytes.write_u8(color.b);
        }
        Node::Group(group) => {
            let mut length: u32 = 0;
            bytes.write_u32(NodeType::Group as u32);
            bytes.write_u32(0); // placeholder for number of bytes, only known at end

            for corner in Corner::all() {
                let node = serialize_node(&group[corner as usize]);
                length += node.len() as u32;
                bytes.append(node);
            }

            // settings number of bytes
            bytes.write_u32_at(length, 1);
        }
    }

    bytes.data()
}

This gets stored in the shader as:

var<storage, read_write> chunks: array<u32>;

And recursive ray hit intersection is done like this:

fn hit_chunk_old(ray: ray::Ray, mem_offset: u32, node_position: vec3<f32>) -> ray::Hit {
    var nearest_hit = ray::Hit(
        false,
        10000000.0, // should be infinity
        vec3<f32>(0.0),
    );
    var nearest_distance: f32 = 100000000.0; // should be infinity

    var mem_offset_stack = array<u32, STACK_SIZE>();
    var node_position_stack = array<vec3<f32>, STACK_SIZE>();
    var size_stack = array<f32, STACK_SIZE>();
    var distance_stack = array<f32, STACK_SIZE>();
    var stack_top: u32 = 0u;

    mem_offset_stack[stack_top] = mem_offset;
    node_position_stack[stack_top] = node_position;
    size_stack[stack_top] = 1.0;
    distance_stack[stack_top] = 0.0;

    stack_top += 1u;

    while stack_top > 0u {
        stack_top -= 1u;
        let mem_offset = mem_offset_stack[stack_top];
        let node_position = node_position_stack[stack_top];
        let size = size_stack[stack_top];
        let distance = distance_stack[stack_top];

        let node_type = chunks[mem_offset];

        if node_type == 0u {
            // continue
        } else if node_type == 1u {
            // filled
            if distance > NEAR_CLIP {
                let data = chunks[mem_offset + 2u];
                let mask: u32 = 0xFFu;
                var flag = data & mask;
                var r = (data >> 8u) & mask;
                var g = (data >> 16u) & mask;
                var b = (data >> 24u) & mask;


                var modified = 0u;
                let t = noise::rng(); 
                if t > 0.99 {
                    if r < 255u {
                        r += 1u;
                    } else {
                        r = 0u;
                    }
                }
                
                modified = modified | (flag << 0u);
                modified = modified | (r << 8u);
                modified = modified | (g << 16u);
                modified = modified | (b << 24u);
                chunks[mem_offset + 2u] = modified;

                let color = vec3<f32>(f32(r) / 255.0, f32(g) / 255.0, f32(b) / 255.0);
                if distance < nearest_distance && flag != 0u {
                    nearest_distance = distance;
                    var position = ray.origin + ray.direction * (distance + 0.000001);
                    let block_position = vec3<i32>(position * 64.0);
                    position = vec3<f32>(block_position);
                    let noise = noise::d3(position / vec3<f32>(4.0)) / 8.0;
                    nearest_hit = ray::Hit(
                        true,
                        distance,
                        color + vec3<f32>(noise),
                    );
                }
            }
            
        } else if node_type == 2u {
            // group
            let child_mem_offset = mem_offset + 2u;
            let child_size = size / 2.0;

            for (var c = 0u; c < 8u; c += 1u) {
                let corner_mem_offset = skip_to_corner_in_group(child_mem_offset, c);
                let corner_position_offset = offset_from_corner(c) * child_size;
                let aabb = aabb::Aabb(
                    node_position + corner_position_offset,
                    node_position + corner_position_offset + child_size
                );
                let aabb_hit = aabb::hit(aabb, ray);
                if aabb_hit.hit || aabb::in_bound(aabb, ray.origin) {
                    mem_offset_stack[stack_top] = corner_mem_offset;
                    node_position_stack[stack_top] = node_position + corner_position_offset;
                    size_stack[stack_top] = child_size;
                    distance_stack[stack_top] = aabb_hit.distance;
                    stack_top += 1u;
                }
            }
        } else {
            // return nearest_hit;
            return ray::Hit(true, 1.0, vec3<f32>(0.5, 0.0, 1.0)); // return errorish purple 
        }
    }
    
    return nearest_hit;    
}

But this method is very inefficient, because I cannot just simply return the first hit, as it may not be the nearest one.

I previously used a DDA approach on a fixed size 32x32x32 chunk, which performed way better.

But When moving to an Octree, i hoped to improved performance by being able to skip empty parts.

The problem is that I don't know how to implement a direction based approach like the previous DDA one on an Octree.

0

There are 0 answers