Metal / SceneKit Fragment Shaders - How to avoid rendering ontop of other geometry?

626 views Asked by At

Given this basic SceneKit scene with a cube, sphere and pyramid positioned next to each other, where the sphere intersects the pyramid:

SceneKit scene with no techniques applied

And given the below, simplified Metal shader / SCNTechnique which just renders any nodes with a specific categoryBitmask solid red:

Technique Definition:

<dict>
    <key>passes</key>
    <dict>
        <key>pass_fill_drawMask</key>
        <dict>
            <key>draw</key>
            <string>DRAW_SCENE</string>
            <key>program</key>
            <string>doesntexist</string>
            <key>metalVertexShader</key>
            <string>pass_fill_drawMask_vertex</string>
            <key>metalFragmentShader</key>
            <string>pass_fill_drawMask_fragment</string>
            <key>includeCategoryMask</key>
            <string>0</string>
            <key>colorStates</key>
            <dict>
                <key>clear</key>
                <true/>
            </dict>
            <key>inputs</key>
            <dict>
                <key>aPos</key>
                <string>vertexSymbol</string>
            </dict>
            <key>outputs</key>
            <dict>
                <key>color</key>
                <string>MASK</string>
            </dict>
        </dict>
        <key>pass_fill_render</key>
        <dict>
            <key>draw</key>
            <string>DRAW_QUAD</string>
            <key>program</key>
            <string>doesntexist</string>
            <key>metalVertexShader</key>
            <string>pass_fill_render_vertex</string>
            <key>metalFragmentShader</key>
            <string>pass_fill_render_fragment</string>
            <key>inputs</key>
            <dict>
                <key>aPos</key>
                <string>vertexSymbol</string>
                <key>colorSampler</key>
                <string>COLOR</string>
                <key>maskSampler</key>
                <string>MASK</string>
                <key>resolution</key>
                <string>resolution</string>
            </dict>
            <key>outputs</key>
            <dict>
                <key>color</key>
                <string>COLOR</string>
            </dict>
        </dict>
    </dict>
    <key>sequence</key>
    <array>
        <string>pass_fill_drawMask</string>
        <string>pass_fill_render</string>
    </array>
    <key>symbols</key>
    <dict>
        <key>resolution</key>
        <dict>
            <key>type</key>
            <string>float</string>
        </dict>
        <key>vertexSymbol</key>
        <dict>
            <key>semantic</key>
            <string>vertex</string>
        </dict>
    </dict>
    <key>targets</key>
    <dict>
        <key>MASK</key>
        <dict>
            <key>type</key>
            <string>color</string>
            <key>format</key>
            <string>rgb</string>
            <key>size</key>
            <string>1024x1024</string>
            <key>scaleFactor</key>
            <integer>1</integer>
        </dict>
    </dict>
</dict>

Shader:

#include <metal_stdlib>
using namespace metal;
#include <SceneKit/scn_metal>

struct Node {
    float4x4 modelTransform;
    float4x4 modelViewTransform;
    float4x4 normalTransform;
    float4x4 modelViewProjectionTransform;
};

struct VertexIn {
    float4 position [[attribute(SCNVertexSemanticPosition)]];
    float4 normal [[attribute(SCNVertexSemanticNormal)]];
};

struct VertexOut {
    float4 position [[position]];
    float2 uv;
};

typedef struct {
    float resolution;
} Uniforms;

constexpr sampler s = sampler(coord::normalized,
                              r_address::clamp_to_edge,
                              t_address::clamp_to_edge,
                              filter::linear);

////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// PASS 1 - Render solid pixels for the input geometry to a target image for later use
////////////////////////////////////////////////////////////////////////////////////////////////////////////

vertex VertexOut pass_fill_drawMask_vertex(VertexIn in [[stage_in]],
                                           constant Node& scn_node [[buffer(0)]]) {
    VertexOut out;
    out.position = scn_node.modelViewProjectionTransform * float4(in.position.xyz, 1.0);
    out.uv = float2((in.position.x + 1.0) * 0.5, 1.0 - (in.position.y + 1.0) * 0.5);
    return out;
};

