Stuck whilst applying the step missing from decoding a PNG file, in Luau (defiltering)

74 views Asked by At

I am attempting to decode PNG files in Luau, and so far I have made great progress, to the point of being able to make out the image in the final, decoded result, despite clearly being incorrect. The one thing I need to fix is the de-filtering (reconstruction) process.

So far I have correctly parsed the different chunks of the file, and decompressed the IDAT chunk with an external library. I have also converted the raw data to the format I need (a table consisting of numbers between 0-1, 4 for each pixel, each number corresponding to an RGBA channel), but the data is "filtered" before being compressed in order to optimize the compression. This of course means that I need to reverse this process (de-filter) in order to view the correct results.

For every scanline, I am correctly getting the byte used to determine which filter function was used on the scanline. All I need to do is apply the correct function to reverse the filtering per scanline (Recon(x)). To test this, I'm using the test images provided by PNGSuite (specifically the colored ones), because only one type of filter is used on the entire image, which makes it easier to see the results of each individual de-filtering function.

Here are the results for each image:

  • 0 (No Operation): output image is correct!

Test image f00n2c08.png after being decoded

  • 1 (Sub) image looks nearly correct, but the red channel is a little weird:

Test image f01n2c08.png after being decoded

  • 2 (Up) the shapes in the image resemble the correct result, but it's definitely very wrong:

Test image f02n2c08.png after being decoded

  • 3 (Average) the number in the image can be seen, but otherwise, it's nearly completely incomprehensible:

Test image f03n2c08.png after being decoded

  • 4 (Paeth Predictor) shows patterns similar to Up, but it's also very wrong:

Test image f04n2c08.png after being decoded

