Quasiquotes in Scalafix

205 views Asked by At

Here is Spark 2.4 code using unionAll

import org.apache.spark.sql.{DataFrame, Dataset}

object UnionRewrite {
  def inSource(
    df1: DataFrame,
    df2: DataFrame,
    df3: DataFrame,
    ds1: Dataset[String],
    ds2: Dataset[String]
  ): Unit = {
    val res1 = df1.unionAll(df2)
    val res2 = df1.unionAll(df2).unionAll(df3)
    val res3 = Seq(df1, df2, df3).reduce(_ unionAll _)
    val res4 = ds1.unionAll(ds2)
    val res5 = Seq(ds1, ds2).reduce(_ unionAll _)
  }
}

In Spark 3.+ unionAll is deprecated. Here is equivalent code using union

import org.apache.spark.sql.{DataFrame, Dataset}

object UnionRewrite {
  def inSource(
    df1: DataFrame,
    df2: DataFrame,
    df3: DataFrame,
    ds1: Dataset[String],
    ds2: Dataset[String]
  ): Unit = {
    val res1 = df1.union(df2)
    val res2 = df1.union(df2).union(df3)
    val res3 = Seq(df1, df2, df3).reduce(_ union _)
    val res4 = ds1.union(ds2)
    val res5 = Seq(ds1, ds2).reduce(_ union _)
  }
}

The question is how to write a Scalafix rule (using quasiquotes) replacing unionAll with union?

Without quasiquotes I implemented the rule, it's working

override def fix(implicit doc: SemanticDocument): Patch = {
  def matchOnTree(t: Tree): Patch = {
    t.collect {
      case Term.Apply(
          Term.Select(_, deprecated @ Term.Name(name)),
          _
          ) if config.deprecatedMethod.contains(name) =>
        Patch.replaceTree(
          deprecated,
          config.deprecatedMethod(name)
        )
      case Term.Apply(
          Term.Select(_, _ @Term.Name(name)),
          List(
            Term.AnonymousFunction(
              Term.ApplyInfix(
                _,
                deprecatedAnm @ Term.Name(nameAnm),
                _,
                _
              )
            )
          )
          ) if "reduce".contains(name) && config.deprecatedMethod.contains(nameAnm) =>
        Patch.replaceTree(
          deprecatedAnm,
          config.deprecatedMethod(nameAnm)
        )
    }.asPatch
  }

  matchOnTree(doc.tree)
}
3

There are 3 answers

0
Dmytro Mitin On BEST ANSWER

Try the rule

override def fix(implicit doc: SemanticDocument): Patch = {

  def isDatasetSubtype(expr: Tree): Boolean =
    expr.symbol.info.flatMap(_.signature match {
      case ValueSignature(tpe)        => Some(tpe)
      case MethodSignature(_, _, tpe) => Some(tpe)
      case _                          => None
    }) match {
      case Some(TypeRef(_, symbol, _)) =>
        Seq("package.DataFrame", "Dataset")
          .map(tp => Symbol(s"org/apache/spark/sql/$tp#"))
          .contains(symbol)
      case _ => false
    }

  def mkPatch(ename: Tree): Patch = Patch.replaceTree(ename, "union")

  def matchOnTree(t: Tree): Patch =
    t.collect {
        case q"$expr.${ename@q"unionAll"}($expr1)" if isDatasetSubtype(expr) =>
          mkPatch(ename)

        // infix application
        case q"$expr ${ename@q"unionAll"} $expr1" /*if isDatasetSubtype(expr)*/ =>
          mkPatch(ename)
    }.asPatch

  matchOnTree(doc.tree)
}

It transforms

import org.apache.spark.sql.{DataFrame, Dataset}

object UnionRewrite {
  def inSource(
                df1: DataFrame,
                df2: DataFrame,
                df3: DataFrame,
                ds1: Dataset[String],
                ds2: Dataset[String]
              ): Unit = {
    val res1 = df1.unionAll(df2)
    val res2 = df1.unionAll(df2).unionAll(df3)
    val res3 = Seq(df1, df2, df3).reduce(_ unionAll _)
    val res4 = ds1.unionAll(ds2)
    val res5 = Seq(ds1, ds2).reduce(_ unionAll _)
    val res6 = Seq(ds1, ds2).reduce(_ unionAll (_))

    val unionAll = 42
  }
}

