|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | import scipy.linalg
|
| 12 | +import scipy.sparse.linalg |
12 | 13 | from skimage import data, img_as_float
|
13 | 14 |
|
14 | 15 | from pygsp import graphs
|
@@ -199,20 +200,45 @@ def test_nngraph(self):
|
199 | 200 | graphs.NNGraph(Xin, NNtype='knn',
|
200 | 201 | backend=cur_backend,
|
201 | 202 | dist_type=dist_type, order=order)
|
| 203 | + self.assertRaises(ValueError, graphs.NNGraph, Xin, |
| 204 | + NNtype='badtype', backend=cur_backend, |
| 205 | + dist_type=dist_type) |
| 206 | + self.assertRaises(ValueError, graphs.NNGraph, Xin, |
| 207 | + NNtype='knn', backend='badtype', |
| 208 | + dist_type=dist_type) |
202 | 209 |
|
203 | 210 | def test_nngraph_consistency(self):
|
204 |
| - #Xin = np.arange(180).reshape(60, 3) |
205 | 211 | Xin = np.random.uniform(-5, 5, (60, 3))
|
206 | 212 | dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
|
207 | 213 | backends = ['scipy-kdtree', 'flann']
|
208 |
| - num_neighbors=5 |
| 214 | + num_neighbors=4 |
| 215 | + epsilon=0.1 |
209 | 216 |
|
| 217 | + # use pdist as ground truth |
210 | 218 | G = graphs.NNGraph(Xin, NNtype='knn',
|
211 | 219 | backend='scipy-pdist', k=num_neighbors)
|
212 |
| - for cur_backend in backends: |
| 220 | + for cur_backend in backends: |
213 | 221 | for dist_type in dist_types:
|
| 222 | + if cur_backend == 'flann' and dist_type == 'max_dist': |
| 223 | + continue |
| 224 | + #print("backend={} dist={}".format(cur_backend, dist_type)) |
214 | 225 | Gt = graphs.NNGraph(Xin, NNtype='knn',
|
215 | 226 | backend=cur_backend, k=num_neighbors)
|
| 227 | + d = scipy.sparse.linalg.norm(G.W - Gt.W) |
| 228 | + self.assertTrue(d < 0.01, 'Graphs (knn) are not identical error='.format(d)) |
| 229 | + |
| 230 | + G = graphs.NNGraph(Xin, NNtype='radius', |
| 231 | + backend='scipy-pdist', epsilon=epsilon) |
| 232 | + for cur_backend in backends: |
| 233 | + for dist_type in dist_types: |
| 234 | + if cur_backend == 'flann' and dist_type == 'max_dist': |
| 235 | + continue |
| 236 | + #print("backend={} dist={}".format(cur_backend, dist_type)) |
| 237 | + Gt = graphs.NNGraph(Xin, NNtype='radius', |
| 238 | + backend=cur_backend, epsilon=epsilon) |
| 239 | + d = scipy.sparse.linalg.norm(G.W - Gt.W, ord=1) |
| 240 | + self.assertTrue(d < 0.01, |
| 241 | + 'Graphs (radius) are not identical error='.format(d)) |
216 | 242 |
|
217 | 243 | def test_bunny(self):
|
218 | 244 | graphs.Bunny()
|
|
0 commit comments