@@ -331,7 +331,7 @@ module Operations =
331
331
| _ -> failwith " Not implemented yet"
332
332
333
333
/// <summary>
334
- /// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations.
334
+ /// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations by skipping reduction stage .
335
335
/// </summary>
336
336
/// <param name="add">Type of binary function to reduce entries.</param>
337
337
/// <param name="mul">Type of binary function to combine entries.</param>
@@ -352,6 +352,50 @@ module Operations =
352
352
| ClMatrix.CSR m, ClVector.Sparse v -> Option.map ClVector.Sparse ( run queue m v)
353
353
| _ -> failwith " Not implemented yet"
354
354
355
+ /// <summary>
356
+ /// CSR Matrix - sparse vector multiplication with mask. Mask is complemented.
357
+ /// </summary>
358
+ /// <param name="add">Type of binary function to reduce entries.</param>
359
+ /// <param name="mul">Type of binary function to combine entries.</param>
360
+ /// <param name="clContext">OpenCL context.</param>
361
+ /// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
362
+ let SpMSpVMasked
363
+ ( add : Expr < 'c option -> 'c option -> 'c option >)
364
+ ( mul : Expr < 'a option -> 'b option -> 'c option >)
365
+ ( clContext : ClContext )
366
+ workGroupSize
367
+ =
368
+
369
+ let run =
370
+ SpMSpV.Masked.runMasked add mul clContext workGroupSize
371
+
372
+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix < 'a >) ( vector : ClVector < 'b >) ( mask : ClVector < 'd >) ->
373
+ match matrix, vector, mask with
374
+ | ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse ( run queue m v mask)
375
+ | _ -> failwith " Not implemented yet"
376
+
377
+ /// <summary>
378
+ /// CSR Matrix - sparse vector multiplication with mask. Mask is complemented. Optimized for bool OR and AND operations by skipping reduction stage.
379
+ /// </summary>
380
+ /// <param name="add">Type of binary function to reduce entries.</param>
381
+ /// <param name="mul">Type of binary function to combine entries.</param>
382
+ /// <param name="clContext">OpenCL context.</param>
383
+ /// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
384
+ let SpMSpVMaskedBool
385
+ ( add : Expr < bool option -> bool option -> bool option >)
386
+ ( mul : Expr < bool option -> bool option -> bool option >)
387
+ ( clContext : ClContext )
388
+ workGroupSize
389
+ =
390
+
391
+ let run =
392
+ SpMSpV.Masked.runMaskedBoolStandard add mul clContext workGroupSize
393
+
394
+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix < 'a >) ( vector : ClVector < 'b >) ( mask : ClVector < 'd >) ->
395
+ match matrix, vector, mask with
396
+ | ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse ( run queue m v mask)
397
+ | _ -> failwith " Not implemented yet"
398
+
355
399
/// <summary>
356
400
/// CSR Matrix - sparse vector multiplication.
357
401
/// </summary>
0 commit comments