fragment half4 pass_fill_drawMask_fragment(VertexOut in [[stage_in]]) {
    return half4(1.0, 1.0, 1.0, 1.0);
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// PASS 2 - Render any opaque pixels from the target image a solid color
////////////////////////////////////////////////////////////////////////////////////////////////////////////

vertex VertexOut pass_fill_render_vertex(VertexIn in [[stage_in]],
                                    texture2d<float, access::sample> colorSampler,
                                    texture2d<float, access::sample> maskSampler) {
    VertexOut out;
    out.position = in.position;
    out.uv = float2((in.position.x + 1.0) * 0.5,1.0 - (in.position.y + 1.0) * 0.5);
    return out;
};

fragment half4 pass_fill_render_fragment(VertexOut in [[stage_in]],
                                         texture2d<float, access::sample> colorSampler [[texture(0)]],
                                         texture2d<float, access::sample> maskSampler [[texture(1)]],
                                         constant SCNSceneBuffer& scn_frame [[buffer(0)]],
                                         constant Uniforms& uniforms [[buffer(1)]]) {
    
    float2 ratio = float2(colorSampler.get_width() / uniforms.resolution, colorSampler.get_height() / uniforms.resolution);
    ratio = float2(ratio.x > 1 ? 1 : ratio.x, ratio.y > 1 ? 1 : ratio.y);
    
    float4 maskColor = maskSampler.sample(s, in.uv * ratio);
    
    if (maskColor.a > 0) {
        // This pixel belongs to the geometry, render it red
        return half4(1.0, 0.0, 0.0, 1.0);
    } else {
        // This pixel does not belong to the geometry, render it the normal scene color
        float4 fragmentColor = colorSampler.sample(s, in.uv);
        return half4(fragmentColor);
    }
    
};

Swift Implementation:

static func fillTechnique(resolution: CGFloat, nodeCategoryBitmask: UInt32) -> SCNTechnique? {

    guard
        let fileUrl = Bundle.main.url(forResource: "FillTechnique", withExtension: "plist"),
        let data = try? Data(contentsOf: fileUrl),
        var result = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any] else {
        return nil
    }

    result[keyPath: "passes.pass_fill_drawMask.includeCategoryMask"] = Int(nodeCategoryBitmask)
    print(result)
    guard let technique = SCNTechnique(dictionary: result) else {
        fatalError("Unable to create outline technique")
    }
    
    technique.setObject(resolution, forKeyedSubscript: "resolution" as NSCopying)

    return technique

}

When applied to the original scene, where the sphere mesh has had it's category bitmask set to the same value applied to the includeCategoryBitmask in the technique's definition, we get the following output:

SceneKit render without depth testing

What I would like to do is take into account the depth of the fragments passed to the shader, such that the red sphere is not rendered "ontop" of geometry that is occluding it from the camera's view.

