@@ -680,6 +680,82 @@ def sample(
680
680
self .c_sparse_matrix .sample (dim , fanout , ids , replace , bias )
681
681
)
682
682
683
+ def compact (
684
+ self ,
685
+ dim : int ,
686
+ leading_indices : Optional [torch .Tensor ] = None ,
687
+ ):
688
+ """Compact sparse matrix by removing rows or columns without non-zero
689
+ elements in the sparse matrix and relabeling indices of the dimension.
690
+
691
+ This function serves a dual purpose: it allows you to reorganize the
692
+ indices within a specific dimension (rows or columns) of the sparse
693
+ matrix and, if needed, place certain 'leading_indices' at the beginning
694
+ of the relabeled dimension.
695
+
696
+ In the absence of 'leading_indices' (when it's set to `None`), the order
697
+ of relabeled indices remains the same as the original order, except that
698
+ rows or columns without non-zero elements are removed. When
699
+ 'leading_indices' are provided, they are positioned at the start of the
700
+ relabeled dimension. To be precise, all rows selected by the specified
701
+ indices will be remapped from 0 to length(indices) - 1. Rows that are not
702
+ selected and contain any non-zero elements will be positioned after those
703
+ remapped rows while maintaining their original order.
704
+
705
+ This function mimics 'dgl.to_block', a method used to compress a sampled
706
+ subgraph by eliminating redundant nodes. The 'leading_indices' parameter
707
+ replicates the behavior of 'include_dst_in_src' in 'dgl.to_block',
708
+ adding destination node information for message passing.
709
+ Setting 'leading_indices' to column IDs when relabeling the row
710
+ dimension, for example, achieves the same effect as including destination
711
+ nodes in source nodes.
712
+
713
+ Parameters
714
+ ----------
715
+ dim : int
716
+ The dimension to relabel. Should be 0 or 1. Use `dim = 0` for rowwise
717
+ relabeling and `dim = 1` for columnwise relabeling.
718
+ leading_indices : torch.Tensor, optional
719
+ An optional tensor containing row or column ids that should be placed
720
+ at the beginning of the relabeled dimension.
721
+
722
+ Returns
723
+ -------
724
+ Tuple[SparseMatrix, torch.Tensor]
725
+ A tuple containing the relabeled sparse matrix and the index mapping
726
+ of the relabeled dimension from the new index to the original index.
727
+
728
+ Examples
729
+ --------
730
+ >>> indices = torch.tensor([[0, 2],
731
+ [1, 2]])
732
+ >>> A = dglsp.spmatrix(indices)
733
+ >>> print(A.to_dense())
734
+ tensor([[0., 1., 0.],
735
+ [0., 0., 0.],
736
+ [0., 0., 1.]])
737
+
738
+ Case 1: Compact rows without indices.
739
+
740
+ >>> B, original_rows = A.compact(dim=0, leading_indices=None)
741
+ >>> print(B.to_dense())
742
+ tensor([[0., 1., 0.],
743
+ [0., 0., 1.]])
744
+ >>> print(original_rows)
745
+ torch.Tensor([0, 2])
746
+
747
+ Case 2: Compact rows with indices.
748
+
749
+ >>> B, original_rows = A.compact(dim=0, leading_indices=[1, 2])
750
+ >>> print(B.to_dense())
751
+ tensor([[0., 0., 0.],
752
+ [0., 0., 1.],
753
+ [0., 1., 0.],])
754
+ >>> print(original_rows)
755
+ torch.Tensor([1, 2, 0])
756
+ """
757
+ raise NotImplementedError
758
+
683
759
684
760
def spmatrix (
685
761
indices : torch .Tensor ,
0 commit comments