Conversion of an abstract syntax tree with R

93 views Asked by At

Given an arithmetric expression, for example x + y*z, I want to convert it to add(x, multiply(y, z)).

I found a helpful function here:

> getAST <- function(ee) purrr::map_if(as.list(ee), is.call, getAST)
> getAST(quote(x + y*z)) 
[[1]]
`+`

[[2]]
x

[[3]]
[[3]][[1]]
`*`

[[3]][[2]]
y

[[3]][[3]]
z

One can use rapply(result, as.character, how = "list") to get characters instead of symbols.

How to get add(x, multiply(y, z)) from this AST (the result)? This becomes more complicated when there are some parentheses:

> getAST(quote((x + y) * z)) 
[[1]]
`*`

[[2]]
[[2]][[1]]
`(`

[[2]][[2]]
[[2]][[2]][[1]]
`+`

[[2]][[2]][[2]]
x

[[2]][[2]][[3]]
y



[[3]]
z

I don't require the answer must use the getAST function. It's just a possible way to go.

Of course in my real use case the expressions are longer.


Here is a solution (I think) for the case when there's no parentheses:

getAST <- function(ee) purrr::map_if(as.list(ee), is.call, getAST)

ast <- rapply(getAST(quote(x + y*z)), as.character, how = "list")

convertAST <- function(ast) {
  op <- switch(
    ast[[1]],
    "+" = "add",
    "-" = "subtract",
    "*" = "multiply",
    "/" = "divide"
  )
  left <- ast[[2]]
  right <- ast[[3]]
  if(is.character(left) && is.character(right)) {
    return(sprintf("%s(%s, %s)", op, left, right))
  }
  if(is.character(left)) {
    return(sprintf("%s(%s, %s)", op, left, convertAST(right)))
  }
  if(is.character(right)) {
    return(sprintf("%s(%s, %s)", op, convertAST(left), right))
  }
  return(sprintf("%s(%s, %s)", op, convertAST(left), convertAST(right)))
}

convertAST(ast)
3

There are 3 answers

0
G. Grothendieck On BEST ANSWER

We can use substitute like this:

subst <- function(e, sub = list(`+` = "add", 
                                `-` = "minus",
                                `/` = "divide",
                                `*` = "multiply")) {
  sub <- Map(as.name, sub)
  do.call("substitute", list(e, sub))
}

# test
e <- quote(x + (y + 1) * z)
res <- subst(e); res
## add(x, multiply((add(y, 1)), z))

# evaluate test against values
add <- `+`; multiply <- `*`; x <- 1; y <- 2; z <- 3
eval(res)
## [1] 10

If you want a character string result then

deparse1(subst(e))
## [1] "add(x, multiply((add(y, 1)), z))"
0
Allan Cameron On

It may just be because I don't understand rapply very well, but any time I try to use it my code is more complex than just writing my own recursive function.

In this case I have put the recursive function in a thin wrapper that permits direct entry of expressions without using quote (if desired)

sub_call <- function(input, direct = TRUE,
                     subs = list(`+` = "add", `-` = "minus", 
                                 `/` = "divide", `*` = "multiply")) {
  scall <- function(x, subs) {
    if(is.call(x))
    {
      if(as.character(x[[1]]) %in% names(subs)) {
        x[[1]] <- str2lang(subs[[match(as.character(x[[1]]), names(subs))]])
      }
    }
    if(length(x) == 1) return(x) 
    x[-1] <- lapply(x[-1], scall, subs = subs)
    x
  }

  if(direct) return(scall(as.list(match.call())$input, subs))
  return(scall(input, subs))
}

This allows direct input of an expression:

sub_call(x + y*z)
#> add(x, multiply(y, z))

Or indirect input:

my_expr <- quote(x + y*z)

sub_call(my_expr, direct = FALSE)
#> add(x, multiply(y, z))

And handles arbitrarily deep nesting, leaving parentheses intact:

sub_call(sin(((x + (1/3))^2)))
#> sin(((add(x, (divide(1, 3))))^2))
0
Joris C. On

Using an external package, this can also be done with rrapply() (in package rrapply), which --in contrast to base rapply()-- also recurses through call objects/expression vectors:

## examples 
lang <- quote(x + (y + 1) * z)
expr <- expression(x + (y + 1) * z, sin(((x + (1/3))^2)))

## replacement function
replace_symbol <- function(s) {
    switch(
      as.character(s),
      '+' = quote(add),
      '-' = quote(minus),
      '/' = quote(divide),
      '*' = quote(multiply),
      s
    )
}

lang1 <- rrapply::rrapply(lang, f = replace_symbol, how = "replace")
str(lang1)
#>  language add(x, multiply((add(y, 1)), z))

expr1 <- rrapply::rrapply(expr, f = replace_symbol, how = "replace")
str(expr1)
#>   expression(add(x, multiply((add(y, 1)), z)), sin(((add(x, (divide(1, 3))))^2)))