Sliced Wasserstein kernel: initial implementation
This commit is contained in:
parent
6253d94589
commit
d9f627d6e7
1 changed files with 36 additions and 0 deletions
36
sliced_wasserstein.py
Normal file
36
sliced_wasserstein.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import numpy as np
|
||||
import dionysus as d
|
||||
|
||||
|
||||
def diagram_array(dgm):
|
||||
res = []
|
||||
for p in dgm:
|
||||
if p.death != np.inf:
|
||||
res.append([p.birth, p.death])
|
||||
return np.array(res)
|
||||
|
||||
|
||||
def SW_approx(dgm1, dgm2, M):
|
||||
dgm1 = diagram_array(dgm1)
|
||||
dgm2 = diagram_array(dgm2)
|
||||
# Add \pi_\delta(dgm1) to dgm2 and vice-versa
|
||||
proj1 = dgm1.dot([1, 1])/np.sqrt(2)
|
||||
proj2 = dgm2.dot([1, 1])/np.sqrt(2)
|
||||
dgm1 = np.vstack((dgm1, np.vstack((proj2, proj2)).T))
|
||||
dgm2 = np.vstack((dgm2, np.vstack((proj1, proj1)).T))
|
||||
SW = 0
|
||||
theta = -np.pi/2
|
||||
s = np.pi/M
|
||||
for i in range(M):
|
||||
# Project each diagram on the direction theta
|
||||
vec = [1, np.arctan(theta)]
|
||||
vec = vec / np.linalg.norm(vec)
|
||||
V1 = dgm1.dot(vec)
|
||||
V2 = dgm2.dot(vec)
|
||||
# Sort the projections
|
||||
V1.sort()
|
||||
V2.sort()
|
||||
# l1-distance between the projections
|
||||
SW = SW + s * np.sum(np.abs(V1 - V2))
|
||||
theta = theta + s
|
||||
return SW / M
|
Loading…
Add table
Add a link
Reference in a new issue