Implementing a Python interface for a Rust function with generic Rust type

102 views Asked by At

This function works perfectly on Rust:

fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
where
    T: Hash + Eq + Clone,
{
    let s1 = vec_to_set(&s1);
    let s2 = vec_to_set(&s2);
    let i = s1.intersection(&s2).count() as f32;
    let u = s1.union(&s2).count() as f32;
    return i / u;
}

fn vec_to_set<T>(vec: &Vec<T>) -> HashSet<T>
where
    T: Hash + Eq + Clone,{
    HashSet::from_iter(vec.iter().cloned())
}

on the following test cases:

#[test]
fn test_jaccard_similarity() {
    let left = vec!["kitten", "sitting", "saturday", "sunday"];
    let right = vec!["kitten", "sitting", "saturday", "sunday"];
    assert_eq!(jaccard_similarity(left, right), 1.0);
    let left = vec![1,2,3,4];
    let right = vec![1,2,3,4];
    assert_eq!(jaccard_similarity(left, right), 1.0);
    let left = vec![1,2,3,4];
    let right = vec![2,2,3,4];
    assert_eq!(jaccard_similarity(left, right), 0.75);

}

However, as soon as I wrap it as a #[pyfunction] of the pyo3 crate [version: 0.13.2] (and I also update my lib.rs and mod.rs files). For context, I am using the Maturin library.

#[pyfunction]
fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
where
    T: Hash + Eq + Clone,
{
    let s1 = vec_to_set(&s1);
    let s2 = vec_to_set(&s2);
    let i = s1.intersection(&s2).count() as f32;
    let u = s1.union(&s2).count() as f32;
    return i / u;
}

I get the following error:

--> src\distance_functions\jaccard_similarity.rs:6:4
  |
6 | fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
  |    ^^^^^^^^^^^^^^^^^^ cannot infer type of the type parameter `T` declared on the function `jaccard_similarity`
  |
  = note: cannot satisfy `_: Hash`
note: required by a bound in `jaccard_similarity`
 --> src\distance_functions\jaccard_similarity.rs:8:8
  |
6 | fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32
  |    ------------------ required by a bound in this function
7 | where
8 |     T: Hash + Eq + Clone,
  |        ^^^^ required by this bound in `jaccard_similarity`
help: consider specifying the generic argument
  |
6 | fn jaccard_similarity::<T><T>(s1: Vec<T>, s2: Vec<T>) -> f32
  |                      +++++

The generic argument is already declared on the function. I cannot understand what the compiler is asking me to do.

What works in Rust should also when Rust code is wrapped in a Python interface.

EDIT: I updated my pyo3 version to 0.20.0 and now I get a more meaningful error message:

error: Python functions cannot have generic type parameters
 --> src\distance_functions\jaccard_similarity.rs:6:23
  |
6 | fn jaccard_similarity<T>(s1: Vec<T>, s2: Vec<T>) -> f32

Is there a way to use generic type parameters for Python functions?

1

There are 1 answers

0
Masklinn On

What works in Rust should also when Rust code is wrapped in a Python interface.

No. Python is not statically typed, the python interface does not support generics, and so there is no way for pyo3 to create the bindings bridging your function and the Rust implementation.

In fact this matches the behaviour of Rust itself: jaccard_similarity does not actually generate any code itself. Instead the compiler will look at call sites, and generate an instance for each T the function is invoked with, these instances are the code which ends up in the binary. The instantiation step is the one that's not possible for pyo3 to achieve, so it can't work.

I would also say that this code is fundamentally not useful, the python equivalent is 5 calls of things implemented in C (create two sets, intersect them, union them, and divide). The overhead of copying the data to vectors, then set-ing those vectors, will likely be as large as just letting python do the thing. Especially with Rust's default hash function.

To have any chance of real gains I think you'd likely need to avoid the two vecs entirely (perform conversions on the fly from the PyList) and avoid reifying one of the sets (likely the larger one).