Skip to content

Create preallocation utility functions for expressions #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 13, 2024
3 changes: 2 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using DispatchDoctor: @stable, @unstable
include("OperatorEnum.jl")
include("Node.jl")
include("NodeUtils.jl")
include("NodePreallocation.jl")
include("Strings.jl")
include("Evaluate.jl")
include("EvaluateDerivative.jl")
Expand Down Expand Up @@ -41,11 +42,11 @@ import .ValueInterfaceModule:
GraphNode,
Node,
copy_node,
copy_node!,
set_node!,
tree_mapreduce,
filter_map,
filter_map!
import .NodePreallocationModule: allocate_container, copy_into!
import .NodeModule:
constructorof,
with_type_parameters,
Expand Down
20 changes: 20 additions & 0 deletions src/Expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import ..NodeUtilsModule:
count_scalar_constants,
get_scalar_constants,
set_scalar_constants!
import ..NodePreallocationModule: copy_into!, allocate_container
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
import ..EvaluateDerivativeModule: eval_grad_tree_array
import ..EvaluationHelpersModule: _grad_evaluator
Expand Down Expand Up @@ -502,4 +503,23 @@ function (ex::AbstractExpression)(
return get_tree(ex)(X, get_operators(ex, operators); kws...)
end

# We don't require users to overload this, as it's not part of the required interface.
# Also, there's no way to generally do this from the required interface, so for backwards
# compatibility, we just return nothing.
# COV_EXCL_START
function copy_into!(::Nothing, src::AbstractExpression)
return copy(src)
end
function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
return nothing
end
# COV_EXCL_STOP
function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing)
return (; tree=allocate_container(get_contents(prototype), n))
end
function copy_into!(dest::NamedTuple, src::Expression)
tree = copy_into!(dest.tree, get_contents(src))
return with_contents(src, tree)
end

end
32 changes: 22 additions & 10 deletions src/Interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ using ..NodeModule:
default_allocator,
with_type_parameters,
leaf_copy,
leaf_copy!,
leaf_convert,
leaf_hash,
leaf_equal,
branch_copy,
branch_copy!,
branch_convert,
branch_hash,
branch_equal,
Expand All @@ -38,6 +36,8 @@ using ..NodeUtilsModule:
has_constants,
get_scalar_constants,
set_scalar_constants!
using ..NodePreallocationModule:
copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container
using ..StringsModule: string_tree
using ..EvaluateModule: eval_tree_array
using ..EvaluateDerivativeModule: eval_grad_tree_array
Expand Down Expand Up @@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression)
end

## optional
function _check_copy_into!(ex::AbstractExpression)
container = allocate_container(ex)
prealloc_ex = copy_into!(container, ex)
return container !== nothing && prealloc_ex == ex && prealloc_ex !== ex
end
function _check_count_nodes(ex::AbstractExpression)
return count_nodes(ex) isa Int64
end
Expand Down Expand Up @@ -156,6 +161,7 @@ ei_components = (
with_metadata = "returns the expression with different metadata" => _check_with_metadata,
),
optional = (
copy_into! = "copies an expression into a preallocated container" => _check_copy_into!,
count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes,
count_constant_nodes = "counts the number of constant nodes in the expression tree" => _check_count_constant_nodes,
count_depth = "calculates the depth of the expression tree" => _check_count_depth,
Expand Down Expand Up @@ -260,14 +266,19 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode)
end

