diff --git a/src/cmd/go/test.go b/src/cmd/go/test.go index 995ba146f5..e23f939255 100644 --- a/src/cmd/go/test.go +++ b/src/cmd/go/test.go @@ -1206,11 +1206,11 @@ func (b *builder) notest(a *action) error { return nil } -// isTestMain tells whether fn is a TestMain(m *testing.M) function. -func isTestMain(fn *ast.FuncDecl) bool { - if fn.Name.String() != "TestMain" || - fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || - fn.Type.Params == nil || +// isTestFunc tells whether fn has the type of a testing function. arg +// specifies the parameter type we look for: B, M or T. +func isTestFunc(fn *ast.FuncDecl, arg string) bool { + if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || + fn.Type.Params.List == nil || len(fn.Type.Params.List) != 1 || len(fn.Type.Params.List[0].Names) > 1 { return false @@ -1222,10 +1222,11 @@ func isTestMain(fn *ast.FuncDecl) bool { // We can't easily check that the type is *testing.M // because we don't know how testing has been imported, // but at least check that it's *M or *something.M. - if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" { + // Same applies for B and T. + if name, ok := ptr.X.(*ast.Ident); ok && name.Name == arg { return true } - if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" { + if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == arg { return true } return false @@ -1344,16 +1345,22 @@ func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error { } name := n.Name.String() switch { - case isTestMain(n): + case name == "TestMain" && isTestFunc(n, "M"): if t.TestMain != nil { return errors.New("multiple definitions of TestMain") } t.TestMain = &testFunc{pkg, name, ""} *doImport, *seen = true, true case isTest(name, "Test"): + if !isTestFunc(n, "T") { + return fmt.Errorf("wrong type for %s", name) + } t.Tests = append(t.Tests, testFunc{pkg, name, ""}) *doImport, *seen = true, true case isTest(name, "Benchmark"): + if !isTestFunc(n, "B") { + return fmt.Errorf("wrong type for %s", name) + } t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, ""}) *doImport, *seen = true, true }