How can I operate on pairs of arguments to macro_rules?

899 views Asked by At

I am building a simple feed forward neural network at compile time using const generics and macros. These are a bunch of matrices one after the other.

I have created the network! macro, which works like this:

network!(2, 4, 1)

The first item is the number of inputs, and the rest are the number of neurons per layer. The macro looks as follows:

#[macro_export]
macro_rules! network {
    ( $inputs:expr, $($outputs:expr),* ) => {
        {
            Network {
                layers: [
                    $(
                        &Layer::<$inputs, $outputs>::new(),
                    )*
                ]
            }
        }
    };
}

It declares an array of layer elements, which use const generics to to have a fixed size array of weights on each layer, the first type parameter is the number of inputs it expects and the second type parameter is the number of outputs.

This macro produces the following code:

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<2, 1>::new(),
    ]
}

This is completely wrong, because for each layer the number of inputs should be the number of outputs of the previous one, like so (notice 2 -> 4):

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<4, 1>::new(),
    ]
}

To do this, I need to replace the $inputs value by the value of $outputs on each iteration, but I have no clue how to do.

2

There are 2 answers

2
Shepmaster On BEST ANSWER

You can match on the two leading values and then all the rest. Do something specific for the two values and call the macro recursively, reusing the second value:

struct Layer<const I: usize, const O: usize>;

macro_rules! example {
    // Do something interesting for a given pair of arguments
    ($a:literal, $b:literal) => {
        Layer::<$a, $b>;
    };

    // Recursively traverse the arguments
    ($a:literal, $b:literal, $($rest:literal),+) => {
        example!($a, $b);
        example!($b, $($rest),*);
    };
}

fn main() {
    example!(1, 2, 3);
}

Expanding the macro leads to:

fn main() {
    Layer::<1, 2>;
    Layer::<2, 3>;
}
0
codearm On

For those interested I finally was able to populate my network like this, based on @Shepmaster's answer:

struct Network<'a, const L: usize> {
    layers: [&'a dyn Forward; L],
}

macro_rules! network {
    // Recursively accumulate token tree
    (@accum ($a:literal, $b:literal, $($others:literal),+) $($e:tt)*) => {
        network!(@accum ($b, $($others),*) $($e)*, &Layer::<$a, $b>::new())
    };

    // Latest iteration, convert to expression
    (@accum ($a:literal, $b:literal) $($e:tt)*) => {[$($e)*, &Layer::<$a, $b>::new()]};

    // Entrance
    ($a:literal, $b:literal, $($others:literal),+) => {
        Network {
            layers: network!(@accum ($b, $($others),*) &Layer::<$a, $b>::new())
        }
    };
}

For network!(2, 3, 4, 5, 1) it translates to:

Network {
     layers:
          [&Layer::<2, 3>::new(),
           &Layer::<3, 4>::new(),
           &Layer::<4, 5>::new(),
           &Layer::<5, 1>::new()]
};