I know SceneKit provides both a COLOR and DEPTH input to SCNTechnique Metal shaders, and even without that input we can calculate the depth of a given fragment by multiplying a given vertex's z position by the node's modelTransform and the scene's viewProjectionTransform, also provided by SceneKit. To that end I have come up with modification of the above, which performs 3 passes:

  1. Renders an image containing the depth information of all of the nodes in the includeCategoryBitmask
  2. Renders an image containing the depth information of all of the nodes in the excludeCategoryBitmask (so we now have 2 depth images representing the node(s) we want to "stylize" and the node(s) we don't)
  3. Checks the depth of the fragment in the two images above, and, if the depth from the first image is greater than the depth of the second image, we can render a solid red fragment, otherwise we can render the normal scene's fragment color.

Below is this implementation:

Technique Definition:

<dict>
    <key>passes</key>
    <dict>
        <key>pass_fill_drawMask</key>
        <dict>
            <key>draw</key>
            <string>DRAW_SCENE</string>
            <key>program</key>
            <string>doesntexist</string>
            <key>metalVertexShader</key>
            <string>pass_fill_drawMask_vertex</string>
            <key>metalFragmentShader</key>
            <string>pass_fill_drawMask_fragment</string>
            <key>includeCategoryMask</key>
            <string>2</string>
            <key>colorStates</key>
            <dict>
                <key>clear</key>
                <true/>
            </dict>
            <key>inputs</key>
            <dict>
                <key>fillColorR</key>
                <string>fillColorR</string>
                <key>fillColorG</key>
                <string>fillColorG</string>
                <key>fillColorB</key>
                <string>fillColorB</string>
                <key>aPos</key>
                <string>vertexSymbol</string>
                <key>colorSampler</key>
                <string>COLOR</string>
            </dict>
            <key>outputs</key>
            <dict>
                <key>color</key>
                <string>MASK</string>
            </dict>
        </dict>
        <key>pass_fill_render</key>
        <dict>
            <key>draw</key>
            <string>DRAW_QUAD</string>
            <key>program</key>
            <string>doesntexist</string>
            <key>metalVertexShader</key>
            <string>pass_fill_render_vertex</string>
            <key>metalFragmentShader</key>
            <string>pass_fill_render_fragment</string>
            <key>inputs</key>
            <dict>
                <key>aPos</key>
                <string>vertexSymbol</string>
                <key>colorSampler</key>
                <string>COLOR</string>
                <key>maskSampler</key>
                <string>MASK</string>
                <key>resolution</key>
                <string>resolution</string>
            </dict>
            <key>outputs</key>
            <dict>
                <key>color</key>
                <string>COLOR</string>
            </dict>
        </dict>
    </dict>
    <key>sequence</key>
    <array>
        <string>pass_fill_drawMask</string>
        <string>pass_fill_render</string>
    </array>
    <key>symbols</key>
    <dict>
        <key>fillColorR</key>
        <dict>
            <key>type</key>
            <string>float</string>
        </dict>
        <key>fillColorG</key>
        <dict>
            <key>type</key>
            <string>float</string>
        </dict>
        <key>fillColorB</key>
        <dict>
            <key>type</key>
            <string>float</string>
        </dict>
        <key>resolution</key>
        <dict>
            <key>type</key>
            <string>float</string>
        </dict>
        <key>vertexSymbol</key>
        <dict>
            <key>semantic</key>
            <string>vertex</string>
        </dict>
    </dict>
    <key>targets</key>
    <dict>
        <key>MASK</key>
        <dict>
            <key>type</key>
            <string>color</string>
            <key>format</key>
            <string>rgb</string>
            <key>size</key>
            <string>1024x1024</string>
            <key>scaleFactor</key>
            <integer>1</integer>
        </dict>
    </dict>
</dict>

Shader:

#include <metal_stdlib>
using namespace metal;
#include <SceneKit/scn_metal>

struct Node {
    float4x4 modelTransform;
    float4x4 modelViewTransform;
    float4x4 normalTransform;
    float4x4 modelViewProjectionTransform;
};

struct VertexIn {
    float4 position [[attribute(SCNVertexSemanticPosition)]];
    float4 normal [[attribute(SCNVertexSemanticNormal)]];
};

struct VertexOut {
    float4 position [[position]];
    float2 uv;
};

typedef struct {
    float resolution;
} Uniforms;

constexpr sampler s = sampler(coord::normalized,
                              r_address::clamp_to_edge,
                              t_address::clamp_to_edge,
                              filter::linear);

////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// PASS 1 - Render depth for included geometries
////////////////////////////////////////////////////////////////////////////////////////////////////////////

vertex VertexOut pass_depth_incl_vertex(VertexIn in [[stage_in]],
                                        constant SCNSceneBuffer& scn_frame [[buffer(0)]],
                                        constant Node& scn_node [[buffer(1)]]) {
    VertexOut out;
    out.position = scn_node.modelViewProjectionTransform * float4(in.position.xyz, 1.0);
    
    // Store the screen depth in the position's z axis
    float4 depth = scn_frame.viewProjectionTransform * scn_node.modelTransform * in.position;
    out.position.z = depth.z;
    
    out.uv = float2((in.position.x + 1.0) * 0.5,1.0 - (in.position.y + 1.0) * 0.5);
    return out;
};

fragment half4 pass_depth_incl_fragment(VertexOut in [[stage_in]]) {
    return half4(in.position.z, in.position.z, in.position.z, 1.0);
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// PASS 2 - Render depth for excluded geometries
////////////////////////////////////////////////////////////////////////////////////////////////////////////

vertex VertexOut pass_depth_excl_vertex(VertexIn in [[stage_in]],
                                        constant SCNSceneBuffer& scn_frame [[buffer(0)]],
                                        constant Node& scn_node [[buffer(1)]]) {
    VertexOut out;
    out.position = scn_node.modelViewProjectionTransform * float4(in.position.xyz, 1.0);
    
    // Store the screen depth in the position's z axis
    float4 depth = scn_frame.viewProjectionTransform * scn_node.modelTransform * in.position;
    out.position.z = depth.z;
    
    out.uv = float2((in.position.x + 1.0) * 0.5,1.0 - (in.position.y + 1.0) * 0.5);
    return out;
};

fragment half4 pass_depth_excl_fragment(VertexOut in [[stage_in]]) {
    return half4(in.position.z, in.position.z, in.position.z, 1.0);
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// PASS 3 - Render fragments for pixels passing depth test
////////////////////////////////////////////////////////////////////////////////////////////////////////////

vertex VertexOut pass_depth_check_vertex(VertexIn in [[stage_in]]) {
    VertexOut out;
    out.position = in.position;
    out.uv = float2(
                    (in.position.x + 1.0) * 0.5,
                    1.0 - (in.position.y + 1.0) * 0.5
                    );
    return out;
};

fragment half4 pass_depth_check_fragment(VertexOut in [[stage_in]],
                                         texture2d<float, access::sample> inclSampler [[texture(0)]],
                                         texture2d<float, access::sample> exclSampler [[texture(1)]],
                                         texture2d<float, access::sample> colorSampler [[texture(2)]],
                                         constant SCNSceneBuffer& scn_frame [[buffer(0)]],
                                         constant Uniforms& uniforms [[buffer(1)]]) {

    float2 ratio = float2(colorSampler.get_width() / uniforms.resolution, colorSampler.get_height() / uniforms.resolution);
    ratio = float2(ratio.x > 1 ? 1 : ratio.x, ratio.y > 1 ? 1 : ratio.y);

    float4 inclColor = inclSampler.sample(s, in.uv * ratio);
    float4 exclColor = exclSampler.sample(s, in.uv * ratio);
    
    float inclDepth = inclColor.r;
    float exclDepth = exclColor.r;
    
    bool isBackground = inclColor.a == 0 && exclColor.a == 0;
    
    if (inclDepth >= exclDepth && !isBackground) {
        return half4(1.0, 0.0, 0.0, 1.0);
    } else {
        float4 color = colorSampler.sample(s, in.uv);
        return half4(color);
    }
    
};

Swift Implementation:

static func depthCheckTechnique(resolution: CGFloat, nodeCategoryBitmask: UInt32) -> SCNTechnique? {

    guard
        let fileUrl = Bundle.main.url(forResource: "DepthCheckTechnique", withExtension: "plist"),
        let data = try? Data(contentsOf: fileUrl),
        var result = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any] else {
        return nil
    }

    result[keyPath: "passes.pass_depth_incl.includeCategoryMask"] = Int(nodeCategoryBitmask)
    result[keyPath: "passes.pass_depth_excl.excludeCategoryBitmask"] = Int(nodeCategoryBitmask)
    result[keyPath: "targets.TARG_DEPTH_INCL.size"] = "\(resolution)x\(resolution)"
    result[keyPath: "targets.TARG_DEPTH_EXCL.size"] = "\(resolution)x\(resolution)"
    
    guard let technique = SCNTechnique(dictionary: result) else {
        fatalError("Unable to create outline technique")
    }

    technique.setObject(resolution, forKeyedSubscript: "resolution" as NSCopying)

    return technique

}

The above applied to the same scene as before, gets close to the desired result:

SceneKit shader with depth

However, there are a few issues:

  1. The edge where the sphere meets the triangle is jagged, and jitters as the camera orbits around the scene
  2. I don't believe this method is taking into account things like translucency or rendering orders, for example if I have a node in the scene with a higher rendering order so that it appears ontop of other nodes, or a node with a slightly transparent material, I don't think this approach will properly handle those situations.
  3. This feels like I may be re-inventing the wheel. Like I said I know SceneKit provides us with an actual depth input to read, but I'm not sure how to properly use it to achieve what I want.

Is this the best approach with SCNTechniques to create shaders that properly get occluded by fragments "in front" of the fragments we want to affect? Am I trying to use SCNTechnique for something it wasn't designed for? I'm fairly new to shader development.

Thanks!

  • Adam
1

There are 1 answers

0
Nick On

The inaccurate intersections problem you’re seeing is caused by encoding the depth value into a single channel of a colour target. Each output to the colour target gets stored into 32bits, but because you’re only using one channel, your full precision float gets packed into only 1 quarter (8bits) of those 32bits.

Instead of colour targets, you should render to two depth targets instead, that way you get a much higher precision depth when sampling in your final mix pass.

Depending on what you’re trying to achieve, you might be better off using a SCNProgram.