Skip to content

Commit 0f04769

Browse files
authored
Merge pull request #122 from SymbolicML/allow-pretty-operators
feat: allow separate operator names for pretty printing
2 parents b3ed0b6 + e3cd937 commit 0f04769

File tree

3 files changed

+81
-14
lines changed

3 files changed

+81
-14
lines changed

src/DynamicExpressions.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ import .NodeModule:
6969
get_scalar_constants,
7070
set_scalar_constants!
7171
@reexport import .StringsModule: string_tree, print_tree
72-
import .StringsModule: get_op_name
72+
import .StringsModule: get_op_name, get_pretty_op_name
7373
@reexport import .OperatorEnumModule: AbstractOperatorEnum
7474
@reexport import .OperatorEnumConstructionModule:
7575
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!

src/Strings.jl

+22-13
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@ using ..UtilsModule: deprecate_varmap
44
using ..OperatorEnumModule: AbstractOperatorEnum
55
using ..NodeModule: AbstractExpressionNode, tree_mapreduce
66

7-
function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
8-
if deg == 1
9-
return vcat(collect("unary_operator["), collect(string(idx)), [']'])
10-
else
11-
return vcat(collect("binary_operator["), collect(string(idx)), [']'])
12-
end
7+
function dispatch_op_name(
8+
::Val{deg}, ::Nothing, idx, pretty::Bool
9+
)::Vector{Char} where {deg}
10+
return vcat(
11+
collect(deg == 1 ? "unary_operator[" : "binary_operator["),
12+
collect(string(idx)),
13+
[']'],
14+
)
1315
end
14-
function dispatch_op_name(::Val{deg}, operators::AbstractOperatorEnum, idx) where {deg}
15-
if deg == 1
16-
return collect(get_op_name(operators.unaops[idx])::String)
16+
function dispatch_op_name(
17+
::Val{deg}, operators::AbstractOperatorEnum, idx, pretty::Bool
18+
) where {deg}
19+
op = if deg == 1
20+
operators.unaops[idx]
1721
else
18-
return collect(get_op_name(operators.binops[idx])::String)
22+
operators.binops[idx]
1923
end
24+
return collect((pretty ? get_pretty_op_name(op) : get_op_name(op))::String)
2025
end
2126

2227
const OP_NAME_CACHE = (; x=Dict{UInt64,String}(), lock=Threads.SpinLock())
@@ -47,6 +52,9 @@ function get_op_name(op::F) where {F}
4752
unlock(OP_NAME_CACHE.lock)
4853
end
4954
end
55+
function get_pretty_op_name(op::F) where {F}
56+
return get_op_name(op)
57+
end
5058

5159
@inline function strip_brackets(s::Vector{Char})::Vector{Char}
5260
if first(s) == '(' && last(s) == ')'
@@ -145,8 +153,9 @@ function string_tree(
145153
raw::Union{Bool,Nothing}=nothing,
146154
varMap=nothing,
147155
)::String where {T,F1<:Function,F2<:Function}
148-
!isnothing(raw) &&
156+
if !isnothing(raw)
149157
Base.depwarn("`raw` is deprecated; use `pretty` instead", :string_tree)
158+
end
150159
pretty = @something(pretty, _not(raw), false)
151160
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
152161
raw_output = tree_mapreduce(
@@ -162,9 +171,9 @@ function string_tree(
162171
end,
163172
let operators = operators
164173
(branch,) -> if branch.degree == 1
165-
dispatch_op_name(Val(1), operators, branch.op)::Vector{Char}
174+
dispatch_op_name(Val(1), operators, branch.op, pretty)::Vector{Char}
166175
else
167-
dispatch_op_name(Val(2), operators, branch.op)::Vector{Char}
176+
dispatch_op_name(Val(2), operators, branch.op, pretty)::Vector{Char}
168177
end
169178
end,
170179
combine_op_with_inputs,

test/test_print.jl

+58
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,61 @@ end
117117
@test string(tree) == "(k1 * x2) + x3"
118118
empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x)
119119
end
120+
121+
@testset "Test pretty format for operators" begin
122+
# Define a custom operator with different pretty representation
123+
@eval begin
124+
my_pretty_op(x, y) = x + y
125+
DE.get_op_name(::typeof(my_pretty_op)) = "my_pretty_op"
126+
DE.get_pretty_op_name(::typeof(my_pretty_op)) = "pretty_op_two"
127+
end
128+
129+
operators = OperatorEnum(;
130+
default_params...,
131+
binary_operators=(+, *, /, -, my_pretty_op),
132+
unary_operators=(cos, sin),
133+
)
134+
@extend_operators operators
135+
136+
x1, x2 = [Node(; feature=i) for i in 1:2]
137+
138+
# Test default format (not pretty)
139+
tree = my_pretty_op(x1, x2)
140+
@test string_tree(tree, operators) == "my_pretty_op(x1, x2)"
141+
142+
# Test pretty format
143+
@test string_tree(tree, operators; pretty=true) == "pretty_op_two(x1, x2)"
144+
145+
# Test with nested expressions
146+
tree = sin(my_pretty_op(x1, x2))
147+
@test string_tree(tree, operators) == "sin(my_pretty_op(x1, x2))"
148+
@test string_tree(tree, operators; pretty=true) == "sin(pretty_op_two(x1, x2))"
149+
150+
# Test with constants
151+
tree = my_pretty_op(x1, Node(; val=3.14))
152+
@test string_tree(tree, operators) == "my_pretty_op(x1, 3.14)"
153+
@test string_tree(tree, operators; pretty=true) == "pretty_op_two(x1, 3.14)"
154+
155+
# Test that the default implementation of get_pretty_op_name falls back to get_op_name
156+
tree = sin(x1)
157+
@test string_tree(tree, operators) == "sin(x1)"
158+
@test string_tree(tree, operators; pretty=true) == "sin(x1)"
159+
160+
# Test with a unary operator that has a different pretty name
161+
@eval begin
162+
my_unary_op(x) = sin(x)
163+
DE.get_op_name(::typeof(my_unary_op)) = "my_unary_op"
164+
DE.get_pretty_op_name(::typeof(my_unary_op)) = "sine"
165+
end
166+
167+
operators_with_unary = OperatorEnum(;
168+
default_params...,
169+
binary_operators=(+, *, /, -),
170+
unary_operators=(cos, sin, my_unary_op),
171+
)
172+
@extend_operators operators_with_unary
173+
174+
tree = my_unary_op(x1)
175+
@test string_tree(tree, operators_with_unary) == "my_unary_op(x1)"
176+
@test string_tree(tree, operators_with_unary; pretty=true) == "sine(x1)"
177+
end

0 commit comments

Comments
 (0)