Associated type for ndarray arguments in rust

235 views Asked by At

I want to create an interface for a (numeric) algorithm for which I want to provide an implementation with ndarray and similar libraries (let's say pytorch bindings)


struct A<D> {
   array: D
}

trait<D> T  {
   type ArgType;
   fn foo(&mut self, other: &ArgType);
}

The type of the second argument other should depend on the selected generic type D. With ndarray, there are two types of arrays–those owning their data and views that share their data. The best way is to accept both types seems to be something like the following:

fn bar<S>(a: &ArrayBase<S, Ix2>)
where
   S: Data<Elem = f64> {}

For the implementation of the trait T for A using ndarray, that would mean that I need something like this

use ndarray::{prelude::*, Data};
impl T<Array2<f64>> for A<Array2<f64>> {
   type ArgType<S>=ArrayBase<S: Data<Elem=f64>, Ix2>; 
   fn foo(&mut self, other: &Self::ArgType){ 
      ///
   }
}

Then, however, how would I add the new template parameter to foo. More importantly, generic associated types are not allowed in stable. I assume it would be simpler if there were a trait in ndarray that defined all methods, however, they are implemented directly for the ArrayBase type. I thought about writing a thin abstraction layer but that would be too complex (even if I only use a small subset of methods only). The asarray trait seemed to be a promising solution but it requires an lifetime parameter too (is there even a concept of associated traits?).

What would be the recommended way of handling this kind of situation; is there an easy way?

1

There are 1 answers

2
Brian Bowman On

Maybe this approach is what you want:

use ndarray::{ArrayBase, Ix2};

trait TheAlgorithm<OtherArg> {
    fn calculate(&self, other: &OtherArg) -> f64;
}

impl<ReprSelf, ReprOther> TheAlgorithm<ArrayBase<ReprOther, Ix2>>
    for ArrayBase<ReprSelf, Ix2>
where
    ReprSelf: ndarray::Data<Elem = f64>,
    ReprOther: ndarray::Data<Elem = f64>,
{
    fn calculate(&self, other: &ArrayBase<ReprOther, Ix2>) -> f64 {
        // dummy calculation
        self.view()[(0, 0)] + other.view()[(0, 0)]
    }
}

This lets you call the calculate() method on either owned arrays or views, and the other argument can be owned or a view in either case.