Skip to content

Commit 77d7a0a

Browse files
committed
check nn graphs building against pdist reference
1 parent b298600 commit 77d7a0a

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

pygsp/tests/test_graphs.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import scipy.linalg
12+
import scipy.sparse.linalg
1213
from skimage import data, img_as_float
1314

1415
from pygsp import graphs
@@ -199,20 +200,45 @@ def test_nngraph(self):
199200
graphs.NNGraph(Xin, NNtype='knn',
200201
backend=cur_backend,
201202
dist_type=dist_type, order=order)
203+
self.assertRaises(ValueError, graphs.NNGraph, Xin,
204+
NNtype='badtype', backend=cur_backend,
205+
dist_type=dist_type)
206+
self.assertRaises(ValueError, graphs.NNGraph, Xin,
207+
NNtype='knn', backend='badtype',
208+
dist_type=dist_type)
202209

203210
def test_nngraph_consistency(self):
204-
#Xin = np.arange(180).reshape(60, 3)
205211
Xin = np.random.uniform(-5, 5, (60, 3))
206212
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
207213
backends = ['scipy-kdtree', 'flann']
208-
num_neighbors=5
214+
num_neighbors=4
215+
epsilon=0.1
209216

217+
# use pdist as ground truth
210218
G = graphs.NNGraph(Xin, NNtype='knn',
211219
backend='scipy-pdist', k=num_neighbors)
212-
for cur_backend in backends:
220+
for cur_backend in backends:
213221
for dist_type in dist_types:
222+
if cur_backend == 'flann' and dist_type == 'max_dist':
223+
continue
224+
#print("backend={} dist={}".format(cur_backend, dist_type))
214225
Gt = graphs.NNGraph(Xin, NNtype='knn',
215226
backend=cur_backend, k=num_neighbors)
227+
d = scipy.sparse.linalg.norm(G.W - Gt.W)
228+
self.assertTrue(d < 0.01, 'Graphs (knn) are not identical error='.format(d))
229+
230+
G = graphs.NNGraph(Xin, NNtype='radius',
231+
backend='scipy-pdist', epsilon=epsilon)
232+
for cur_backend in backends:
233+
for dist_type in dist_types:
234+
if cur_backend == 'flann' and dist_type == 'max_dist':
235+
continue
236+
#print("backend={} dist={}".format(cur_backend, dist_type))
237+
Gt = graphs.NNGraph(Xin, NNtype='radius',
238+
backend=cur_backend, epsilon=epsilon)
239+
d = scipy.sparse.linalg.norm(G.W - Gt.W, ord=1)
240+
self.assertTrue(d < 0.01,
241+
'Graphs (radius) are not identical error='.format(d))
216242

217243
def test_bunny(self):
218244
graphs.Bunny()

0 commit comments

Comments
 (0)