Skip to content

Commit 8275bc2

Browse files
authoredSep 15, 2023
[Sparse] Add relabel python API (#6323)
1 parent 2a715f2 commit 8275bc2

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
 

‎python/dgl/sparse/sparse_matrix.py

+76
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,82 @@ def sample(
680680
self.c_sparse_matrix.sample(dim, fanout, ids, replace, bias)
681681
)
682682

683+
def compact(
684+
self,
685+
dim: int,
686+
leading_indices: Optional[torch.Tensor] = None,
687+
):
688+
"""Compact sparse matrix by removing rows or columns without non-zero
689+
elements in the sparse matrix and relabeling indices of the dimension.
690+
691+
This function serves a dual purpose: it allows you to reorganize the
692+
indices within a specific dimension (rows or columns) of the sparse
693+
matrix and, if needed, place certain 'leading_indices' at the beginning
694+
of the relabeled dimension.
695+
696+
In the absence of 'leading_indices' (when it's set to `None`), the order
697+
of relabeled indices remains the same as the original order, except that
698+
rows or columns without non-zero elements are removed. When
699+
'leading_indices' are provided, they are positioned at the start of the
700+
relabeled dimension. To be precise, all rows selected by the specified
701+
indices will be remapped from 0 to length(indices) - 1. Rows that are not
702+
selected and contain any non-zero elements will be positioned after those
703+
remapped rows while maintaining their original order.
704+
705+
This function mimics 'dgl.to_block', a method used to compress a sampled
706+
subgraph by eliminating redundant nodes. The 'leading_indices' parameter
707+
replicates the behavior of 'include_dst_in_src' in 'dgl.to_block',
708+
adding destination node information for message passing.
709+
Setting 'leading_indices' to column IDs when relabeling the row
710+
dimension, for example, achieves the same effect as including destination
711+
nodes in source nodes.
712+
713+
Parameters
714+
----------
715+
dim : int
716+
The dimension to relabel. Should be 0 or 1. Use `dim = 0` for rowwise
717+
relabeling and `dim = 1` for columnwise relabeling.
718+
leading_indices : torch.Tensor, optional
719+
An optional tensor containing row or column ids that should be placed
720+
at the beginning of the relabeled dimension.
721+
722+
Returns
723+
-------
724+
Tuple[SparseMatrix, torch.Tensor]
725+
A tuple containing the relabeled sparse matrix and the index mapping
726+
of the relabeled dimension from the new index to the original index.
727+
728+
Examples
729+
--------
730+
>>> indices = torch.tensor([[0, 2],
731+
[1, 2]])
732+
>>> A = dglsp.spmatrix(indices)
733+
>>> print(A.to_dense())
734+
tensor([[0., 1., 0.],
735+
[0., 0., 0.],
736+
[0., 0., 1.]])
737+
738+
Case 1: Compact rows without indices.
739+
740+
>>> B, original_rows = A.compact(dim=0, leading_indices=None)
741+
>>> print(B.to_dense())
742+
tensor([[0., 1., 0.],
743+
[0., 0., 1.]])
744+
>>> print(original_rows)
745+
torch.Tensor([0, 2])
746+
747+
Case 2: Compact rows with indices.
748+
749+
>>> B, original_rows = A.compact(dim=0, leading_indices=[1, 2])
750+
>>> print(B.to_dense())
751+
tensor([[0., 0., 0.],
752+
[0., 0., 1.],
753+
[0., 1., 0.],])
754+
>>> print(original_rows)
755+
torch.Tensor([1, 2, 0])
756+
"""
757+
raise NotImplementedError
758+
683759

684760
def spmatrix(
685761
indices: torch.Tensor,

0 commit comments

Comments
 (0)
Please sign in to comment.