Macro Annotations in Scala 3

1k views Asked by At

The following is a quote from Macros: the Plan for Scala 3 from more than 3 years ago:

For instance, one will be able to define a macro annotation @json that adds a JSON serializer to a type.

Any idea how/if this is actually possible in Scala 3?

More generally, is there anything in Scala 3 that can provide "Macro Annotations" functionality? The following is a quote from Macro Annotations - Scala 2.13:

Unlike in the previous versions of macro paradise, macro annotations in 2.0 are done right in the sense that they: 1) apply not just to classes and objects, but to arbitrary definitions, 2) allow expansions of classes to modify or even create companion objects

2

There are 2 answers

0
francoisr On BEST ANSWER

As of June 2021, macro annotations are not supported in Scala 3, and they are not mentioned anywhere in the doc.

Right now, if you'd like to generate methods, classes or objects, I believe you have to use scalameta or write a compiler plugin.

Obviously, this situation might change in the future, macro annotations were not part of Scala 2 at the beginning either.

2
Dmytro Mitin On

Starting from Scala 3.3.0-RC2, there appeared macro annotations (implemented by Nicolas Stucki).

Macro annotation (part 1) https://github.com/lampepfl/dotty/pull/16392

Macro annotations class modifications (part 2) https://github.com/lampepfl/dotty/pull/16454

Enable returning classes from MacroAnnotations (part 3) https://github.com/lampepfl/dotty/pull/16534

New definitions are not visible from outside the macro expansion.


build.sbt

scalaVersion := "3.3.0-RC3"

Several examples:

  • Macro annotation @memoize adds memoization to a method
import scala.annotation.{MacroAnnotation, experimental}
import scala.collection.mutable
import scala.quoted.*

object Macros:
  @experimental
  class memoize extends MacroAnnotation:
    def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
      import quotes.reflect.*
      tree match
        case DefDef(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(rhsTree)) =>
          (Ref(param.symbol).asExpr, rhsTree.asExpr) match
            case ('{ $paramRefExpr: t }, '{ $rhsExpr: u }) =>
              val cacheTpe = TypeRepr.of[mutable.Map[t, u]]
              val cacheSymbol = Symbol.newVal(tree.symbol.owner, name + "Cache", cacheTpe, Flags.Private, Symbol.noSymbol)
              val cacheRhs = '{ mutable.Map.empty[t, u] }.asTerm
              val cacheVal = ValDef(cacheSymbol, Some(cacheRhs))
              val cacheRefExpr = Ref(cacheSymbol).asExprOf[mutable.Map[t, u]]
              val newRhs = '{ $cacheRefExpr.getOrElseUpdate($paramRefExpr, $rhsExpr) }.asTerm
              val newTree = DefDef.copy(tree)(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(newRhs))
              val res = List(cacheVal, newTree)
              println(res.map(_.show))
              res
        case _ =>
          report.error("Annotation only supported on `def` with a single argument are supported")
          List(tree)
import scala.annotation.experimental
import Macros.memoize

@experimental
object App:
  @memoize
  def fib(n: Int): Int =
    println(s"compute fib of $n")
    if n <= 1 then n else fib(n - 1) + fib(n - 2)

  def main(args: Array[String]): Unit =
    println(fib(5))

//scalac: List(val fibCache: scala.collection.mutable.Map[n, scala.Int] = scala.collection.mutable.Map.empty[n.type, scala.Int],
// @Macros.memoize def fib(n: scala.Int): scala.Int = App.fibCache.getOrElseUpdate(n, {
//  scala.Predef.println(_root_.scala.StringContext.apply("compute fib of ", "").s(n))
//  if (n.<=(1)) n else App.fib(n.-(1)).+(App.fib(n.-(2)))
//}))

//compute fib of 5
//compute fib of 4
//compute fib of 3
//compute fib of 2
//compute fib of 1
//compute fib of 0
//5
  • Macro annotation @equals generates methods equals and hashCode in a class
import scala.annotation.{MacroAnnotation, experimental}
import scala.quoted.*