(Note: Although I know it's not a good idea, i don't plan on supporting bit depths below 8, and for testing purposes, I'm only supporting a bit depth of 8, but do plan on supporting 16-bit color later on.)

(And for any of those familiar with Lua but not with Luau, my code does make heavy use of type annotations and the built-in buffer library. It shouldn't be too hard to figure out how these work, but they are officially documented here if needed):

-- Util function to query a certain number of bits from the image data bytes
-- Will be useful if support for bit depths other than 8 are introduced
-- From here, CurrentPos shall be measured in bits instead of bytes
CurrentPos = 0
local CurrentBuffer:buffer = DecompressedData
local function GetInt(BitDepth:number):number
    local ByteIndex:number = math.floor(CurrentPos / 8)
    local BitOffset:number = (CurrentPos * BitDepth) % 8

    local Int:number = 0
    
    if BitDepth == 8 then
        Int = buffer.readu8(CurrentBuffer, ByteIndex)
    elseif BitDepth == 16 then
        Int = buffer.readu16(CurrentBuffer, ByteIndex)
    elseif BitDepth == 32 then
        Int = buffer.readu32(CurrentBuffer, ByteIndex)
    else
        -- For bit depths less than 8, you need to handle bit-wise extraction
        local Byte:number = buffer.readu8(CurrentBuffer, ByteIndex)
        local Shift:number = 8 - BitDepth - BitOffset
        return bit32.rshift(bit32.band(Byte, bit32.lshift(0xFF, BitOffset)), Shift)
    end

    CurrentPos += BitDepth
    return Int
end
    
-- Defilter process
-- Start by splitting the bytes into scanlines

-- How many bytes are in a scanline? (not including the filter byte)
local ScanlineBytes:number = math.floor(ImageStruct.BitDepth * ImageStruct.Size.X / 8) 
local PixelBytes:number = 1 -- how many bytes in a single pixel?
local Scanlines:{buffer} = {} -- List of scanlines processed

if ImageStruct.ColorType == 2 then -- Truecolor
    PixelBytes = 3
elseif ImageStruct.ColorType == 3 then -- Indexed color
    PixelBytes = 1 -- lol....
elseif ImageStruct.ColorType == 6 then -- Truecolor with alpha
    PixelBytes = 4
end -- No support for greyscale color types yet sory
ScanlineBytes *= PixelBytes -- The amount of bytes in a scanline is actually the width * how many bytes are in one pixel.

-- For every scanline
for Y:number = 1, ImageStruct.Size.Y do
    
    -- Get this scanline's filter byte
    local FilterByte:number = GetInt(8)
    
    -- These should be unfiltered
    --local PreviousScanLine:buffer? = Scanlines[Y - 1] -- Warning: will be null if we are on the first scanline
    local PreviousScanLine:buffer? = nil
    local PreviousByte:number = 0 -- Currently unused variable
    
    local Scanline:buffer = buffer.create(ScanlineBytes)
    local ScanlineUnfiltered:buffer = buffer.create(ScanlineBytes)
        
    -- For every byte on the scanline (defiltering works with bytes, not pixels, so we can disregard bit depth for now)
    for X:number = 0, ScanlineBytes-1, 1 do
            
        -- Get this byte
        local Byte:number = GetInt(8)
            
        --local ByteA:number = PreviousByte
            
        -- Get the byte in the previous pixel (the result should be 0 if we're on the first pixel)
        local ByteA:number = if (X > PixelBytes) then buffer.readu8(Scanline, X - PixelBytes) else 0
            
        -- Get the same byte in the previous scanline (the result should be 0 if we're on the first scanline)
        local ByteB:number = if PreviousScanLine then buffer.readu8(PreviousScanLine, X) else 0
            
        -- Get the byte in the previous pixel on the previous scanline (the result should be 0 if blah blah blah you get the point)
        local ByteC:number = if (PreviousScanLine and (X > PixelBytes)) then buffer.readu8(PreviousScanLine, X - PixelBytes) else 0
        
        -- Apply appropriate defiltering function
        local DefilteredByte:number = m:ApplyDefiltering(FilterByte, Byte, ByteA, ByteB, ByteC) % 256

        buffer.writeu8(ScanlineUnfiltered, X, Byte)
        buffer.writeu8(Scanline, X, DefilteredByte)
        --PreviousByte = DefilteredByte
        PreviousByte = Byte
    end
    --GetInt(8)
    PreviousScanLine = ScanlineUnfiltered
    table.insert(Scanlines, Scanline)
end
    
    
-- Finally, we need to actually get the pixels
local Pixels:{number} = {}
for Y:number = 1, ImageStruct.Size.Y do
    local Scanline:buffer = Scanlines[Y]
    
    -- Set variables used by GetInt (we are reading from the current scanline)
    CurrentBuffer = Scanline
    CurrentPos = 0
    
    for X:number = 1, ImageStruct.Size.X, 1 do
        local R:number, G:number, B:number, A:number
        
        if ImageStruct.ColorType == 2 then -- Truecolor
            R = GetInt(8) / 255
            G = GetInt(8) / 255
            B = GetInt(8) / 255
            A = 1
        elseif ImageStruct.ColorType == 3 then -- Indexed color
            local PaletteIndex:number = GetInt(8)
            local Color:Color3 = ImageStruct.Pallette[PaletteIndex] -- Get the color from the pallette (PLTE)
            R = Color.R
            G = Color.G
            B = Color.B
            A = 1
        elseif ImageStruct.ColorType == 6 then -- Truecolor with alpha
            R = GetInt(8) / 255
            G = GetInt(8) / 255
            B = GetInt(8) / 255
            A = GetInt(8) / 255
        end
        
        table.insert(Pixels, R)
        table.insert(Pixels, G)
        table.insert(Pixels, B)
        table.insert(Pixels, A)
    end
end

Additionally, here's the function that applies de-filtering to a byte:

function m:ApplyDefiltering(FilterByte:number, Byte:number, ByteA:number, ByteB:number, ByteC:number)
    if FilterByte == 0 then -- None
        return Byte
    elseif FilterByte == 1 then -- Sub
        return Byte + ByteA
    elseif FilterByte == 2 then -- Up
        return Byte + ByteB
    elseif FilterByte == 3 then -- Average
        -- Average(x) + floor((ByteA + ByteB) / 2)
        return Byte + math.floor((ByteA + ByteB) / 2)
    elseif FilterByte == 4 then -- Paeth Predictor
        local P:number = ByteA + ByteB - ByteC
        local PA:number = math.abs(P - ByteA)
        local PB:number = math.abs(P - ByteB)
        local PC:number = math.abs(P - ByteC)
        local PR:number
        if (PA <= PB) and (PA <= PC) then 
            PR = ByteA
        elseif PB <= PC then 
            PR = ByteB
        else 
            PR = ByteC
        end
        return Byte + PR
    end
end
0

There are 0 answers