I'm trying to implement an example at:
https://portal.klewel.com/watch/webcast/scala-days-2019/talk/37/
using scala continuation:
object ReverseGrad_CPSImproved {
import scala.util.continuations._
case class Num(
x: Double,
var d: Double = 0.0
) {
def +(that: Num) = shift { (cont: Num => Unit) =>
val y = Num(x + that.x)
cont(y)
this.d += y.d
that.d += y.d
}
def *(that: Num) = shift { (cont: Num => Unit) =>
val y = Num(x * that.x)
cont(y)
this.d += that.x * y.d
that.d += this.x * y.d
}
}
object Num {
implicit def fromX(x: Double): Num = Num(x)
}
def grad(f: Num => Num @cps[Unit])(x: Double): Double = {
val _x = Num(x)
reset { f(_x).d = 1.0 }
_x.d
}
}
This works as long as I'm using simple expression:
it("simple") {
val fn = { x: Num =>
val result = (x + 3) * (x + 4)
result
}
val gg = grad(fn)(3)
println(gg)
}
But once I started using loop it all fall apart:
it("benchmark") {
import scala.util.continuations._
for (i <- 1 to 20) {
val n = Math.pow(2, i).toInt
val fn = { x: Num =>
var result = x + 1
for (j <- 2 to n) {
result = result * (x + j)
}
result
}
val nanoFrom = System.nanoTime()
val gg = grad(fn)(3)
val nanoTo = System.nanoTime()
println(s"diff = $gg,\t time = ${nanoTo - nanoFrom}")
}
}
[Error] /home/peng/git-spike/scalaspike/meta/src/test/scala/com/tribbloids/spike/meta/multistage/lms/ReverseGrad_CPSImproved.scala:78: found cps expression in non-cps position
one error found
I'm under the impression that the continuation library should have its own loop implementation that can be rewritten into a recursion, but I cannot find it anywhere in the latest version (scala 2.12). What's the easiest way to use loop in this case?
In CPS you have to rewrite your code so that you will NOT perform a nested/iterative/recursive call in the same context and instead perform just one step of the computation and pass the partial result forward.
E.g. if you wanted to calculate a product of numbers A to B you could implement it this way:
(see this scastie).
The most interesting is this fragment:
Here, the compiler (plugin) rewrites this to be something similar to:
Compiler can do this because:
shift
andreset
callsA
and returns intermediate resultB
(usable in e.g. inside this or anotherreset
) and final resultC
(what you get when you run the final result of the composition) (denoted asA @ cpsParam[B, C]
- ifB =:= C
you can use a type aliasA @ cps[A]
)reset
makes it easier to not go insane with passing parameters around as it handles taking parameter (A
inA @ cpsParam[B, C]
) and passing it to all nested CPS calls and obtaining the intermediate result (soB
inA @ cpsParam[B, C]
) and making whole block returning the final result -C
A @ cpsParam[B, C]
shift
lifts function(A => B) => C
intoA @ cpsParam[B, C]
Input @cpsParam[Output1, Output2]
it knows that is should rewrite the code to introduce a parameter and pass it thereIn practice, it s slighly more complex underneath, but that's basically it.
Meanwhile you do your
outside of this context, where compiler doesn't perform any transformations. You have to at least compose all that CPS operations within
reset
. (Additionally you run things in a loop and mutation which can also be delegated to CPS).That said CPS (as in: this particular implementation) is dead. It was dropped in Scala 2.13, nobody supports it and using some trampoline-based monad (e.g.
Cont
from Cats) is much easier to understand, so the only places I still see it are outdated courses or articles about historical trivia.