Correct way to measure divergence between similar point clouds #482
Replies: 1 comment 1 reply
-
Hi @arturtoshev , yes, I believe the small negative values are just coming from numerical imprecision.
The regularized OT cost is not a divergence, so if need to have a divergence, I wouldn't use it and would use Sinkhorn divergence instead.
Epsilon the entropy regularization parameter - as Afaik, in our notebooks we only show the effect of here here, see the animation at the bottom as epsilon increases). Lastly, in your code snippet, instead of manually computing the @jax.tree_util.register_pytree_node_class
class MyCost(ott.geometry.costs.CostFn):
"""Squared Euclidean distance."""
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
return ((displacement_fn(x, y)) ** 2).sum(axis=-1)
@jax.jit
def sinkhorn_divergence_ott(x, y):
out = ott.tools.sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud,
x,
y,
cost_fn=MyCost(),
sinkhorn_kwargs={"threshold": 1e-6},
)
return out.divergence, out |
Beta Was this translation helpful? Give feedback.
-
Here is what I want to do:$[0,1]^2$ . What is the correct way to measure the distance in their distribution? By distance, I mean that if both particle sets are on a Cartesian grid (say slightly shifted to each other), then the distance should be close to zero and increase the more shifted the clouds are, and if one of the particle sets populates only a subset of the domain, then the distance should increase.
Given are two 2D point clouds of N*N particles of the same weight in a periodic box
We have been using the
ott.tools.sinkhorn_divergence.sinkhorn_diverngence().divergence
function so far, but I experimented with some edge cases and now I doubt that I understand what this function does (I mean that I didn't expect negative divergence to be allowed and I also expected that the more shifted the point clouds are, the larger the divergence gets). Then, I also looked atott.solvers.linear.sinkhorn.Sinkhorn(ott.problems.linear.linear_problem.LinearProblem(ott.geometry.pointcloud.PointCloud())).reg_ot_cost
as I expected this to be somehow proportional to the divergence, but this quantity is extremely dependent on anepsilon
parameter, which I don't understand.The code I experimented with is here and the plots I got look like that:
If the negative divergence is only because we are in the range of numerical precision errors (I actually ran the whole thing also with double precision, but still got negative divergences), then what is the correct way to capture slight deviations in the actual particle distribution, e.g. the onset of particle clustering?
Best, Artur
Beta Was this translation helpful? Give feedback.
All reactions