object Macros:
  @experimental
  class equals extends MacroAnnotation:
    def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
      import quotes.reflect.*
      tree match
        case ClassDef(className, ctr, parents, self, body) =>
          val cls = tree.symbol

          val constructorParameters = ctr.paramss.collect { case clause: TermParamClause => clause }
          if constructorParameters.size != 1 || constructorParameters.head.params.isEmpty then
            report.errorAndAbort("@equals class must have a single argument list with at least one argument", ctr.pos)

          def checkNotOverridden(sym: Symbol): Unit =
            if sym.overridingSymbol(cls).exists then
              report.error(s"Cannot override ${sym.name} in a @equals class")

          val fields = body.collect {
            case vdef: ValDef if vdef.symbol.flags.is(Flags.ParamAccessor) =>
              Select(This(cls), vdef.symbol).asExpr
          }

          val equalsSym = Symbol.requiredMethod("java.lang.Object.equals")
          checkNotOverridden(equalsSym)
          val equalsOverrideSym = Symbol.newMethod(cls, "equals", equalsSym.info, Flags.Override, Symbol.noSymbol)

          def equalsOverrideDefBody(argss: List[List[Tree]]): Option[Term] =
            given Quotes = equalsOverrideSym.asQuotes

            cls.typeRef.asType match
              case '[c] =>
                Some(equalsExpr[c](argss.head.head.asExpr, fields).asTerm)

          val equalsOverrideDef = DefDef(equalsOverrideSym, equalsOverrideDefBody)

          val hashSym = Symbol.newVal(cls, Symbol.freshName("hash"), TypeRepr.of[Int], Flags.Private | Flags.Lazy, Symbol.noSymbol)
          val hashVal = ValDef(hashSym, Some(hashCodeExpr(className, fields)(using hashSym.asQuotes).asTerm))

          val hashCodeSym = Symbol.requiredMethod("java.lang.Object.hashCode")
          checkNotOverridden(hashCodeSym)
          val hashCodeOverrideSym = Symbol.newMethod(cls, "hashCode", hashCodeSym.info, Flags.Override, Symbol.noSymbol)
          val hashCodeOverrideDef = DefDef(hashCodeOverrideSym, _ => Some(Ref(hashSym)))

          val newBody = equalsOverrideDef :: hashVal :: hashCodeOverrideDef :: body
          val res = List(ClassDef.copy(tree)(className, ctr, parents, self, newBody))
          println(res.map(_.show))
          res
        case _ =>
          report.error("Annotation only supports `class`")
          List(tree)

    private def equalsExpr[T: Type](that: Expr[Any], thisFields: List[Expr[Any]])(using Quotes): Expr[Boolean] =
      '{
        $that match
          case that: T@unchecked =>
            ${
              val thatFields: List[Expr[Any]] =
                import quotes.reflect.*
                thisFields.map(field => Select('{ that }.asTerm, field.asTerm.symbol).asExpr)
              thisFields.zip(thatFields)
                .map { case (thisField, thatField) => '{ $thisField == $thatField } }
                .reduce { case (pred1, pred2) => '{ $pred1 && $pred2 } }
            }
          case _ => false
      }

    private def hashCodeExpr(className: String, thisFields: List[Expr[Any]])(using Quotes): Expr[Int] =
      '{
        var acc: Int = ${ Expr(scala.runtime.Statics.mix(-889275714, className.hashCode)) }
        ${
          Expr.block(
            thisFields.map {
              case '{ $field: Boolean } => '{ if $field then 1231 else 1237 }
              case '{ $field: Byte } => '{ $field.toInt }
              case '{ $field: Char } => '{ $field.toInt }
              case '{ $field: Short } => '{ $field.toInt }
              case '{ $field: Int } => field
              case '{ $field: Long } => '{ scala.runtime.Statics.longHash($field) }
              case '{ $field: Double } => '{ scala.runtime.Statics.doubleHash($field) }
              case '{ $field: Float } => '{ scala.runtime.Statics.floatHash($field) }
              case '{ $field: Null } => '{ 0 }
              case '{ $field: Unit } => '{ 0 }
              case field => '{ scala.runtime.Statics.anyHash($field) }
            }.map(hash => '{ acc = scala.runtime.Statics.mix(acc, $hash) }),
            '{ scala.runtime.Statics.finalizeHash(acc, ${ Expr(thisFields.size) }) }
          )
        }
      }
import scala.annotation.experimental
import Macros.equals

@experimental
object App:
  @equals
  class User(val name: String, val id: Int)

  def main(args: Array[String]): Unit =
    println(User("a", 1) == User("a", 1)) // true

//scalac: List(@Macros.equals class User(val name: scala.Predef.String, val id: scala.Int) {
//  override def equals(x$0: scala.Any): scala.Boolean = x$0 match {
//    case that: App.User @scala.unchecked =>
//      User.this.name.==(that.name).&&(User.this.id.==(that.id))
//    case _ =>
//      false
//  }
//  lazy val hash$macro$1: scala.Int = {
//    var acc: scala.Int = 515782504
//    acc = scala.runtime.Statics.mix(acc, scala.runtime.Statics.anyHash(User.this.name))
//    acc = scala.runtime.Statics.mix(acc, User.this.id)
//    scala.runtime.Statics.finalizeHash(acc, 2)
//  }
//  override def hashCode(): scala.Int = User.this.hash$macro$1
//})
  • Macro annotation @addClass generates a class near a method
import scala.annotation.{MacroAnnotation, experimental}
import scala.quoted.*

object Macros:
  @experimental
  class addClass extends MacroAnnotation:
    def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
      import quotes.reflect._
      tree match
        case DefDef(name, List(TermParamClause(Nil)), tpt, Some(rhs)) =>
          val parents = List(TypeTree.of[Object])
          def decls(cls: Symbol): List[Symbol] =
            List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.EmptyFlags, Symbol.noSymbol))

          val newClassName = Symbol.freshName("Bar")
          val cls = Symbol.newClass(Symbol.spliceOwner/*.owner*/, newClassName, parents = parents.map(_.tpe), decls, selfType = None)
          val runSym = cls.declaredMethod("run").head

          val runDef = DefDef(runSym, _ => Some(rhs))
          val clsDef = ClassDef(cls, parents, body = List(runDef))

          val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)

          val newDef = DefDef.copy(tree)(name, List(TermParamClause(Nil)), tpt, Some(Apply(Select(newCls, runSym), Nil)))
          val res = List(clsDef, newDef)
          println(res.map(_.show))
          res
        case _ =>
          report.error("Annotation only supports `def` with one argument")
          List(tree)