## optional
function _check_copy_into!(tree::AbstractExpressionNode)
container = allocate_container(tree)
prealloc_tree = copy_into!(container, tree)
return container !== nothing && prealloc_tree == tree && prealloc_tree !== container
end
function _check_leaf_copy(tree::AbstractExpressionNode)
tree.degree != 0 && return true
return leaf_copy(tree) isa typeof(tree)
end
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
function _check_leaf_copy_into!(tree::AbstractExpressionNode{T}) where {T}
tree.degree != 0 && return true
new_leaf = constructorof(typeof(tree))(; val=zero(T))
ret = leaf_copy!(new_leaf, tree)
ret = leaf_copy_into!(new_leaf, tree)
return new_leaf == tree && ret === new_leaf
end
function _check_leaf_convert(tree::AbstractExpressionNode)
Expand All @@ -292,16 +303,16 @@ function _check_branch_copy(tree::AbstractExpressionNode)
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
end
end
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T}
if tree.degree == 0
return true
end
new_branch = constructorof(typeof(tree))(; val=zero(T))
if tree.degree == 1
ret = branch_copy!(new_branch, tree, copy(tree.l))
ret = branch_copy_into!(new_branch, tree, copy(tree.l))
return new_branch == tree && ret === new_branch
else
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r))
return new_branch == tree && ret === new_branch
end
end
Expand Down Expand Up @@ -372,13 +383,14 @@ ni_components = (
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
),
optional = (
copy_into! = "copies a node into a preallocated container" => _check_copy_into!,
leaf_copy = "copies a leaf node" => _check_leaf_copy,
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!,
leaf_convert = "converts a leaf node" => _check_leaf_convert,
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
branch_copy = "copies a branch node" => _check_branch_copy,
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
branch_copy_into! = "copies a branch node in-place" => _check_branch_copy_into!,
branch_convert = "converts a branch node" => _check_branch_convert,
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
Expand Down Expand Up @@ -419,7 +431,7 @@ ni_description = (
[Arguments()]
)
@implements(
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
NodeInterface{all_ni_methods_except(())},
GraphNode,
[Arguments()]
)
Expand Down
11 changes: 0 additions & 11 deletions src/Node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,12 @@ function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {
return GraphNode{promote_type(T1, T2)}
end

# TODO: Verify using this helps with garbage collection
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N()

"""
set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T}

Set every field of `tree` equal to the corresponding field of `new_tree`.
"""
function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode)
# First, ensure we free some memory:
if new_tree.degree < 2 && tree.degree == 2
tree.r = create_dummy_node(typeof(tree))
end
if new_tree.degree < 1 && tree.degree >= 1
tree.l = create_dummy_node(typeof(tree))
end

tree.degree = new_tree.degree
if new_tree.degree == 0
tree.constant = new_tree.constant
Expand Down
69 changes: 69 additions & 0 deletions src/NodePreallocation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
module NodePreallocationModule

using ..NodeModule:
AbstractExpressionNode,
with_type_parameters,
tree_mapreduce,
leaf_copy,
branch_copy,
set_node!

"""
allocate_container(prototype::AbstractExpressionNode, n=nothing)

Preallocate an array of `n` empty nodes matching the type of `prototype`.
If `n` is not provided, it will be computed from `length(prototype)`.

A given return value of this will be passed to `copy_into!` as the first argument,
so it should be compatible.
"""
function allocate_container(
prototype::N, n::Union{Nothing,Integer}=nothing
) where {T,N<:AbstractExpressionNode{T}}
num_nodes = @something(n, length(prototype))
return N[with_type_parameters(N, T)() for _ in 1:num_nodes]
end

"""
copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode}

Copy a node, recursively copying all children nodes, in-place to a preallocated container.
This should result in no extra allocations.
"""
function copy_into!(
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
) where {N<:AbstractExpressionNode}
_ref = if ref === nothing
Ref(0)
else
ref.x = 0
ref
end
return tree_mapreduce(
leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf),
identity,
((p, c::Vararg{Any,M}) where {M}) ->
branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
src,
N,
)
end
# COV_EXCL_START
function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}
set_node!(dest, src)
return dest
end
# COV_EXCL_STOP
function branch_copy_into!(
dest::N, src::N, children::Vararg{N,M}
) where {N<:AbstractExpressionNode,M}
dest.degree = M
dest.op = src.op
dest.l = children[1]
if M == 2
dest.r = children[2]
end
return dest
end

end
62 changes: 46 additions & 16 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk

using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
using ..ExpressionModule: AbstractExpression, Metadata
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
using ..ChainRulesModule: NodeTangent

import ..NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
leaf_copy,
leaf_copy!,
leaf_convert,
leaf_hash,
leaf_equal,
branch_copy!
set_node!
import ..NodePreallocationModule: copy_into!, allocate_container
import ..NodeUtilsModule:
count_constant_nodes,
index_constant_nodes,
Expand Down Expand Up @@ -124,21 +124,29 @@ function leaf_copy(t::ParametricNode{T}) where {T}
return n
end
end
function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}}
dest.degree = 0
if src.constant
dest.constant = true
dest.val = src.val
elseif !src.is_parameter
dest.constant = false
dest.is_parameter = false
dest.feature = src.feature
function set_node!(tree::ParametricNode, new_tree::ParametricNode)
tree.degree = new_tree.degree
if new_tree.degree == 0
if new_tree.constant
tree.constant = true
tree.val = new_tree.val
elseif !new_tree.is_parameter
tree.constant = false
tree.is_parameter = false
tree.feature = new_tree.feature
else
tree.constant = false
tree.is_parameter = true
tree.parameter = new_tree.parameter
end
else
dest.constant = false
dest.is_parameter = true
dest.parameter = src.parameter
tree.op = new_tree.op
tree.l = new_tree.l
if new_tree.degree == 2
tree.r = new_tree.r
end
end
return dest
return nothing
end
function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}}
if t.constant
Expand Down Expand Up @@ -444,6 +452,28 @@ end
return node_type(; val=ex)
end
end
function allocate_container(
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
)
return (;
tree=allocate_container(get_contents(prototype), n),
parameters=similar(get_metadata(prototype).parameters),
)
end
function copy_into!(dest::NamedTuple, src::ParametricExpression)
new_tree = copy_into!(dest.tree, get_contents(src))
metadata = get_metadata(src)
new_parameters = dest.parameters
new_parameters .= metadata.parameters
new_metadata = Metadata((;
operators=metadata.operators,
variable_names=metadata.variable_names,
parameters=new_parameters,
parameter_names=metadata.parameter_names,
))
# TODO: Better interface for this^
return with_metadata(with_contents(src, new_tree), new_metadata)
end
###############################################################################

end
14 changes: 14 additions & 0 deletions src/StructuredExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ using ..ExpressionModule: AbstractExpression, Metadata, node_type
using ..ChainRulesModule: NodeTangent

import ..NodeModule: constructorof
import ..NodePreallocationModule: copy_into!, allocate_container
import ..ExpressionModule:
get_contents,
get_metadata,
get_tree,
get_operators,
get_variable_names,
with_contents,
Metadata,
_copy,
_data,
Expand Down Expand Up @@ -164,4 +166,16 @@ function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs)
return e
end

function allocate_container(
e::AbstractStructuredExpression, n::Union{Nothing,Integer}=nothing
)
ts = get_contents(e)
return (; trees=NamedTuple{keys(ts)}(map(t -> allocate_container(t, n), values(ts))))
end
function copy_into!(dest::NamedTuple, src::AbstractStructuredExpression)
ts = get_contents(src)
new_contents = NamedTuple{keys(ts)}(map(copy_into!, values(dest.trees), values(ts)))
return with_contents(src, new_contents)
end

end
Loading
Loading