Skip to content

Zero-allocation tree evaluation with buffer #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 12, 2024
Merged
25 changes: 13 additions & 12 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ module DynamicExpressionsLoopVectorizationExt

using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
using DynamicExpressions.UtilsModule: ResultOk
using DynamicExpressions.EvaluateModule:
@return_on_nonfinite_val, EvalOptions, get_array, get_feature_array, get_filled_array
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
Expand Down Expand Up @@ -56,12 +57,12 @@ function deg1_l2_ll0_lr0_eval(
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
elseif tree.l.l.constant
val_ll = tree.l.l.val
@return_on_nonfinite_val(eval_options, val_ll, cX)
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(val_ll, cX[feature_lr, j])
x = op(x_l)
Expand All @@ -72,7 +73,7 @@ function deg1_l2_ll0_lr0_eval(
feature_ll = tree.l.l.feature
val_lr = tree.l.r.val
@return_on_nonfinite_val(eval_options, val_lr, cX)
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], val_lr)
x = op(x_l)
Expand All @@ -82,7 +83,7 @@ function deg1_l2_ll0_lr0_eval(
else
feature_ll = tree.l.l.feature
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])
x = op(x_l)
Expand All @@ -106,10 +107,10 @@ function deg1_l1_ll0_eval(
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
else
feature_ll = tree.l.l.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j])
x = op(x_l)
Expand All @@ -132,9 +133,9 @@ function deg2_l0_r0_eval(
@return_on_nonfinite_val(eval_options, val_r, cX)
x = op(val_l, val_r)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
elseif tree.l.constant
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
val_l = tree.l.val
@return_on_nonfinite_val(eval_options, val_l, cX)
feature_r = tree.r.feature
Expand All @@ -144,7 +145,7 @@ function deg2_l0_r0_eval(
end
return ResultOk(cumulator, true)
elseif tree.r.constant
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
feature_l = tree.l.feature
val_r = tree.r.val
@return_on_nonfinite_val(eval_options, val_r, cX)
Expand All @@ -154,7 +155,7 @@ function deg2_l0_r0_eval(
end
return ResultOk(cumulator, true)
else
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
feature_l = tree.l.feature
feature_r = tree.r.feature
@turbo for j in axes(cX, 2)
Expand Down
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import .StringsModule: get_op_name
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
import .EvaluateModule: ArrayBuffer
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
Loading
Loading