import Macros.addClass

import scala.annotation.experimental

object App:
  @addClass @experimental
  def bar(): Unit = println("bar")

//List(class Bar$macro$1 extends java.lang.Object {
//  def run(): scala.Unit = scala.Predef.println("bar")
//}, @scala.annotation.experimental @Macros.addClass def bar(): scala.Unit = new App.Bar$macro$1().run())
  • Macro annotation @mainMacro transforms a method into an object with runnable method main
import scala.annotation.{experimental, MacroAnnotation}
import scala.quoted._
import scala.collection.mutable

object Macros:
  @experimental
  class mainMacro extends MacroAnnotation:
    def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
      import quotes.reflect._
      tree match
        case DefDef(name, List(TermParamClause(Nil)), _, _) =>
          val parents = List(TypeTree.of[Object])
          def decls(cls: Symbol): List[Symbol] =
            List(Symbol.newMethod(cls, "main", MethodType(List("args"))(_ => List(TypeRepr.of[Array[String]]), _ => TypeRepr.of[Unit]), Flags.Static, Symbol.noSymbol))

          val cls = Symbol.newClass(Symbol.spliceOwner, name, parents = parents.map(_.tpe), decls, selfType = None)
          val mainSym = cls.declaredMethod("main").head

          val mainDef = DefDef(mainSym, _ => Some(Apply(Ref(tree.symbol), Nil)))
          val clsDef = ClassDef(cls, parents, body = List(mainDef))

          val res = List(clsDef, tree)
          println(res.map(_.show))
          res

        case _ =>
          report.error("Annotation only supports `def` without arguments")
          List(tree)
import Macros.mainMacro
import scala.annotation.experimental

@experimental
object App:
  @mainMacro def Test(): Unit = println("macro generated main")

//scalac: List(class Test extends java.lang.Object {
//  def main(args: scala.Array[scala.Predef.String]): scala.Unit = App.Test()
//}, @Macros.mainMacro def Test(): scala.Unit = scala.Predef.println("macro generated main"))

https://github.com/lampepfl/dotty/blob/3.3.0-RC3/library/src/scala/annotation/MacroAnnotation.scala

How to generate a class in Dotty with macro?

How to generate parameterless constructor at compile time using scala 3 macro?

Scala 3 macro to create enum

https://users.scala-lang.org/t/macro-annotations-replacement-in-scala-3/7374

https://contributors.scala-lang.org/t/sponsoring-work-on-scala-3-macro-annotations/5658

https://contributors.scala-lang.org/t/scala-3-macro-annotations-and-code-generation/6035

https://contributors.scala-lang.org/t/scala-3-macros-next-steps/6105

https://contributors.scala-lang.org/t/whitebox-macros-in-scala-3-are-possible-after-all/5014