diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 8e32d899..6c0ba5f8 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -9,6 +9,7 @@ using DispatchDoctor: @stable, @unstable include("OperatorEnum.jl") include("Node.jl") include("NodeUtils.jl") + include("NodePreallocation.jl") include("Strings.jl") include("Evaluate.jl") include("EvaluateDerivative.jl") @@ -41,11 +42,11 @@ import .ValueInterfaceModule: GraphNode, Node, copy_node, - copy_node!, set_node!, tree_mapreduce, filter_map, filter_map! +import .NodePreallocationModule: allocate_container, copy_into! import .NodeModule: constructorof, with_type_parameters, diff --git a/src/Expression.jl b/src/Expression.jl index be927269..68fd8818 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -19,6 +19,7 @@ import ..NodeUtilsModule: count_scalar_constants, get_scalar_constants, set_scalar_constants! +import ..NodePreallocationModule: copy_into!, allocate_container import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array import ..EvaluateDerivativeModule: eval_grad_tree_array import ..EvaluationHelpersModule: _grad_evaluator @@ -502,4 +503,23 @@ function (ex::AbstractExpression)( return get_tree(ex)(X, get_operators(ex, operators); kws...) end +# We don't require users to overload this, as it's not part of the required interface. +# Also, there's no way to generally do this from the required interface, so for backwards +# compatibility, we just return nothing. +# COV_EXCL_START +function copy_into!(::Nothing, src::AbstractExpression) + return copy(src) +end +function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing) + return nothing +end +# COV_EXCL_STOP +function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing) + return (; tree=allocate_container(get_contents(prototype), n)) +end +function copy_into!(dest::NamedTuple, src::Expression) + tree = copy_into!(dest.tree, get_contents(src)) + return with_contents(src, tree) +end + end diff --git a/src/Interfaces.jl b/src/Interfaces.jl index b950ec97..c2d44b59 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -13,12 +13,10 @@ using ..NodeModule: default_allocator, with_type_parameters, leaf_copy, - leaf_copy!, leaf_convert, leaf_hash, leaf_equal, branch_copy, - branch_copy!, branch_convert, branch_hash, branch_equal, @@ -38,6 +36,8 @@ using ..NodeUtilsModule: has_constants, get_scalar_constants, set_scalar_constants! +using ..NodePreallocationModule: + copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container using ..StringsModule: string_tree using ..EvaluateModule: eval_tree_array using ..EvaluateDerivativeModule: eval_grad_tree_array @@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression) end ## optional +function _check_copy_into!(ex::AbstractExpression) + container = allocate_container(ex) + prealloc_ex = copy_into!(container, ex) + return container !== nothing && prealloc_ex == ex && prealloc_ex !== ex +end function _check_count_nodes(ex::AbstractExpression) return count_nodes(ex) isa Int64 end @@ -156,6 +161,7 @@ ei_components = ( with_metadata = "returns the expression with different metadata" => _check_with_metadata, ), optional = ( + copy_into! = "copies an expression into a preallocated container" => _check_copy_into!, count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes, count_constant_nodes = "counts the number of constant nodes in the expression tree" => _check_count_constant_nodes, count_depth = "calculates the depth of the expression tree" => _check_count_depth, @@ -260,14 +266,19 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode) end ## optional +function _check_copy_into!(tree::AbstractExpressionNode) + container = allocate_container(tree) + prealloc_tree = copy_into!(container, tree) + return container !== nothing && prealloc_tree == tree && prealloc_tree !== container +end function _check_leaf_copy(tree::AbstractExpressionNode) tree.degree != 0 && return true return leaf_copy(tree) isa typeof(tree) end -function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T} +function _check_leaf_copy_into!(tree::AbstractExpressionNode{T}) where {T} tree.degree != 0 && return true new_leaf = constructorof(typeof(tree))(; val=zero(T)) - ret = leaf_copy!(new_leaf, tree) + ret = leaf_copy_into!(new_leaf, tree) return new_leaf == tree && ret === new_leaf end function _check_leaf_convert(tree::AbstractExpressionNode) @@ -292,16 +303,16 @@ function _check_branch_copy(tree::AbstractExpressionNode) return branch_copy(tree, tree.l, tree.r) isa typeof(tree) end end -function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T} +function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T} if tree.degree == 0 return true end new_branch = constructorof(typeof(tree))(; val=zero(T)) if tree.degree == 1 - ret = branch_copy!(new_branch, tree, copy(tree.l)) + ret = branch_copy_into!(new_branch, tree, copy(tree.l)) return new_branch == tree && ret === new_branch else - ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r)) + ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r)) return new_branch == tree && ret === new_branch end end @@ -372,13 +383,14 @@ ni_components = ( tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce ), optional = ( + copy_into! = "copies a node into a preallocated container" => _check_copy_into!, leaf_copy = "copies a leaf node" => _check_leaf_copy, - leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!, + leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!, leaf_convert = "converts a leaf node" => _check_leaf_convert, leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash, leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal, branch_copy = "copies a branch node" => _check_branch_copy, - branch_copy! = "copies a branch node in-place" => _check_branch_copy!, + branch_copy_into! = "copies a branch node in-place" => _check_branch_copy_into!, branch_convert = "converts a branch node" => _check_branch_convert, branch_hash = "computes the hash of a branch node" => _check_branch_hash, branch_equal = "checks equality of two branch nodes" => _check_branch_equal, @@ -419,7 +431,7 @@ ni_description = ( [Arguments()] ) @implements( - NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))}, + NodeInterface{all_ni_methods_except(())}, GraphNode, [Arguments()] ) diff --git a/src/Node.jl b/src/Node.jl index 40667f94..ddceaaa4 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -321,23 +321,12 @@ function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where { return GraphNode{promote_type(T1, T2)} end -# TODO: Verify using this helps with garbage collection -create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N() - """ set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T} Set every field of `tree` equal to the corresponding field of `new_tree`. """ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode) - # First, ensure we free some memory: - if new_tree.degree < 2 && tree.degree == 2 - tree.r = create_dummy_node(typeof(tree)) - end - if new_tree.degree < 1 && tree.degree >= 1 - tree.l = create_dummy_node(typeof(tree)) - end - tree.degree = new_tree.degree if new_tree.degree == 0 tree.constant = new_tree.constant diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl new file mode 100644 index 00000000..ccce372d --- /dev/null +++ b/src/NodePreallocation.jl @@ -0,0 +1,69 @@ +module NodePreallocationModule + +using ..NodeModule: + AbstractExpressionNode, + with_type_parameters, + tree_mapreduce, + leaf_copy, + branch_copy, + set_node! + +""" + allocate_container(prototype::AbstractExpressionNode, n=nothing) + +Preallocate an array of `n` empty nodes matching the type of `prototype`. +If `n` is not provided, it will be computed from `length(prototype)`. + +A given return value of this will be passed to `copy_into!` as the first argument, +so it should be compatible. +""" +function allocate_container( + prototype::N, n::Union{Nothing,Integer}=nothing +) where {T,N<:AbstractExpressionNode{T}} + num_nodes = @something(n, length(prototype)) + return N[with_type_parameters(N, T)() for _ in 1:num_nodes] +end + +""" + copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode} + +Copy a node, recursively copying all children nodes, in-place to a preallocated container. +This should result in no extra allocations. +""" +function copy_into!( + dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing +) where {N<:AbstractExpressionNode} + _ref = if ref === nothing + Ref(0) + else + ref.x = 0 + ref + end + return tree_mapreduce( + leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf), + identity, + ((p, c::Vararg{Any,M}) where {M}) -> + branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...), + src, + N, + ) +end +# COV_EXCL_START +function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode} + set_node!(dest, src) + return dest +end +# COV_EXCL_STOP +function branch_copy_into!( + dest::N, src::N, children::Vararg{N,M} +) where {N<:AbstractExpressionNode,M} + dest.degree = M + dest.op = src.op + dest.l = children[1] + if M == 2 + dest.r = children[2] + end + return dest +end + +end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 60f5fb41..16d27254 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -5,7 +5,7 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce -using ..ExpressionModule: AbstractExpression, Metadata +using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata using ..ChainRulesModule: NodeTangent import ..NodeModule: @@ -13,11 +13,11 @@ import ..NodeModule: with_type_parameters, preserve_sharing, leaf_copy, - leaf_copy!, leaf_convert, leaf_hash, leaf_equal, - branch_copy! + set_node! +import ..NodePreallocationModule: copy_into!, allocate_container import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, @@ -124,21 +124,29 @@ function leaf_copy(t::ParametricNode{T}) where {T} return n end end -function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}} - dest.degree = 0 - if src.constant - dest.constant = true - dest.val = src.val - elseif !src.is_parameter - dest.constant = false - dest.is_parameter = false - dest.feature = src.feature +function set_node!(tree::ParametricNode, new_tree::ParametricNode) + tree.degree = new_tree.degree + if new_tree.degree == 0 + if new_tree.constant + tree.constant = true + tree.val = new_tree.val + elseif !new_tree.is_parameter + tree.constant = false + tree.is_parameter = false + tree.feature = new_tree.feature + else + tree.constant = false + tree.is_parameter = true + tree.parameter = new_tree.parameter + end else - dest.constant = false - dest.is_parameter = true - dest.parameter = src.parameter + tree.op = new_tree.op + tree.l = new_tree.l + if new_tree.degree == 2 + tree.r = new_tree.r + end end - return dest + return nothing end function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}} if t.constant @@ -444,6 +452,28 @@ end return node_type(; val=ex) end end +function allocate_container( + prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing +) + return (; + tree=allocate_container(get_contents(prototype), n), + parameters=similar(get_metadata(prototype).parameters), + ) +end +function copy_into!(dest::NamedTuple, src::ParametricExpression) + new_tree = copy_into!(dest.tree, get_contents(src)) + metadata = get_metadata(src) + new_parameters = dest.parameters + new_parameters .= metadata.parameters + new_metadata = Metadata((; + operators=metadata.operators, + variable_names=metadata.variable_names, + parameters=new_parameters, + parameter_names=metadata.parameter_names, + )) + # TODO: Better interface for this^ + return with_metadata(with_contents(src, new_tree), new_metadata) +end ############################################################################### end diff --git a/src/StructuredExpression.jl b/src/StructuredExpression.jl index 28ed6cd6..963da0e4 100644 --- a/src/StructuredExpression.jl +++ b/src/StructuredExpression.jl @@ -6,12 +6,14 @@ using ..ExpressionModule: AbstractExpression, Metadata, node_type using ..ChainRulesModule: NodeTangent import ..NodeModule: constructorof +import ..NodePreallocationModule: copy_into!, allocate_container import ..ExpressionModule: get_contents, get_metadata, get_tree, get_operators, get_variable_names, + with_contents, Metadata, _copy, _data, @@ -164,4 +166,16 @@ function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs) return e end +function allocate_container( + e::AbstractStructuredExpression, n::Union{Nothing,Integer}=nothing +) + ts = get_contents(e) + return (; trees=NamedTuple{keys(ts)}(map(t -> allocate_container(t, n), values(ts)))) +end +function copy_into!(dest::NamedTuple, src::AbstractStructuredExpression) + ts = get_contents(src) + new_contents = NamedTuple{keys(ts)}(map(copy_into!, values(dest.trees), values(ts))) + return with_contents(src, new_contents) +end + end diff --git a/src/base.jl b/src/base.jl index 8a656ab3..3d29a404 100644 --- a/src/base.jl +++ b/src/base.jl @@ -485,54 +485,6 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi return constructorof(N)(T; op=t.op, children) end -# In-place versions - -""" - copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode} - -Copy a node, recursively copying all children nodes, in-place to an -array of pre-allocated nodes. This should result in no extra allocations. -""" -function copy_node!( - dest::AbstractArray{N}, - src::N; - break_sharing::Val{BS}=Val(false), - ref::Base.RefValue{<:Integer}=Ref(0), -) where {BS,N<:AbstractExpressionNode} - ref.x = 0 - return tree_mapreduce( - leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf), - identity, - ((p, c::Vararg{Any,M}) where {M}) -> - branch_copy!(@inbounds(dest[ref.x += 1]), p, c...), - src, - N; - break_sharing=Val(BS), - ) -end -function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}} - dest.degree = 0 - if src.constant - dest.constant = true - dest.val = src.val - else - dest.constant = false - dest.feature = src.feature - end - return dest -end -function branch_copy!( - dest::N, src::N, children::Vararg{N,M} -) where {T,N<:AbstractExpressionNode{T},M} - dest.degree = M - dest.op = src.op - dest.l = children[1] - if M == 2 - dest.r = children[2] - end - return dest -end - """ copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false)) diff --git a/test/test_copy_inplace.jl b/test/test_copy_inplace.jl index 4b337aec..839d3d55 100644 --- a/test/test_copy_inplace.jl +++ b/test/test_copy_inplace.jl @@ -1,6 +1,6 @@ -@testitem "copy_node! - random trees" begin +@testitem "copy_into! - random trees" begin using DynamicExpressions - using DynamicExpressions: copy_node! + using DynamicExpressions: copy_into! include("tree_gen_utils.jl") operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos]) @@ -15,7 +15,7 @@ orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes ref = Ref(0) - result = copy_node!(dest_array, tree; ref) + result = copy_into!(dest_array, tree; ref) @test ref[] == n_nodes # Increment once per node @@ -35,9 +35,9 @@ end end -@testitem "copy_node! - leaf nodes" begin +@testitem "copy_into! - leaf nodes" begin using DynamicExpressions - using DynamicExpressions: copy_node! + using DynamicExpressions: copy_into! leaf_constant = Node{Float64}(; val=1.0) leaf_feature = Node{Float64}(; feature=1) @@ -45,9 +45,51 @@ end for leaf in [leaf_constant, leaf_feature] dest_array = [Node{Float64}() for _ in 1:1] ref = Ref(0) - result = copy_node!(dest_array, leaf; ref=ref) + result = copy_into!(dest_array, leaf; ref=ref) @test ref[] == 1 @test result == leaf @test result === dest_array[1] end end + +@testitem "copy_into! with expressions" begin + using DynamicExpressions + using DynamicExpressions: + copy_into!, allocate_container, get_operators, get_variable_names + + operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin]) + variable_names = ["x", "y"] + + # Test regular Expression + ex = @parse_expression( + sin(x + 2.0 * y), operators = operators, variable_names = variable_names + ) + container = allocate_container(ex) + result = copy_into!(container, ex) + + @test result == ex + @test result !== ex + @test get_tree(result) !== get_tree(ex) + @test get_operators(result, nothing) === get_operators(ex, nothing) + @test get_variable_names(result, nothing) === get_variable_names(ex, nothing) + + # Test ParametricExpression + parameters = [1.0 2.0; 3.0 4.0] + pex = @parse_expression( + sin(x + p1 * y + p2), + operators = operators, + variable_names = variable_names, + expression_type = ParametricExpression, + extra_metadata = (; parameters=parameters, parameter_names=["p1", "p2"]) + ) + container = allocate_container(pex) + result = copy_into!(container, pex) + + @test result == pex + @test result !== pex + @test get_tree(result) !== get_tree(pex) + @test get_operators(result, nothing) === get_operators(pex, nothing) + @test get_variable_names(result, nothing) === get_variable_names(pex, nothing) + @test result.metadata.parameters !== pex.metadata.parameters + @test result.metadata.parameters == pex.metadata.parameters +end diff --git a/test/test_parametric_expression.jl b/test/test_parametric_expression.jl index a3aceed3..e222765f 100644 --- a/test/test_parametric_expression.jl +++ b/test/test_parametric_expression.jl @@ -26,7 +26,7 @@ end using Interfaces: test ex = @parse_expression( - x + y + p1 * p2, + x + y + p1 * p2 + 1.5, binary_operators = [+, -, *, /], variable_names = ["x", "y"], node_type = ParametricNode,