How to avoid code repetition where types need to be chosen dynamically?

117 views Asked by At

The following code is a simplified example of a video stream parser. The input is binary data containing of video and audio frames. Each frame is comprised of:

  1. Frame type flag indicating whether it is video or audio frame
  2. Header
  3. Payload

The goal is to parse the stream, extract fields from the headers and the payload.

So, the first approach is:

package main
import (
    "fmt"
    "encoding/binary"
    "bytes"
)

type Type byte

const (
    Video  Type = 0xFC
    Audio   Type = 0xFA
)

var HMap = map[Type]string {
    Video:   "Video",
    Audio:   "Audio",
}

type CommonHeader struct {
    Type      Type
}

type HeaderVideo struct {
    Width       uint16
    Height      uint16
    Length      uint32
}

type HeaderAudio struct {
    SampleRate  uint16
    Length      uint16
}


func main() {
    data := bytes.NewReader([]byte{0xFC, 0x80, 0x07, 0x38, 0x04, 0x02, 0x00, 0x00, 0x00, 0xFF, 0xAF, 0xFA, 0x10, 0x00, 0x01, 0x00, 0xFF})
    var cHeader CommonHeader
    var dataLength int
    for {
        err := binary.Read(data, binary.LittleEndian, &cHeader)
        if err != nil {
            break
        }
        fmt.Println(HMap[cHeader.Type])
        switch cHeader.Type {
            case Video:
                var info HeaderVideo
                binary.Read(data, binary.LittleEndian, &info)
                dataLength = int(info.Length)
                fmt.Println(info)
            case Audio:
                var info HeaderAudio
                binary.Read(data, binary.LittleEndian, &info)
                dataLength = int(info.Length)
                fmt.Println(info)
        }
        payload := make([]byte, dataLength)
        data.Read(payload)
        fmt.Println(payload)
    }
}

It works, but I don't like the code repetition in the switch cases. Essentially, we have to repeat the same code just because the frame types are different.

One approach trying to avoid the repetition is this:

package main
import (
    "fmt"
    "encoding/binary"
    "bytes"
)

type Type byte

const (
    Video  Type = 0xFC
    Audio   Type = 0xFA
)

var HMap = map[Type]string {
    Video:   "Video",
    Audio:   "Audio",
}

type CommonHeader struct {
    Type      Type
}

type Header interface {
    GetLength() int
}

type HeaderVideo struct {
    Width       uint16
    Height      uint16
    Length      uint32
}

func (h HeaderVideo) GetLength() int {
    return int(h.Length)
}

type HeaderAudio struct {
    SampleRate  uint16
    Length      uint16
}

func (h HeaderAudio) GetLength() int {
    return int(h.Length)
}

var TMap = map[Type]func() Header {
    Video:     func() Header { return &HeaderVideo{} },
    Audio:     func() Header { return &HeaderAudio{} },
}

func main() {
    data := bytes.NewReader([]byte{0xFC, 0x80, 0x07, 0x38, 0x04, 0x02, 0x00, 0x00, 0x00, 0xFF, 0xAF, 0xFA, 0x10, 0x00, 0x01, 0x00, 0xFF})
    var cHeader CommonHeader
    for {
        err := binary.Read(data, binary.LittleEndian, &cHeader)
        if err != nil {
            break
        }
        fmt.Println(HMap[cHeader.Type])
        info := TMap[cHeader.Type]()
        binary.Read(data, binary.LittleEndian, info)
        fmt.Println(info)
        payload := make([]byte, info.GetLength())
        data.Read(payload)
        fmt.Println(payload)
    }
}

That is, we implement dynamic type selection by introducing the TMap map which allows to create an instance of the right struct depending on the frame type. However, this solution comes at the cost of repeating the GetLength() method — for each of the frame types.

I find it quite disconcerting that there doesn't seem to be a way to avoid repetition completely. Am I missing some way, or is it just a limitation of the language?

Here is a related question (which was actually triggered by the same problem), however, its premise omits the need for dynamic type selection, and thus the accepted solution (using generics) does not help.

4

There are 4 answers

1
Dental Floss Tycoon On BEST ANSWER

The King's answer requires duplication for each integer type used to encode the length. Mondarin's answer uses the dreaded reflect package. Here's a solution that avoids both problems. This answer is based on the King's answer.

Declare a generic type with the GetLength() method.

type Length[T uint8 | uint16 | uint32 | uint64] struct { Length T }

func (l Length[T]) GetLength() int { return int(l.Length) }

Remove the GetLength method from each header type. Embed the generic length type in each header type:

type HeaderVideo struct {
    Width  uint16
    Height uint16
    Length[uint32]
}

type HeaderAudio struct {
    SampleRate uint16
    Length[uint16]
}

Declare TMap as in the question. The GetLength method is provided by embedded field.

var TMap = map[Type]func() Header{
    Video: func() Header { return &HeaderVideo{} },
    Audio: func() Header { return &HeaderAudio{} },
}

