Defining partial function with variable arguments in Scala

1.7k views Asked by At

I am trying to define a generic framework for validation. I was thinking of using partial function to define the rules and pass them into a generic function to eval and return result. Code looks like following:

  def assert(name: TField, validator: PartialFunction[(Any*), String], input: Any*): Seq[ValidationError] = {
    if (validator.isDefinedAt(input)) {
      invalidInput(name.name, validator(input))
    } else {
      Seq.empty
    }
  }

  assert(fieldA, hasValidValue, inputValue, allowedValues)
  assert(fieldB, isPositive, input)

  // =-=-=-=--=-=-= Validation Rules =-=-=-=-=-=-=-=
  def hasValidValue[T] = PartialFunction[(T, Set[T]), String] {
    case (input, validValues) if !validValues.contains(input) => "Value not allowed"
  }

  def isPositive = PartialFunction[Long, String] {
    case value: Long if value <= 0 => "Value should always be positive"
  }

But I cannot figure out how to define partial function parameter with variable arguments on this line:

  def assert(name: TField, validator: PartialFunction[(Any*), String], input: Any*): Seq[ValidationError] = {

So, even though the above definition compiles fine, there is compiler error while actually trying to call assert:

// Error: Type mismatch, Expected: ParitalFunction[Any, String], Found: PartialFunction[(Nothing, Set[Nothing]), String]
  assert(fieldA, hasValidValue, inputValue, allowedValues)

// Error: Type mismatch, Expected: ParitalFunction[Any, String], Found: PartialFunction[Long, String]
  assert(fieldB, isPositive, input)

So how can I define this?

1

There are 1 answers

2
bjfletcher On

All you need to do is to change PartialFunction[(Any*), String] to PartialFunction[Seq[Any], String]. This is because the vararg input becomes a Seq[List] and this is what the validator would actually take in.

Update

Here's a working demo of how the vararg can work:

case class TField(name: String)
type ValidationError = String
def invalidInput(s: String, v: String): Seq[String] = Seq(v)

def assert(name: TField, validator: PartialFunction[Seq[Any], String], input: Any*): Seq[ValidationError] = {
  if (validator.isDefinedAt(input)) {
    invalidInput(name.name, validator(input))
  } else {
    Seq.empty
  }
}

def isPositive = PartialFunction[Seq[Any], String] {
  case value if value.length < 2 => "Value need to have at least 2 values"
  case _ => "Ok"
}

assert(TField("s"), isPositive, 5L)
assert(TField("s"), isPositive, "s1", 1, -2L, "s3")