Skip to content

Commit 218f457

Browse files
committed
text/template: make reflect.Value indirections more robust
Always shadow or modify the original parameter name. With code like: func index(item reflect.Value, ... { v := indirectInterface(item) It was possible to incorrectly use 'item' and 'v' later in the function, which could result in subtle bugs. This is precisely the kind of mistake that led to #36199. Instead, don't keep both the old and new reflect.Value variables in scope. Always shadow or modify the original variable. While at it, simplify the signature of 'length', to receive a reflect.Value directly and save a few redundant lines. Change-Id: I01416636a9d49f81246d28b91aca6413b1ba1aa5 Reviewed-on: https://go-review.googlesource.com/c/go/+/212117 Run-TryBot: Daniel Martí <mvdan@mvdan.cc> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Roberto Clapis <robclap8@gmail.com> Reviewed-by: Rob Pike <r@golang.org>
1 parent 31acdcc commit 218f457

File tree

1 file changed

+66
-72
lines changed

1 file changed

+66
-72
lines changed

src/text/template/funcs.go

+66-72
Original file line numberDiff line numberDiff line change
@@ -185,41 +185,41 @@ func indexArg(index reflect.Value, cap int) (int, error) {
185185
// arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
186186
// indexed item must be a map, slice, or array.
187187
func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
188-
v := indirectInterface(item)
189-
if !v.IsValid() {
188+
item = indirectInterface(item)
189+
if !item.IsValid() {
190190
return reflect.Value{}, fmt.Errorf("index of untyped nil")
191191
}
192-
for _, i := range indexes {
193-
index := indirectInterface(i)
192+
for _, index := range indexes {
193+
index = indirectInterface(index)
194194
var isNil bool
195-
if v, isNil = indirect(v); isNil {
195+
if item, isNil = indirect(item); isNil {
196196
return reflect.Value{}, fmt.Errorf("index of nil pointer")
197197
}
198-
switch v.Kind() {
198+
switch item.Kind() {
199199
case reflect.Array, reflect.Slice, reflect.String:
200-
x, err := indexArg(index, v.Len())
200+
x, err := indexArg(index, item.Len())
201201
if err != nil {
202202
return reflect.Value{}, err
203203
}
204-
v = v.Index(x)
204+
item = item.Index(x)
205205
case reflect.Map:
206-
index, err := prepareArg(index, v.Type().Key())
206+
index, err := prepareArg(index, item.Type().Key())
207207
if err != nil {
208208
return reflect.Value{}, err
209209
}
210-
if x := v.MapIndex(index); x.IsValid() {
211-
v = x
210+
if x := item.MapIndex(index); x.IsValid() {
211+
item = x
212212
} else {
213-
v = reflect.Zero(v.Type().Elem())
213+
item = reflect.Zero(item.Type().Elem())
214214
}
215215
case reflect.Invalid:
216-
// the loop holds invariant: v.IsValid()
216+
// the loop holds invariant: item.IsValid()
217217
panic("unreachable")
218218
default:
219-
return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
219+
return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
220220
}
221221
}
222-
return v, nil
222+
return item, nil
223223
}
224224

225225
// Slicing.
@@ -229,29 +229,27 @@ func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error)
229229
// is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
230230
// argument must be a string, slice, or array.
231231
func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
232-
var (
233-
cap int
234-
v = indirectInterface(item)
235-
)
236-
if !v.IsValid() {
232+
item = indirectInterface(item)
233+
if !item.IsValid() {
237234
return reflect.Value{}, fmt.Errorf("slice of untyped nil")
238235
}
239236
if len(indexes) > 3 {
240237
return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
241238
}
242-
switch v.Kind() {
239+
var cap int
240+
switch item.Kind() {
243241
case reflect.String:
244242
if len(indexes) == 3 {
245243
return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
246244
}
247-
cap = v.Len()
245+
cap = item.Len()
248246
case reflect.Array, reflect.Slice:
249-
cap = v.Cap()
247+
cap = item.Cap()
250248
default:
251-
return reflect.Value{}, fmt.Errorf("can't slice item of type %s", v.Type())
249+
return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
252250
}
253251
// set default values for cases item[:], item[i:].
254-
idx := [3]int{0, v.Len()}
252+
idx := [3]int{0, item.Len()}
255253
for i, index := range indexes {
256254
x, err := indexArg(index, cap)
257255
if err != nil {
@@ -264,44 +262,40 @@ func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error)
264262
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
265263
}
266264
if len(indexes) < 3 {
267-
return v.Slice(idx[0], idx[1]), nil
265+
return item.Slice(idx[0], idx[1]), nil
268266
}
269267
// given item[i:j:k], make sure i <= j <= k.
270268
if idx[1] > idx[2] {
271269
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
272270
}
273-
return v.Slice3(idx[0], idx[1], idx[2]), nil
271+
return item.Slice3(idx[0], idx[1], idx[2]), nil
274272
}
275273

276274
// Length
277275

278276
// length returns the length of the item, with an error if it has no defined length.
279-
func length(item interface{}) (int, error) {
280-
v := reflect.ValueOf(item)
281-
if !v.IsValid() {
282-
return 0, fmt.Errorf("len of untyped nil")
283-
}
284-
v, isNil := indirect(v)
277+
func length(item reflect.Value) (int, error) {
278+
item, isNil := indirect(item)
285279
if isNil {
286280
return 0, fmt.Errorf("len of nil pointer")
287281
}
288-
switch v.Kind() {
282+
switch item.Kind() {
289283
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
290-
return v.Len(), nil
284+
return item.Len(), nil
291285
}
292-
return 0, fmt.Errorf("len of type %s", v.Type())
286+
return 0, fmt.Errorf("len of type %s", item.Type())
293287
}
294288

295289
// Function invocation
296290

297291
// call returns the result of evaluating the first argument as a function.
298292
// The function must return 1 result, or 2 results, the second of which is an error.
299293
func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
300-
v := indirectInterface(fn)
301-
if !v.IsValid() {
294+
fn = indirectInterface(fn)
295+
if !fn.IsValid() {
302296
return reflect.Value{}, fmt.Errorf("call of nil")
303297
}
304-
typ := v.Type()
298+
typ := fn.Type()
305299
if typ.Kind() != reflect.Func {
306300
return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
307301
}
@@ -322,19 +316,19 @@ func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
322316
}
323317
argv := make([]reflect.Value, len(args))
324318
for i, arg := range args {
325-
value := indirectInterface(arg)
319+
arg = indirectInterface(arg)
326320
// Compute the expected type. Clumsy because of variadics.
327321
argType := dddType
328322
if !typ.IsVariadic() || i < numIn-1 {
329323
argType = typ.In(i)
330324
}
331325

332326
var err error
333-
if argv[i], err = prepareArg(value, argType); err != nil {
327+
if argv[i], err = prepareArg(arg, argType); err != nil {
334328
return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
335329
}
336330
}
337-
return safeCall(v, argv)
331+
return safeCall(fn, argv)
338332
}
339333

340334
// safeCall runs fun.Call(args), and returns the resulting value and error, if
@@ -440,52 +434,52 @@ func basicKind(v reflect.Value) (kind, error) {
440434

441435
// eq evaluates the comparison a == b || a == c || ...
442436
func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
443-
v1 := indirectInterface(arg1)
444-
if v1 != zero {
445-
if t1 := v1.Type(); !t1.Comparable() {
446-
return false, fmt.Errorf("uncomparable type %s: %v", t1, v1)
437+
arg1 = indirectInterface(arg1)
438+
if arg1 != zero {
439+
if t1 := arg1.Type(); !t1.Comparable() {
440+
return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
447441
}
448442
}
449443
if len(arg2) == 0 {
450444
return false, errNoComparison
451445
}
452-
k1, _ := basicKind(v1)
446+
k1, _ := basicKind(arg1)
453447
for _, arg := range arg2 {
454-
v2 := indirectInterface(arg)
455-
k2, _ := basicKind(v2)
448+
arg = indirectInterface(arg)
449+
k2, _ := basicKind(arg)
456450
truth := false
457451
if k1 != k2 {
458452
// Special case: Can compare integer values regardless of type's sign.
459453
switch {
460454
case k1 == intKind && k2 == uintKind:
461-
truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
455+
truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
462456
case k1 == uintKind && k2 == intKind:
463-
truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
457+
truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
464458
default:
465459
return false, errBadComparison
466460
}
467461
} else {
468462
switch k1 {
469463
case boolKind:
470-
truth = v1.Bool() == v2.Bool()
464+
truth = arg1.Bool() == arg.Bool()
471465
case complexKind:
472-
truth = v1.Complex() == v2.Complex()
466+
truth = arg1.Complex() == arg.Complex()
473467
case floatKind:
474-
truth = v1.Float() == v2.Float()
468+
truth = arg1.Float() == arg.Float()
475469
case intKind:
476-
truth = v1.Int() == v2.Int()
470+
truth = arg1.Int() == arg.Int()
477471
case stringKind:
478-
truth = v1.String() == v2.String()
472+
truth = arg1.String() == arg.String()
479473
case uintKind:
480-
truth = v1.Uint() == v2.Uint()
474+
truth = arg1.Uint() == arg.Uint()
481475
default:
482-
if v2 == zero {
483-
truth = v1 == v2
476+
if arg == zero {
477+
truth = arg1 == arg
484478
} else {
485-
if t2 := v2.Type(); !t2.Comparable() {
486-
return false, fmt.Errorf("uncomparable type %s: %v", t2, v2)
479+
if t2 := arg.Type(); !t2.Comparable() {
480+
return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
487481
}
488-
truth = v1.Interface() == v2.Interface()
482+
truth = arg1.Interface() == arg.Interface()
489483
}
490484
}
491485
}
@@ -505,13 +499,13 @@ func ne(arg1, arg2 reflect.Value) (bool, error) {
505499

506500
// lt evaluates the comparison a < b.
507501
func lt(arg1, arg2 reflect.Value) (bool, error) {
508-
v1 := indirectInterface(arg1)
509-
k1, err := basicKind(v1)
502+
arg1 = indirectInterface(arg1)
503+
k1, err := basicKind(arg1)
510504
if err != nil {
511505
return false, err
512506
}
513-
v2 := indirectInterface(arg2)
514-
k2, err := basicKind(v2)
507+
arg2 = indirectInterface(arg2)
508+
k2, err := basicKind(arg2)
515509
if err != nil {
516510
return false, err
517511
}
@@ -520,9 +514,9 @@ func lt(arg1, arg2 reflect.Value) (bool, error) {
520514
// Special case: Can compare integer values regardless of type's sign.
521515
switch {
522516
case k1 == intKind && k2 == uintKind:
523-
truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
517+
truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
524518
case k1 == uintKind && k2 == intKind:
525-
truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
519+
truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
526520
default:
527521
return false, errBadComparison
528522
}
@@ -531,13 +525,13 @@ func lt(arg1, arg2 reflect.Value) (bool, error) {
531525
case boolKind, complexKind:
532526
return false, errBadComparisonType
533527
case floatKind:
534-
truth = v1.Float() < v2.Float()
528+
truth = arg1.Float() < arg2.Float()
535529
case intKind:
536-
truth = v1.Int() < v2.Int()
530+
truth = arg1.Int() < arg2.Int()
537531
case stringKind:
538-
truth = v1.String() < v2.String()
532+
truth = arg1.String() < arg2.String()
539533
case uintKind:
540-
truth = v1.Uint() < v2.Uint()
534+
truth = arg1.Uint() < arg2.Uint()
541535
default:
542536
panic("invalid kind")
543537
}

0 commit comments

Comments
 (0)