Skip to content

Commit a80901c

Browse files
authoredJun 2, 2024
Merge pull request #96 from kirillgarbar/spmspvmasked
SpMSpVMasked
2 parents e5ad7ac + 12e733f commit a80901c

File tree

9 files changed

+360
-42
lines changed

9 files changed

+360
-42
lines changed
 

‎src/GraphBLAS-sharp.Backend/Algorithms/BFS.fs

+14-23
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ module internal BFS =
135135
Operations.SpMVInPlace add mul clContext workGroupSize
136136

137137
let spMSpV =
138-
Operations.SpMSpVBool add mul clContext workGroupSize
138+
Operations.SpMSpVMaskedBool add mul clContext workGroupSize
139139

140140
let zeroCreate =
141141
Vector.zeroCreate clContext workGroupSize
@@ -145,13 +145,11 @@ module internal BFS =
145145
let maskComplementedInPlace =
146146
Vector.map2InPlace Mask.complementedOp clContext workGroupSize
147147

148-
let maskComplemented =
149-
Vector.map2Sparse Mask.complementedOp clContext workGroupSize
150-
151148
let fillSubVectorInPlace =
152149
Vector.assignByMaskInPlace (Mask.assign) clContext workGroupSize
153150

154-
let toSparse = Vector.toSparse clContext workGroupSize
151+
let toSparse =
152+
Vector.toSparseUnsorted clContext workGroupSize
155153

156154
let toDense = Vector.toDense clContext workGroupSize
157155

@@ -190,28 +188,21 @@ module internal BFS =
190188
match frontier with
191189
| ClVector.Sparse _ ->
192190
//Getting new frontier
193-
match spMSpV queue matrix frontier with
191+
match spMSpV queue matrix frontier levels with
194192
| None ->
195193
frontier.Dispose()
196194
stop <- true
197-
| Some newFrontier ->
195+
| Some newMaskedFrontier ->
198196
frontier.Dispose()
199-
//Filtering visited vertices
200-
match maskComplemented queue DeviceOnly newFrontier levels with
201-
| None ->
202-
stop <- true
203-
newFrontier.Dispose()
204-
| Some newMaskedFrontier ->
205-
newFrontier.Dispose()
206-
207-
//Push/pull
208-
let NNZ = getNNZ queue newMaskedFrontier
209-
210-
if (push NNZ newMaskedFrontier.Size) then
211-
frontier <- newMaskedFrontier
212-
else
213-
frontier <- toDense queue DeviceOnly newMaskedFrontier
214-
newMaskedFrontier.Dispose()
197+
198+
//Push/pull
199+
let NNZ = getNNZ queue newMaskedFrontier
200+
201+
if (push NNZ newMaskedFrontier.Size) then
202+
frontier <- newMaskedFrontier
203+
else
204+
frontier <- toDense queue DeviceOnly newMaskedFrontier
205+
newMaskedFrontier.Dispose()
215206
| ClVector.Dense oldFrontier ->
216207
//Getting new frontier
217208
spMVInPlace queue matrix frontier frontier

‎src/GraphBLAS-sharp.Backend/Common/ClArray.fs

+28-10
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ module ClArray =
362362

363363
let gid = ndRange.GlobalID0
364364

365-
if gid < length then
365+
if gid < length && not result.Value then
366366
let isExist = (%predicate) vector.[gid]
367367

368368
if isExist then result.Value <- true @>
@@ -902,22 +902,40 @@ module ClArray =
902902

903903
let count<'a> (predicate: Expr<'a -> bool>) (clContext: ClContext) workGroupSize =
904904

