wasserstein_distance for multidimesional data

77 views Asked by At

I am trying to compute the wasserstein distance for two point clouds and I have written the following below code my doubt is that my original point clouds are in 3D and to align with the library that I am using I flatten the data is it okay?

import numpy as np
from scipy.stats import wasserstein_distance
from plyfile import PlyData

def load_point_cloud(file_name):
    plydata = PlyData.read(file_name)
    points = np.array([list(x) for x in plydata['vertex'].data])
    return points

def compute_wasserstein_distance(cloud1, cloud2):
    # Flatten the point clouds to 1D arrays (assuming 3D points)
    u_values = cloud1.flatten()
    v_values = cloud2.flatten()

    # Compute Wasserstein distance
    distance = wasserstein_distance(u_values, v_values)
    return distance

# Load point clouds
cloud1 = load_point_cloud('/home/amit/bunny.ply')
cloud2 = load_point_cloud('/home/amit/Datasets/Animal Dataset/dragon.ply')

# Compute Wasserstein distance
distance = compute_wasserstein_distance(cloud1, cloud2)
print(f"Wasserstein Distance: {distance}")

I am confused that if flatting the data changes the geometry of the original point cloud or not? If yes then is there any way to compute it in my case?

1

There are 1 answers

2
Frank Yellin On

It is certainly not true that flattening the data preserves the Wasserstein distance.

Imagine each cloud has exactly one point, say <1, 2, 3> and <2, 3, 1>.

The Wasserstein difference between these two points is non-zero, while the distance between [2, 3, 1] and [1, 2, 3] is zero.