into

import org.apache.spark.sql.{DataFrame, Dataset}

object UnionRewrite {
  def inSource(
                df1: DataFrame,
                df2: DataFrame,
                df3: DataFrame,
                ds1: Dataset[String],
                ds2: Dataset[String]
              ): Unit = {
    val res1 = df1.union(df2)
    val res2 = df1.union(df2).union(df3)
    val res3 = Seq(df1, df2, df3).reduce(_ union _)
    val res4 = ds1.union(ds2)
    val res5 = Seq(ds1, ds2).reduce(_ union _)
    val res6 = Seq(ds1, ds2).reduce(_ union (_))

    val unionAll = 42
  }
}

https://scalacenter.github.io/scalafix/docs/developers/setup.html

https://scalameta.org/docs/trees/quasiquotes.html

https://scalameta.org/docs/semanticdb/guide.html

Your Ver: 1 implementation erroneously transformed val unionAll = 42 into val union = 42.

Sadly, <: Dataset[_] can't be checked for the infix application since SemanticDB seems not to have type information in this case (underscore _ in a lambda). This seems to be SemanticDB limitation. If you really needed subtype check in this case then maybe you would need a compiler plugin.


Update. We can use multiple rules: firstly apply the rule replacing underscore lambdas with parameter lambdas

override def fix(implicit doc: SemanticDocument): Patch = {
  def matchOnTree(t: Tree): Patch =
    t.collect {
      case t1@q"_.unionAll(_)" =>
        Patch.replaceTree(t1, "(x, y) => x.unionAll(y)")
      case t1@q"_ unionAll _" =>
        Patch.replaceTree(t1, "(x, y) => x unionAll y")
    }.asPatch

  matchOnTree(doc.tree)
}

then re-compile the code (new .semanticdb files will be generated), apply the second rule replacing unionAll with union (if types correspond)

override def fix(implicit doc: SemanticDocument): Patch = {

  def isDatasetSubtype(expr: Tree): Boolean = {
    expr.symbol.info.flatMap(_.signature match {
      case ValueSignature(tpe)        => Some(tpe)
      case MethodSignature(_, _, tpe) => Some(tpe)
      case _                          => None
    }) match {
      case Some(TypeRef(_, symbol, _)) =>
        Seq("package.DataFrame", "Dataset")
          .map(tp => Symbol(s"org/apache/spark/sql/$tp#"))
          .contains(symbol)
      case _ => false
    }
  }

  def mkPatch(ename: Tree): Patch = Patch.replaceTree(ename, "union")

  def matchOnTree(t: Tree): Patch =
    t.collect {
      case q"$expr.${ename@q"unionAll"}($_)" if isDatasetSubtype(expr) =>
        mkPatch(ename)
      case q"$expr ${ename@q"unionAll"} $_" if isDatasetSubtype(expr) =>
        mkPatch(ename)
    }.asPatch

  matchOnTree(doc.tree)
}

then apply the third rule replacing parameter lambdas back with underscore lambdas

override def fix(implicit doc: SemanticDocument): Patch = {
  def matchOnTree(t: Tree): Patch =
    t.collect {
      case t1@q"(x, y) => x.union(y)" =>
        Patch.replaceTree(t1, "_.union(_)")
      case t1@q"(x, y) => x union y" =>
        Patch.replaceTree(t1, "_ union _")
    }.asPatch

  matchOnTree(doc.tree)
}

The 1st and 3rd rules can be syntactic.

3
mvasyliv On

Ver: 1

package fix

import scalafix.v1._
import scala.meta._

class RuleQuasiquotesUnionAll extends SemanticRule("RuleQuasiquotesUnionAll") {
  override val description =
    """Quasiquotes in Scalafix. Replacing unionAll with union"""
  override val isRewrite = true