https://go.dev/play/p/H2gWStsouly

(Like the code in the question, this answer uses the reflect package indirectly through the binary.Read function. The reflect package is a great tool for keeping code DRY.)

1
Pak Uula On

You can wrap the details of stream processing into methods of custom types:

package main

import (
    "bytes"
    "encoding/binary"
    "fmt"
    "io"
    "log"
)

type Type byte

const (
    Video Type = 0xFC
    Audio Type = 0xFA
)

var HMap = map[Type]string{
    Video: "Video",
    Audio: "Audio",
}

type CommonHeader struct {
    Type Type
}

func (ch CommonHeader) GetMediaHeader() MediaHeader {
    switch ch.Type {
    case Video:
        return &HeaderVideo{}
    case Audio:
        return &HeaderAudio{}
    default:
        panic("Unsupported media type")
    }
}

func (ch *CommonHeader) Fill(r io.Reader, o binary.ByteOrder) error {
    return binary.Read(r, o, ch)
}

func (ch CommonHeader) String() string {
    return HMap[ch.Type]
}

type MediaHeader interface {
    GetLength() int
    Fill(io.Reader, binary.ByteOrder) error
}

type HeaderVideo struct {
    Width  uint16
    Height uint16
    Length uint32
}

// Fill implements MediaHeader.
func (hv *HeaderVideo) Fill(r io.Reader, o binary.ByteOrder) error {
    return binary.Read(r, o, hv)
}

func (hv HeaderVideo) GetLength() int {
    return int(hv.Length)
}

func (hv HeaderVideo) String() string {
    return fmt.Sprintf("%#v", hv)
}

var _ MediaHeader = &HeaderVideo{}

type HeaderAudio struct {
    SampleRate uint16
    Length     uint16
}

func (ha *HeaderAudio) Fill(r io.Reader, o binary.ByteOrder) error {
    return binary.Read(r, o, ha)
}

func (ha HeaderAudio) GetLength() int {
    return int(ha.Length)
}

func (ha HeaderAudio) String() string {
    return fmt.Sprintf("%#v", ha)
}

var _ MediaHeader = &HeaderAudio{}

func main() {
    data := bytes.NewReader([]byte{0xFC, 0x80, 0x07, 0x38, 0x04, 0x02, 0x00, 0x00, 0x00, 0xFF, 0xAF, 0xFA, 0x10, 0x00, 0x01, 0x00, 0xFF})
    cHeader := &CommonHeader{}
    var dataLength int
    for {
        err := cHeader.Fill(data, binary.LittleEndian)
        if err != nil {
            log.Println(err)
            break
        }
        fmt.Println("cHeader: ", cHeader)
        mediaHeader := cHeader.GetMediaHeader()
        err = mediaHeader.Fill(data, binary.LittleEndian)
        if err != nil {
            log.Println(err)
            break
        }
        dataLength = mediaHeader.GetLength()
        fmt.Println("Media header: ", mediaHeader)

        payload := make([]byte, dataLength)
        data.Read(payload)
        fmt.Println("Payload: ", payload)
    }
}

The key feature is cHeader.GetMediaHeader() factory method. It knows the stream type and produces the proper header type.

Each media header type knows how to parse stream (method Fill) and how to extract payload length. With interface MediaHeader the code of main is streamlined, no switch and duplication of business logic.

1
The King of io.Reader On

Here's how to eliminate the header type GetLength method in the second example:

Declare a type for each encoding of length (uint16, uint32, ...). Implement the GetLength() int method on each of these types.

type Luint32 uint32

func (l Luint32) GetLength() int { return int(l) }

type Luint16 uint16

func (l Luint16) GetLength() int { return int(l) }

Embed one of these types in the header type. Remove the header's explicit GetLength method.

type HeaderVideo struct {
    Width  uint16
    Height uint16
    Luint32
}

type HeaderAudio struct {
    SampleRate uint16
    Luint16
}

Declare TMap as in the question. The GetLength method is provided by embedded field.

var TMap = map[Type]func() Header{
    Video: func() Header { return &HeaderVideo{} },
    Audio: func() Header { return &HeaderAudio{} },
}

https://go.dev/play/p/-POlpa_z0VR

0
Mondarian Sequester On

Use the reflect package.

Here's how to get the length from any of the header types:

func GetLength(v any) int {
    return int(reflect.ValueOf(v).Elem().FieldByName("Length").Uint())
}

Change the type of the factory functions to return any:

var TMap = map[Type]func() any {
    Video: func() any { return &HeaderVideo{} },
    Audio: func() any { return &HeaderAudio{} },
}

Call the function above instead of the GetLength method:

    payload := make([]byte, GetLength(info))

Bonus suggestion: Use io.ReadFull to ensure that the payload buffer is filled.

    io.ReadFull(data, payload)

Bonus bonus suggestion: Check and handle errors.