Skip to content

Commit c875677

Browse files
h9jianggopherbot
authored andcommitted
gopls/internal/golang: support add test for receiver w/o constructor
Gopls loops over all the function and find the ones qualifies based on its signature. - When a qualifying constructor is found, the test skeleton uses the constructor to initialize the receiver and call the method. - When no constructor is found, gopls generates a test skeleton with a variable declaration using the receiver type (e.g., "var t T") and includes a TODO comment to remind users to implement receiver variable initialization. For golang/vscode-go#1594 Change-Id: I2a703bbd099f03fd1bf85e516f86484805b4a0ae Reviewed-on: https://go-review.googlesource.com/c/tools/+/620696 Auto-Submit: Hongxiang Jiang <hxjiang@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Robert Findley <rfindley@google.com>
1 parent e26dff9 commit c875677

File tree

2 files changed

+558
-74
lines changed

2 files changed

+558
-74
lines changed

gopls/internal/golang/addtest.go

Lines changed: 257 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"path/filepath"
2020
"strconv"
2121
"strings"
22+
"unicode"
2223

2324
"golang.org/x/tools/go/ast/astutil"
2425
"golang.org/x/tools/gopls/internal/cache"
@@ -29,10 +30,21 @@ import (
2930
)
3031

3132
const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
33+
{{- /* Constructor input parameters struct declaration. */}}
34+
{{- if and .Receiver .Receiver.Constructor}}
35+
{{- if gt (len .Receiver.Constructor.Args) 1}}
36+
type constructorArgs struct {
37+
{{- range .Receiver.Constructor.Args}}
38+
{{.Name}} {{.Type}}
39+
{{- end}}
40+
}
41+
{{- end}}
42+
{{- end}}
43+
3244
{{- /* Functions/methods input parameters struct declaration. */}}
33-
{{- if gt (len .Args) 1}}
45+
{{- if gt (len .Func.Args) 1}}
3446
type args struct {
35-
{{- range .Args}}
47+
{{- range .Func.Args}}
3648
{{.Name}} {{.Type}}
3749
{{- end}}
3850
}
@@ -41,13 +53,22 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
4153
{{- /* Test cases struct declaration and empty initialization. */}}
4254
tests := []struct {
4355
name string // description of this test case
44-
{{- if gt (len .Args) 1}}
56+
{{- if and .Receiver .Receiver.Constructor}}
57+
{{- if gt (len .Receiver.Constructor.Args) 1}}
58+
constructorArgs constructorArgs
59+
{{- end}}
60+
{{- if eq (len .Receiver.Constructor.Args) 1}}
61+
constructorArg {{(index .Receiver.Constructor.Args 0).Type}}
62+
{{- end}}
63+
{{- end}}
64+
65+
{{- if gt (len .Func.Args) 1}}
4566
args args
4667
{{- end}}
47-
{{- if eq (len .Args) 1}}
48-
arg {{(index .Args 0).Type}}
68+
{{- if eq (len .Func.Args) 1}}
69+
arg {{(index .Func.Args 0).Type}}
4970
{{- end}}
50-
{{- range $index, $res := .Results}}
71+
{{- range $index, $res := .Func.Results}}
5172
{{- if eq $res.Name "gotErr"}}
5273
wantErr bool
5374
{{- else if eq $index 0}}
@@ -63,38 +84,64 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
6384
{{- /* Loop over all the test cases. */}}
6485
for _, tt := range tests {
6586
t.Run(tt.name, func(t *testing.T) {
66-
{{/* Got variables. */}}
67-
{{- if .Results}}{{fieldNames .Results ""}} := {{end}}
87+
{{- /* Constructor or empty initialization. */}}
88+
{{- if .Receiver}}
89+
{{- if .Receiver.Constructor}}
90+
{{- /* Receiver variable by calling constructor. */}}
91+
{{fieldNames .Receiver.Constructor.Results ""}} := {{if .PackageName}}{{.PackageName}}.{{end}}
92+
{{- .Receiver.Constructor.Name}}
93+
94+
{{- /* Constructor input parameters. */ -}}
95+
({{- if eq (len .Receiver.Constructor.Args) 1}}tt.constructorArg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Receiver.Constructor.Args "tt.constructorArgs."}}{{end}})
96+
97+
{{- /* Handles the error return from constructor. */}}
98+
{{- $last := last .Receiver.Constructor.Results}}
99+
{{- if eq $last.Type "error"}}
100+
if err != nil {
101+
t.Fatalf("could not contruct receiver type: %v", err)
102+
}
103+
{{- end}}
104+
{{- else}}
105+
{{- /* Receiver variable declaration. */}}
106+
// TODO: construct the receiver type.
107+
var {{.Receiver.Var.Name}} {{.Receiver.Var.Type}}
108+
{{- end}}
109+
{{- end}}
110+
111+
{{- /* Got variables. */}}
112+
{{if .Func.Results}}{{fieldNames .Func.Results ""}} := {{end}}
68113
69-
{{- /* Call expression. In xtest package test, call function by PACKAGE.FUNC. */}}
70-
{{- /* TODO(hxjiang): consider any renaming in existing xtest package imports. E.g. import renamedfoo "foo". */}}
71-
{{- /* TODO(hxjiang): support add test for methods by calling the right constructor. */}}
72-
{{- if .PackageName}}{{.PackageName}}.{{end}}{{.FuncName}}
114+
{{- /* Call expression. */}}
115+
{{- if .Receiver}}{{/* Call method by VAR.METHOD. */}}
116+
{{- .Receiver.Var.Name}}.
117+
{{- else if .PackageName}}{{/* Call function by PACKAGE.FUNC. */}}
118+
{{- .PackageName}}.
119+
{{- end}}{{.Func.Name}}
73120
74-
{{- /* Input parameters. */ -}}
75-
({{- if eq (len .Args) 1}}tt.arg{{end}}{{if gt (len .Args) 1}}{{fieldNames .Args "tt.args."}}{{end}})
121+
{{- /* Input parameters. */ -}}
122+
({{- if eq (len .Func.Args) 1}}tt.arg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Func.Args "tt.args."}}{{end}})
76123
77124
{{- /* Handles the returned error before the rest of return value. */}}
78-
{{- $last := index .Results (add (len .Results) -1)}}
79-
{{- if eq $last.Name "gotErr"}}
125+
{{- $last := last .Func.Results}}
126+
{{- if eq $last.Type "error"}}
80127
if gotErr != nil {
81128
if !tt.wantErr {
82-
t.Errorf("{{$.FuncName}}() failed: %v", gotErr)
129+
t.Errorf("{{$.Func.Name}}() failed: %v", gotErr)
83130
}
84131
return
85132
}
86133
if tt.wantErr {
87-
t.Fatal("{{$.FuncName}}() succeeded unexpectedly")
134+
t.Fatal("{{$.Func.Name}}() succeeded unexpectedly")
88135
}
89136
{{- end}}
90137
91138
{{- /* Compare the returned values except for the last returned error. */}}
92-
{{- if or (and .Results (ne $last.Name "gotErr")) (and (gt (len .Results) 1) (eq $last.Name "gotErr"))}}
139+
{{- if or (and .Func.Results (ne $last.Type "error")) (and (gt (len .Func.Results) 1) (eq $last.Type "error"))}}
93140
// TODO: update the condition below to compare got with tt.want.
94-
{{- range $index, $res := .Results}}
141+
{{- range $index, $res := .Func.Results}}
95142
{{- if ne $res.Name "gotErr"}}
96143
if true {
97-
t.Errorf("{{$.FuncName}}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
144+
t.Errorf("{{$.Func.Name}}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
98145
}
99146
{{- end}}
100147
{{- end}}
@@ -108,16 +155,36 @@ type field struct {
108155
Name, Type string
109156
}
110157

158+
type function struct {
159+
Name string
160+
Args []field
161+
Results []field
162+
}
163+
164+
type receiver struct {
165+
// Var is the name and type of the receiver variable.
166+
Var field
167+
// Constructor holds information about the constructor for the receiver type.
168+
// If no qualified constructor is found, this field will be nil.
169+
Constructor *function
170+
}
171+
111172
type testInfo struct {
112173
PackageName string
113-
FuncName string
114174
TestFuncName string
115-
Args []field
116-
Results []field
175+
// Func holds information about the function or method being tested.
176+
Func function
177+
// Receiver holds information about the receiver of the function or method
178+
// being tested.
179+
// This field is nil for functions and non-nil for methods.
180+
Receiver *receiver
117181
}
118182

119183
var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{
120184
"add": func(a, b int) int { return a + b },
185+
"last": func(slice []field) field {
186+
return slice[len(slice)-1]
187+
},
121188
"fieldNames": func(fields []field, qualifier string) (res string) {
122189
var names []string
123190
for _, f := range fields {
@@ -309,43 +376,187 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
309376
if err != nil {
310377
return nil, err
311378
}
379+
312380
data := testInfo{
313-
FuncName: fn.Name(),
381+
PackageName: qf(pkg.Types()),
314382
TestFuncName: testName,
383+
Func: function{
384+
Name: fn.Name(),
385+
},
315386
}
316387

317-
if sig.Recv() == nil && xtest {
318-
data.PackageName = qf(pkg.Types())
319-
}
388+
errorType := types.Universe.Lookup("error").Type()
320389

321-
for i := range sig.Params().Len() {
322-
if i == 0 {
323-
data.Args = append(data.Args, field{
324-
Name: "in",
325-
Type: types.TypeString(sig.Params().At(i).Type(), qf),
326-
})
390+
// TODO(hxjiang): if input parameter is not named (meaning it's not used),
391+
// pass the zero value to the function call.
392+
// TODO(hxjiang): if the input parameter is named, define the field by using
393+
// the parameter's name instead of in%d.
394+
// TODO(hxjiang): handle special case for ctx.Context input.
395+
for index := range sig.Params().Len() {
396+
var name string
397+
if index == 0 {
398+
name = "in"
327399
} else {
328-
data.Args = append(data.Args, field{
329-
Name: fmt.Sprintf("in%d", i+1),
330-
Type: types.TypeString(sig.Params().At(i).Type(), qf),
331-
})
400+
name = fmt.Sprintf("in%d", index+1)
332401
}
402+
data.Func.Args = append(data.Func.Args, field{
403+
Name: name,
404+
Type: types.TypeString(sig.Params().At(index).Type(), qf),
405+
})
333406
}
334407

335-
errorType := types.Universe.Lookup("error").Type()
336-
for i := range sig.Results().Len() {
337-
name := "got"
338-
if i == sig.Results().Len()-1 && types.Identical(sig.Results().At(i).Type(), errorType) {
408+
for index := range sig.Results().Len() {
409+
var name string
410+
if index == sig.Results().Len()-1 && types.Identical(sig.Results().At(index).Type(), errorType) {
339411
name = "gotErr"
340-
} else if i > 0 {
341-
name = fmt.Sprintf("got%d", i+1)
412+
} else if index == 0 {
413+
name = "got"
414+
} else {
415+
name = fmt.Sprintf("got%d", index+1)
342416
}
343-
data.Results = append(data.Results, field{
417+
data.Func.Results = append(data.Func.Results, field{
344418
Name: name,
345-
Type: types.TypeString(sig.Results().At(i).Type(), qf),
419+
Type: types.TypeString(sig.Results().At(index).Type(), qf),
346420
})
347421
}
348422

423+
if sig.Recv() != nil {
424+
// Find the preferred type for the receiver. We don't use
425+
// typesinternal.ReceiverNamed here as we want to preserve aliases.
426+
recvType := sig.Recv().Type()
427+
if ptr, ok := recvType.(*types.Pointer); ok {
428+
recvType = ptr.Elem()
429+
}
430+
431+
t, ok := recvType.(typesinternal.NamedOrAlias)
432+
if !ok {
433+
return nil, fmt.Errorf("the receiver type is neither named type nor alias type")
434+
}
435+
436+
var varName string
437+
{
438+
var possibleNames []string // list of candidates, preferring earlier entries.
439+
if len(sig.Recv().Name()) > 0 {
440+
possibleNames = append(possibleNames,
441+
sig.Recv().Name(), // receiver name.
442+
string(sig.Recv().Name()[0]), // first character of receiver name.
443+
)
444+
}
445+
possibleNames = append(possibleNames,
446+
string(t.Obj().Name()[0]), // first character of receiver type name.
447+
)
448+
if len(t.Obj().Name()) >= 2 {
449+
possibleNames = append(possibleNames,
450+
string(t.Obj().Name()[:2]), // first two character of receiver type name.
451+
)
452+
}
453+
var camelCase []rune
454+
for i, s := range t.Obj().Name() {
455+
if i == 0 || unicode.IsUpper(s) {
456+
camelCase = append(camelCase, s)
457+
}
458+
}
459+
possibleNames = append(possibleNames,
460+
string(camelCase), // captalized initials.
461+
)
462+
for _, name := range possibleNames {
463+
name = strings.ToLower(name)
464+
if name == "" || name == "t" || name == "tt" {
465+
continue
466+
}
467+
varName = name
468+
break
469+
}
470+
if varName == "" {
471+
varName = "r" // default as "r" for "receiver".
472+
}
473+
}
474+
475+
data.Receiver = &receiver{
476+
Var: field{
477+
Name: varName,
478+
Type: types.TypeString(recvType, qf),
479+
},
480+
}
481+
482+
// constructor is the selected constructor for type T.
483+
var constructor *types.Func
484+
485+
// When finding the qualified constructor, the function should return the
486+
// any type whose named type is the same type as T's named type.
487+
_, wantType := typesinternal.ReceiverNamed(sig.Recv())
488+
for _, name := range pkg.Types().Scope().Names() {
489+
f, ok := pkg.Types().Scope().Lookup(name).(*types.Func)
490+
if !ok {
491+
continue
492+
}
493+
if f.Signature().Recv() != nil {
494+
continue
495+
}
496+
// Unexported constructor is not visible in x_test package.
497+
if xtest && !f.Exported() {
498+
continue
499+
}
500+
// Only allow constructors returning T, T, (T, error), or (T, error).
501+
if f.Signature().Results().Len() > 2 || f.Signature().Results().Len() == 0 {
502+
continue
503+
}
504+
505+
_, gotType := typesinternal.ReceiverNamed(f.Signature().Results().At(0))
506+
if gotType == nil || !types.Identical(gotType, wantType) {
507+
continue
508+
}
509+
510+
if f.Signature().Results().Len() == 2 && !types.Identical(f.Signature().Results().At(1).Type(), errorType) {
511+
continue
512+
}
513+
514+
if constructor == nil {
515+
constructor = f
516+
}
517+
518+
// Functions named NewType are prioritized as constructors over other
519+
// functions that match only the signature criteria.
520+
if strings.EqualFold(strings.ToLower(f.Name()), strings.ToLower("new"+t.Obj().Name())) {
521+
constructor = f
522+
}
523+
}
524+
525+
if constructor != nil {
526+
data.Receiver.Constructor = &function{Name: constructor.Name()}
527+
for index := range constructor.Signature().Params().Len() {
528+
var name string
529+
if index == 0 {
530+
name = "in"
531+
} else {
532+
name = fmt.Sprintf("in%d", index+1)
533+
}
534+
data.Receiver.Constructor.Args = append(data.Receiver.Constructor.Args, field{
535+
Name: name,
536+
Type: types.TypeString(constructor.Signature().Params().At(index).Type(), qf),
537+
})
538+
}
539+
for index := range constructor.Signature().Results().Len() {
540+
var name string
541+
if index == 0 {
542+
// The first return value must be of type T, *T, or a type whose named
543+
// type is the same as named type of T.
544+
name = varName
545+
} else if index == constructor.Signature().Results().Len()-1 && types.Identical(constructor.Signature().Results().At(index).Type(), errorType) {
546+
name = "err"
547+
} else {
548+
// Drop any return values beyond the first and the last.
549+
// e.g., "f, _, _, err := NewFoo()".
550+
name = "_"
551+
}
552+
data.Receiver.Constructor.Results = append(data.Receiver.Constructor.Results, field{
553+
Name: name,
554+
Type: types.TypeString(constructor.Signature().Results().At(index).Type(), qf),
555+
})
556+
}
557+
}
558+
}
559+
349560
var test bytes.Buffer
350561
if err := testTmpl.Execute(&test, data); err != nil {
351562
return nil, err

0 commit comments

Comments
 (0)