905-
let sum =
906-
Reduce.reduce <@ (+) @> clContext workGroupSize
905+
let count =
906+
<@ fun (ndRange: Range1D) (length: int) (array: ClArray<'a>) (count: ClCell<int>) ->
907+
let gid = ndRange.GlobalID0
908+
let mutable countLocal = 0
909+
let step = ndRange.GlobalWorkSize
910+
911+
let mutable i = gid
907912

908-
let getBitmap =
909-
Map.map<'a, int> (Map.predicateBitmap predicate) clContext workGroupSize
913+
while i < length do
914+
let res = (%predicate) array.[i]
915+
if res then countLocal <- countLocal + 1
916+
i <- i + step
917+
918+
atomic (+) count.Value countLocal |> ignore @>
919+
920+
let count = clContext.Compile count
910921

911922
fun (processor: RawCommandQueue) (array: ClArray<'a>) ->
912923

913-
let bitmap = getBitmap processor DeviceOnly array
924+
let result = clContext.CreateClCell<int>(0)
914925

915-
let result =
916-
(sum processor bitmap).ToHostAndFree processor
926+
let numberOfGroups =
927+
Utils.divUpClamp array.Length workGroupSize 1 1024
917928

918-
bitmap.Free()
929+
let ndRange =
930+
Range1D.CreateValid(workGroupSize * numberOfGroups, workGroupSize)
919931

920-
result
932+
let kernel = count.GetKernel()
933+
934+
kernel.KernelFunc ndRange array.Length array result
935+
936+
processor.RunKernel kernel
937+
938+
result.ToHostAndFree processor
921939

922940
/// <summary>
923941
/// Builds a new array whose elements are the results of applying the given function

‎src/GraphBLAS-sharp.Backend/Common/Sort/Bitonic.fs

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ module Bitonic =
1212
int (clContext.ClDevice.LocalMemSize)
1313
/ (sizeof<uint64> + sizeof<'a>)
1414
)
15+
/ 2
1516

1617
let maxThreadsPerBlock =
1718
min (clContext.ClDevice.MaxWorkGroupSize) (localSize / 2)
@@ -257,6 +258,7 @@ module Bitonic =
257258
int (clContext.ClDevice.LocalMemSize)
258259
/ (sizeof<int> + sizeof<'a>)
259260
)
261+
/ 2
260262

261263
let maxThreadsPerBlock =
262264
min (clContext.ClDevice.MaxWorkGroupSize) (localSize / 2)
@@ -476,4 +478,4 @@ module Bitonic =
476478

477479
kernelGlobal.KernelFunc ndRangeGlobal rows values values.Length (localSize * 2)
478480

479-
queue.RunKernel(kernelGlobal)
481+
queue.RunKernel(kernelGlobal)

‎src/GraphBLAS-sharp.Backend/Common/Utils.fs

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ module internal Utils =
1919
>> fun x -> x ||| (x >>> 16)
2020
>> fun x -> x + 1
2121

22+
let divUp x y = x / y + (if x % y = 0 then 0 else 1)
23+
24+
let divUpClamp x y left right = min (max (divUp x y) left) right
25+
2226
let floorToMultiple multiple x = x / multiple * multiple
2327

2428
let ceilToMultiple multiple x = ((x - 1) / multiple + 1) * multiple

‎src/GraphBLAS-sharp.Backend/Operations/Operations.fs

+45-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ module Operations =
331331
| _ -> failwith "Not implemented yet"
332332

333333
/// <summary>
334-
/// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations.
334+
/// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations by skipping reduction stage.
335335
/// </summary>
336336
/// <param name="add">Type of binary function to reduce entries.</param>
337337
/// <param name="mul">Type of binary function to combine entries.</param>
@@ -352,6 +352,50 @@ module Operations =
352352
| ClMatrix.CSR m, ClVector.Sparse v -> Option.map ClVector.Sparse (run queue m v)
353353
| _ -> failwith "Not implemented yet"
354354

355+
/// <summary>
356+
/// CSR Matrix - sparse vector multiplication with mask. Mask is complemented.
357+
/// </summary>
358+
/// <param name="add">Type of binary function to reduce entries.</param>
359+
/// <param name="mul">Type of binary function to combine entries.</param>
360+
/// <param name="clContext">OpenCL context.</param>
361+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
362+
let SpMSpVMasked
363+
(add: Expr<'c option -> 'c option -> 'c option>)
364+
(mul: Expr<'a option -> 'b option -> 'c option>)
365+
(clContext: ClContext)
366+
workGroupSize
367+
=
368+
369+
let run =
370+
SpMSpV.Masked.runMasked add mul clContext workGroupSize
371+
372+
fun (queue: RawCommandQueue) (matrix: ClMatrix<'a>) (vector: ClVector<'b>) (mask: ClVector<'d>) ->
373+
match matrix, vector, mask with
374+
| ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse (run queue m v mask)
375+
| _ -> failwith "Not implemented yet"
376+
377+
/// <summary>
378+
/// CSR Matrix - sparse vector multiplication with mask. Mask is complemented. Optimized for bool OR and AND operations by skipping reduction stage.
379+
/// </summary>
380+
/// <param name="add">Type of binary function to reduce entries.</param>
381+
/// <param name="mul">Type of binary function to combine entries.</param>
382+
/// <param name="clContext">OpenCL context.</param>
383+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
384+
let SpMSpVMaskedBool
385+
(add: Expr<bool option -> bool option -> bool option>)
386+
(mul: Expr<bool option -> bool option -> bool option>)
387+
(clContext: ClContext)
388+
workGroupSize
389+
=
390+
391+
let run =
392+
SpMSpV.Masked.runMaskedBoolStandard add mul clContext workGroupSize
393+
394+
fun (queue: RawCommandQueue) (matrix: ClMatrix<'a>) (vector: ClVector<'b>) (mask: ClVector<'d>) ->
395+
match matrix, vector, mask with
396+
| ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse (run queue m v mask)
397+
| _ -> failwith "Not implemented yet"
398+
355399
/// <summary>
356400
/// CSR Matrix - sparse vector multiplication.
357401
/// </summary>

0 commit comments

Comments
 (0)