  override def fix(implicit doc: SemanticDocument): Patch = {

    def matchOnTree(t: Tree): Patch = {
      t.collect { case tt: Term =>
        tt match {
          case q"""unionAll""" =>
            Patch.replaceTree(tt, """union""")
          case _ => Patch.empty
        }
      }.asPatch
    }

    matchOnTree(doc.tree)
  }

}

Ver 2:

package fix
import scalafix.v1._
import scala.meta._
class UnionRewriteWithCheckType
    extends SemanticRule("UnionRewriteWithCheckType") {
  override val description = {
    """Replacing unionAll with union only forch Dataset and DataFrame"""
    // TODO: added type(s) to config
  }
  override val isRewrite = true

  override def fix(implicit doc: SemanticDocument): Patch = {

    def isDatasetDataFrame(
        tp: String,
        q: Term,
        a: List[Term]
    ): Boolean = {
      if (a.nonEmpty) {
        if (q.toString().indexOf("unionAll") >= 0 && tp == "DataFrame") {
          // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
          // When val res: Dataset[Row]= DataFrame1.unionAll(DataFrame2) !!
          // !!!!! result type Dataset[Row] !!!!!                        !!
          // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
          (q.symbol.info.get.signature.toString().indexOf("Dataset") >= 0)
            .equals(true) &&
          (a.head.symbol.info.get.signature.toString().indexOf(tp) >= 0)
            .equals(true)
        } else
          (q.symbol.info.get.signature.toString().indexOf(tp) >= 0)
            .equals(true) &&
          (a.head.symbol.info.get.signature.toString().indexOf(tp) >= 0)
            .equals(true)
      } else false
    }

    def matchOnTree(t: Tree): Patch = {
      t collect {
        case meth @ Defn.Def(a1, a2, a3, a4, a5, a6) =>
          a6.collect {
            case ta @ Term.Apply(
                  Term.Select(qual, trm @ q"""unionAll"""),
                  args
                ) =>
              if (
                isDatasetDataFrame(
                  "DataFrame",
                  qual,
                  args
                ) || isDatasetDataFrame("Dataset", qual, args)
              ) {

                Patch.replaceTree(
                  trm,
                  """union"""
                )
              } else Patch.empty
            case tasr @ Term.Apply(
                  Term.Select(qual, tnr @ q"""reduce"""),
                  args @ List(
                    Term.AnonymousFunction(
                      Term.ApplyInfix(_, op @ q"""unionAll""", _, _)
                    )
                  )
                ) =>
              if (
                qual.symbol.info.get.signature
                  .toString()
                  .indexOf("Dataset") >= 0 || qual.symbol.info.get.signature
                  .toString()
                  .indexOf("DataFrame") >= 0
              ) Patch.replaceTree(op, """union""")
              else Patch.empty
            case _ => Patch.empty
          }.asPatch
        case _ => Patch.empty
      }
    }.asPatch

    matchOnTree(doc.tree)
  }

}


enter code here
2
mvasyliv On

answer to Dmytro Mitin

Check 1. When we use Slick

def inSourceSlickUnionAll(): Unit = {
  case class Coffee(name: String, price: Double)
  class Coffees(tag: Tag) extends Table[(String, Double)](tag, "COFFEES") {
    def name = column[String]("COF_NAME")
    def price = column[Double]("PRICE")
    def * = (name, price)
  }

  val coffees = TableQuery[Coffees]

  val q1 = coffees.filter(_.price < 8.0)
  val q2 = coffees.filter(_.price > 9.0)

  val unionQuery = q1 union q2
  val unionAllQuery = q1 unionAll q2
  val unionAllQuery1 = q1 ++ q2
}

Result your rule

=======
=> Diff
=======
--- obtained
+++ expected
@@ -82,3 +82,3 @@
     val unionQuery = q1 union q2
-    val unionAllQuery = q1 union q2
+    val unionAllQuery = q1 unionAll q2
     val unionAllQuery1 = q1 ++ q2