@@ -19,6 +19,7 @@ import (
19
19
"path/filepath"
20
20
"strconv"
21
21
"strings"
22
+ "unicode"
22
23
23
24
"golang.org/x/tools/go/ast/astutil"
24
25
"golang.org/x/tools/gopls/internal/cache"
@@ -29,10 +30,21 @@ import (
29
30
)
30
31
31
32
const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
33
+ {{- /* Constructor input parameters struct declaration. */}}
34
+ {{- if and .Receiver .Receiver.Constructor}}
35
+ {{- if gt (len .Receiver.Constructor.Args) 1}}
36
+ type constructorArgs struct {
37
+ {{- range .Receiver.Constructor.Args}}
38
+ {{.Name}} {{.Type}}
39
+ {{- end}}
40
+ }
41
+ {{- end}}
42
+ {{- end}}
43
+
32
44
{{- /* Functions/methods input parameters struct declaration. */}}
33
- {{- if gt (len .Args) 1}}
45
+ {{- if gt (len .Func. Args) 1}}
34
46
type args struct {
35
- {{- range .Args}}
47
+ {{- range .Func. Args}}
36
48
{{.Name}} {{.Type}}
37
49
{{- end}}
38
50
}
@@ -41,13 +53,22 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
41
53
{{- /* Test cases struct declaration and empty initialization. */}}
42
54
tests := []struct {
43
55
name string // description of this test case
44
- {{- if gt (len .Args) 1}}
56
+ {{- if and .Receiver .Receiver.Constructor}}
57
+ {{- if gt (len .Receiver.Constructor.Args) 1}}
58
+ constructorArgs constructorArgs
59
+ {{- end}}
60
+ {{- if eq (len .Receiver.Constructor.Args) 1}}
61
+ constructorArg {{(index .Receiver.Constructor.Args 0).Type}}
62
+ {{- end}}
63
+ {{- end}}
64
+
65
+ {{- if gt (len .Func.Args) 1}}
45
66
args args
46
67
{{- end}}
47
- {{- if eq (len .Args) 1}}
48
- arg {{(index .Args 0).Type}}
68
+ {{- if eq (len .Func. Args) 1}}
69
+ arg {{(index .Func. Args 0).Type}}
49
70
{{- end}}
50
- {{- range $index, $res := .Results}}
71
+ {{- range $index, $res := .Func. Results}}
51
72
{{- if eq $res.Name "gotErr"}}
52
73
wantErr bool
53
74
{{- else if eq $index 0}}
@@ -63,38 +84,64 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
63
84
{{- /* Loop over all the test cases. */}}
64
85
for _, tt := range tests {
65
86
t.Run(tt.name, func(t *testing.T) {
66
- {{/* Got variables. */}}
67
- {{- if .Results}}{{fieldNames .Results ""}} := {{end}}
87
+ {{- /* Constructor or empty initialization. */}}
88
+ {{- if .Receiver}}
89
+ {{- if .Receiver.Constructor}}
90
+ {{- /* Receiver variable by calling constructor. */}}
91
+ {{fieldNames .Receiver.Constructor.Results ""}} := {{if .PackageName}}{{.PackageName}}.{{end}}
92
+ {{- .Receiver.Constructor.Name}}
93
+
94
+ {{- /* Constructor input parameters. */ -}}
95
+ ({{- if eq (len .Receiver.Constructor.Args) 1}}tt.constructorArg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Receiver.Constructor.Args "tt.constructorArgs."}}{{end}})
96
+
97
+ {{- /* Handles the error return from constructor. */}}
98
+ {{- $last := last .Receiver.Constructor.Results}}
99
+ {{- if eq $last.Type "error"}}
100
+ if err != nil {
101
+ t.Fatalf("could not contruct receiver type: %v", err)
102
+ }
103
+ {{- end}}
104
+ {{- else}}
105
+ {{- /* Receiver variable declaration. */}}
106
+ // TODO: construct the receiver type.
107
+ var {{.Receiver.Var.Name}} {{.Receiver.Var.Type}}
108
+ {{- end}}
109
+ {{- end}}
110
+
111
+ {{- /* Got variables. */}}
112
+ {{if .Func.Results}}{{fieldNames .Func.Results ""}} := {{end}}
68
113
69
- {{- /* Call expression. In xtest package test, call function by PACKAGE.FUNC. */}}
70
- {{- /* TODO(hxjiang): consider any renaming in existing xtest package imports. E.g. import renamedfoo "foo". */}}
71
- {{- /* TODO(hxjiang): support add test for methods by calling the right constructor. */}}
72
- {{- if .PackageName}}{{.PackageName}}.{{end}}{{.FuncName}}
114
+ {{- /* Call expression. */}}
115
+ {{- if .Receiver}}{{/* Call method by VAR.METHOD. */}}
116
+ {{- .Receiver.Var.Name}}.
117
+ {{- else if .PackageName}}{{/* Call function by PACKAGE.FUNC. */}}
118
+ {{- .PackageName}}.
119
+ {{- end}}{{.Func.Name}}
73
120
74
- {{- /* Input parameters. */ -}}
75
- ({{- if eq (len .Args) 1}}tt.arg{{end}}{{if gt (len .Args) 1}}{{fieldNames .Args "tt.args."}}{{end}})
121
+ {{- /* Input parameters. */ -}}
122
+ ({{- if eq (len .Func. Args) 1}}tt.arg{{end}}{{if gt (len .Func. Args) 1}}{{fieldNames .Func .Args "tt.args."}}{{end}})
76
123
77
124
{{- /* Handles the returned error before the rest of return value. */}}
78
- {{- $last := index .Results (add (len .Results) -1) }}
79
- {{- if eq $last.Name "gotErr "}}
125
+ {{- $last := last .Func .Results}}
126
+ {{- if eq $last.Type "error "}}
80
127
if gotErr != nil {
81
128
if !tt.wantErr {
82
- t.Errorf("{{$.FuncName }}() failed: %v", gotErr)
129
+ t.Errorf("{{$.Func.Name }}() failed: %v", gotErr)
83
130
}
84
131
return
85
132
}
86
133
if tt.wantErr {
87
- t.Fatal("{{$.FuncName }}() succeeded unexpectedly")
134
+ t.Fatal("{{$.Func.Name }}() succeeded unexpectedly")
88
135
}
89
136
{{- end}}
90
137
91
138
{{- /* Compare the returned values except for the last returned error. */}}
92
- {{- if or (and .Results (ne $last.Name "gotErr ")) (and (gt (len .Results) 1) (eq $last.Name "gotErr "))}}
139
+ {{- if or (and .Func. Results (ne $last.Type "error ")) (and (gt (len .Func. Results) 1) (eq $last.Type "error "))}}
93
140
// TODO: update the condition below to compare got with tt.want.
94
- {{- range $index, $res := .Results}}
141
+ {{- range $index, $res := .Func. Results}}
95
142
{{- if ne $res.Name "gotErr"}}
96
143
if true {
97
- t.Errorf("{{$.FuncName }}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
144
+ t.Errorf("{{$.Func.Name }}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
98
145
}
99
146
{{- end}}
100
147
{{- end}}
@@ -108,16 +155,36 @@ type field struct {
108
155
Name , Type string
109
156
}
110
157
158
+ type function struct {
159
+ Name string
160
+ Args []field
161
+ Results []field
162
+ }
163
+
164
+ type receiver struct {
165
+ // Var is the name and type of the receiver variable.
166
+ Var field
167
+ // Constructor holds information about the constructor for the receiver type.
168
+ // If no qualified constructor is found, this field will be nil.
169
+ Constructor * function
170
+ }
171
+
111
172
type testInfo struct {
112
173
PackageName string
113
- FuncName string
114
174
TestFuncName string
115
- Args []field
116
- Results []field
175
+ // Func holds information about the function or method being tested.
176
+ Func function
177
+ // Receiver holds information about the receiver of the function or method
178
+ // being tested.
179
+ // This field is nil for functions and non-nil for methods.
180
+ Receiver * receiver
117
181
}
118
182
119
183
var testTmpl = template .Must (template .New ("test" ).Funcs (template.FuncMap {
120
184
"add" : func (a , b int ) int { return a + b },
185
+ "last" : func (slice []field ) field {
186
+ return slice [len (slice )- 1 ]
187
+ },
121
188
"fieldNames" : func (fields []field , qualifier string ) (res string ) {
122
189
var names []string
123
190
for _ , f := range fields {
@@ -309,43 +376,187 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
309
376
if err != nil {
310
377
return nil , err
311
378
}
379
+
312
380
data := testInfo {
313
- FuncName : fn . Name ( ),
381
+ PackageName : qf ( pkg . Types () ),
314
382
TestFuncName : testName ,
383
+ Func : function {
384
+ Name : fn .Name (),
385
+ },
315
386
}
316
387
317
- if sig .Recv () == nil && xtest {
318
- data .PackageName = qf (pkg .Types ())
319
- }
388
+ errorType := types .Universe .Lookup ("error" ).Type ()
320
389
321
- for i := range sig .Params ().Len () {
322
- if i == 0 {
323
- data .Args = append (data .Args , field {
324
- Name : "in" ,
325
- Type : types .TypeString (sig .Params ().At (i ).Type (), qf ),
326
- })
390
+ // TODO(hxjiang): if input parameter is not named (meaning it's not used),
391
+ // pass the zero value to the function call.
392
+ // TODO(hxjiang): if the input parameter is named, define the field by using
393
+ // the parameter's name instead of in%d.
394
+ // TODO(hxjiang): handle special case for ctx.Context input.
395
+ for index := range sig .Params ().Len () {
396
+ var name string
397
+ if index == 0 {
398
+ name = "in"
327
399
} else {
328
- data .Args = append (data .Args , field {
329
- Name : fmt .Sprintf ("in%d" , i + 1 ),
330
- Type : types .TypeString (sig .Params ().At (i ).Type (), qf ),
331
- })
400
+ name = fmt .Sprintf ("in%d" , index + 1 )
332
401
}
402
+ data .Func .Args = append (data .Func .Args , field {
403
+ Name : name ,
404
+ Type : types .TypeString (sig .Params ().At (index ).Type (), qf ),
405
+ })
333
406
}
334
407
335
- errorType := types .Universe .Lookup ("error" ).Type ()
336
- for i := range sig .Results ().Len () {
337
- name := "got"
338
- if i == sig .Results ().Len ()- 1 && types .Identical (sig .Results ().At (i ).Type (), errorType ) {
408
+ for index := range sig .Results ().Len () {
409
+ var name string
410
+ if index == sig .Results ().Len ()- 1 && types .Identical (sig .Results ().At (index ).Type (), errorType ) {
339
411
name = "gotErr"
340
- } else if i > 0 {
341
- name = fmt .Sprintf ("got%d" , i + 1 )
412
+ } else if index == 0 {
413
+ name = "got"
414
+ } else {
415
+ name = fmt .Sprintf ("got%d" , index + 1 )
342
416
}
343
- data .Results = append (data .Results , field {
417
+ data .Func . Results = append (data . Func .Results , field {
344
418
Name : name ,
345
- Type : types .TypeString (sig .Results ().At (i ).Type (), qf ),
419
+ Type : types .TypeString (sig .Results ().At (index ).Type (), qf ),
346
420
})
347
421
}
348
422
423
+ if sig .Recv () != nil {
424
+ // Find the preferred type for the receiver. We don't use
425
+ // typesinternal.ReceiverNamed here as we want to preserve aliases.
426
+ recvType := sig .Recv ().Type ()
427
+ if ptr , ok := recvType .(* types.Pointer ); ok {
428
+ recvType = ptr .Elem ()
429
+ }
430
+
431
+ t , ok := recvType .(typesinternal.NamedOrAlias )
432
+ if ! ok {
433
+ return nil , fmt .Errorf ("the receiver type is neither named type nor alias type" )
434
+ }
435
+
436
+ var varName string
437
+ {
438
+ var possibleNames []string // list of candidates, preferring earlier entries.
439
+ if len (sig .Recv ().Name ()) > 0 {
440
+ possibleNames = append (possibleNames ,
441
+ sig .Recv ().Name (), // receiver name.
442
+ string (sig .Recv ().Name ()[0 ]), // first character of receiver name.
443
+ )
444
+ }
445
+ possibleNames = append (possibleNames ,
446
+ string (t .Obj ().Name ()[0 ]), // first character of receiver type name.
447
+ )
448
+ if len (t .Obj ().Name ()) >= 2 {
449
+ possibleNames = append (possibleNames ,
450
+ string (t .Obj ().Name ()[:2 ]), // first two character of receiver type name.
451
+ )
452
+ }
453
+ var camelCase []rune
454
+ for i , s := range t .Obj ().Name () {
455
+ if i == 0 || unicode .IsUpper (s ) {
456
+ camelCase = append (camelCase , s )
457
+ }
458
+ }
459
+ possibleNames = append (possibleNames ,
460
+ string (camelCase ), // captalized initials.
461
+ )
462
+ for _ , name := range possibleNames {
463
+ name = strings .ToLower (name )
464
+ if name == "" || name == "t" || name == "tt" {
465
+ continue
466
+ }
467
+ varName = name
468
+ break
469
+ }
470
+ if varName == "" {
471
+ varName = "r" // default as "r" for "receiver".
472
+ }
473
+ }
474
+
475
+ data .Receiver = & receiver {
476
+ Var : field {
477
+ Name : varName ,
478
+ Type : types .TypeString (recvType , qf ),
479
+ },
480
+ }
481
+
482
+ // constructor is the selected constructor for type T.
483
+ var constructor * types.Func
484
+
485
+ // When finding the qualified constructor, the function should return the
486
+ // any type whose named type is the same type as T's named type.
487
+ _ , wantType := typesinternal .ReceiverNamed (sig .Recv ())
488
+ for _ , name := range pkg .Types ().Scope ().Names () {
489
+ f , ok := pkg .Types ().Scope ().Lookup (name ).(* types.Func )
490
+ if ! ok {
491
+ continue
492
+ }
493
+ if f .Signature ().Recv () != nil {
494
+ continue
495
+ }
496
+ // Unexported constructor is not visible in x_test package.
497
+ if xtest && ! f .Exported () {
498
+ continue
499
+ }
500
+ // Only allow constructors returning T, T, (T, error), or (T, error).
501
+ if f .Signature ().Results ().Len () > 2 || f .Signature ().Results ().Len () == 0 {
502
+ continue
503
+ }
504
+
505
+ _ , gotType := typesinternal .ReceiverNamed (f .Signature ().Results ().At (0 ))
506
+ if gotType == nil || ! types .Identical (gotType , wantType ) {
507
+ continue
508
+ }
509
+
510
+ if f .Signature ().Results ().Len () == 2 && ! types .Identical (f .Signature ().Results ().At (1 ).Type (), errorType ) {
511
+ continue
512
+ }
513
+
514
+ if constructor == nil {
515
+ constructor = f
516
+ }
517
+
518
+ // Functions named NewType are prioritized as constructors over other
519
+ // functions that match only the signature criteria.
520
+ if strings .EqualFold (strings .ToLower (f .Name ()), strings .ToLower ("new" + t .Obj ().Name ())) {
521
+ constructor = f
522
+ }
523
+ }
524
+
525
+ if constructor != nil {
526
+ data .Receiver .Constructor = & function {Name : constructor .Name ()}
527
+ for index := range constructor .Signature ().Params ().Len () {
528
+ var name string
529
+ if index == 0 {
530
+ name = "in"
531
+ } else {
532
+ name = fmt .Sprintf ("in%d" , index + 1 )
533
+ }
534
+ data .Receiver .Constructor .Args = append (data .Receiver .Constructor .Args , field {
535
+ Name : name ,
536
+ Type : types .TypeString (constructor .Signature ().Params ().At (index ).Type (), qf ),
537
+ })
538
+ }
539
+ for index := range constructor .Signature ().Results ().Len () {
540
+ var name string
541
+ if index == 0 {
542
+ // The first return value must be of type T, *T, or a type whose named
543
+ // type is the same as named type of T.
544
+ name = varName
545
+ } else if index == constructor .Signature ().Results ().Len ()- 1 && types .Identical (constructor .Signature ().Results ().At (index ).Type (), errorType ) {
546
+ name = "err"
547
+ } else {
548
+ // Drop any return values beyond the first and the last.
549
+ // e.g., "f, _, _, err := NewFoo()".
550
+ name = "_"
551
+ }
552
+ data .Receiver .Constructor .Results = append (data .Receiver .Constructor .Results , field {
553
+ Name : name ,
554
+ Type : types .TypeString (constructor .Signature ().Results ().At (index ).Type (), qf ),
555
+ })
556
+ }
557
+ }
558
+ }
559
+
349
560
var test bytes.Buffer
350
561
if err := testTmpl .Execute (& test , data ); err != nil {
351
562
return nil , err
0 commit comments