Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 43bc955

Browse files
BarclayIIrudongyufrozenbugs
authored andcommittedMar 12, 2024
[Optimization] Use scipy's eigs instead of numpy in lap_pe (dmlc#5855)
Co-authored-by: rudongyu <ru_dongyu@outlook.com> Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
1 parent 570f80c commit 43bc955

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed
 

‎python/dgl/transforms/functional.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -3680,13 +3680,26 @@ def lap_pe(g, k, padding=False, return_eigval=False):
36803680
L = sparse.eye(g.num_nodes()) - N * A * N
36813681

36823682
# select eigenvectors with smaller eigenvalues O(n + klogk)
3683-
EigVal, EigVec = np.linalg.eig(L.toarray())
3684-
max_freqs = min(n - 1, k)
3685-
kpartition_indices = np.argpartition(EigVal, max_freqs)[: max_freqs + 1]
3686-
topk_eigvals = EigVal[kpartition_indices]
3687-
topk_indices = kpartition_indices[topk_eigvals.argsort()][1:]
3688-
topk_EigVec = EigVec[:, topk_indices]
3689-
eigvals = F.tensor(EigVal[topk_indices], dtype=F.float32)
3683+
if k + 1 < n - 1:
3684+
# Use scipy if k + 1 < n - 1 for memory efficiency.
3685+
EigVal, EigVec = scipy.sparse.linalg.eigs(
3686+
L, k=k + 1, which="SR", tol=1e-2
3687+
)
3688+
topk_indices = EigVal.argsort()[1:]
3689+
# Since scipy may return complex value, to avoid crashing in NN code,
3690+
# convert them to real number.
3691+
topk_eigvals = EigVal[topk_indices].real
3692+
topk_EigVec = EigVec[:, topk_indices].real
3693+
else:
3694+
# Fallback to numpy since scipy.sparse do not support this case.
3695+
EigVal, EigVec = np.linalg.eig(L.toarray())
3696+
max_freqs = min(n - 1, k)
3697+
kpartition_indices = np.argpartition(EigVal, max_freqs)[: max_freqs + 1]
3698+
topk_eigvals = EigVal[kpartition_indices]
3699+
topk_indices = kpartition_indices[topk_eigvals.argsort()][1:]
3700+
topk_EigVec = EigVec[:, topk_indices]
3701+
topk_EigVal = EigVal[topk_indices]
3702+
eigvals = F.tensor(topk_EigVal, dtype=F.float32)
36903703

36913704
# get random flip signs
36923705
rand_sign = 2 * (np.random.rand(max_freqs) > 0.5) - 1.0

0 commit comments

Comments
 (0)
Please sign in to comment.