diff --git a/src/cmd/compile/internal/syntax/parser.go b/src/cmd/compile/internal/syntax/parser.go index 770175fe54..40c5eca408 100644 --- a/src/cmd/compile/internal/syntax/parser.go +++ b/src/cmd/compile/internal/syntax/parser.go @@ -588,44 +588,81 @@ func (p *parser) typeDecl(group *Group) Decl { d.Name = p.name() if p.allowGenerics() && p.tok == _Lbrack { // d.Name "[" ... - // array/slice or type parameter list + // array/slice type or type parameter list pos := p.pos() p.next() switch p.tok { case _Name: - // d.Name "[" name ... - // array or type parameter list - name := p.name() - // Index or slice expressions are never constant and thus invalid - // array length expressions. Thus, if we see a "[" following name - // we can safely assume that "[" name starts a type parameter list. - var x Expr // x != nil means x is the array length expression + // We may have an array type or a type parameter list. + // In either case we expect an expression x (which may + // just be a name, or a more complex expression) which + // we can analyze further. + // + // A type parameter list may have a type bound starting + // with a "[" as in: P []E. In that case, simply parsing + // an expression would lead to an error: P[] is invalid. + // But since index or slice expressions are never constant + // and thus invalid array length expressions, if we see a + // "[" following a name it must be the start of an array + // or slice constraint. Only if we don't see a "[" do we + // need to parse a full expression. + var x Expr = p.name() if p.tok != _Lbrack { - // d.Name "[" name ... - // If we reach here, the next token is not a "[", and we need to - // parse the expression starting with name. If that expression is - // just that name, not followed by a "]" (in which case we might - // have the array length "[" name "]"), we can also safely assume - // a type parameter list. + // To parse the expression starting with name, expand + // the call sequence we would get by passing in name + // to parser.expr, and pass in name to parser.pexpr. p.xnest++ - // To parse the expression starting with name, expand the call - // sequence we would get by passing in name to parser.expr, and - // pass in name to parser.pexpr. - x = p.binaryExpr(p.pexpr(name, false), 0) + x = p.binaryExpr(p.pexpr(x, false), 0) p.xnest-- - if x == name && p.tok != _Rbrack { - x = nil + } + + // analyze the cases + var pname *Name // pname != nil means pname is the type parameter name + var ptype Expr // ptype != nil means ptype is the type parameter type; pname != nil in this case + switch t := x.(type) { + case *Name: + // Unless we see a "]", we are at the start of a type parameter list. + if p.tok != _Rbrack { + // d.Name "[" name ... + pname = t + // no ptype + } + case *Operation: + // If we have an expression of the form name*T, and T is a (possibly + // parenthesized) type literal or the next token is a comma, we are + // at the start of a type parameter list. + if name, _ := t.X.(*Name); name != nil { + if t.Op == Mul && (isTypeLit(t.Y) || p.tok == _Comma) { + // d.Name "[" name "*" t.Y + // d.Name "[" name "*" t.Y "," + t.X, t.Y = t.Y, nil // convert t into unary *t.Y + pname = name + ptype = t + } + } + case *CallExpr: + // If we have an expression of the form name(T), and T is a (possibly + // parenthesized) type literal or the next token is a comma, we are + // at the start of a type parameter list. + if name, _ := t.Fun.(*Name); name != nil { + if len(t.ArgList) == 1 && !t.HasDots && (isTypeLit(t.ArgList[0]) || p.tok == _Comma) { + // d.Name "[" name "(" t.ArgList[0] ")" + // d.Name "[" name "(" t.ArgList[0] ")" "," + pname = name + ptype = t.ArgList[0] + } } } - if x == nil { - // d.Name "[" name ... - // type parameter list - d.TParamList = p.paramList(name, _Rbrack, true) + + if pname != nil { + // d.Name "[" pname ... + // d.Name "[" pname ptype ... + // d.Name "[" pname ptype "," ... + d.TParamList = p.paramList(pname, ptype, _Rbrack, true) d.Alias = p.gotAssign() d.Type = p.typeOrNil() } else { - // d.Name "[" x "]" ... - // x is the array length expression + // d.Name "[" x ... d.Type = p.arrayType(pos, x) } case _Rbrack: @@ -650,6 +687,21 @@ func (p *parser) typeDecl(group *Group) Decl { return d } +// isTypeLit reports whether x is a (possibly parenthesized) type literal. +func isTypeLit(x Expr) bool { + switch x := x.(type) { + case *ArrayType, *StructType, *FuncType, *InterfaceType, *SliceType, *MapType, *ChanType: + return true + case *Operation: + // *T may be a pointer dereferenciation. + // Only consider *T as type literal if T is a type literal. + return x.Op == Mul && x.Y == nil && isTypeLit(x.X) + case *ParenExpr: + return isTypeLit(x.X) + } + return false +} + // VarSpec = IdentifierList ( Type [ "=" ExpressionList ] | "=" ExpressionList ) . func (p *parser) varDecl(group *Group) Decl { if trace { @@ -689,7 +741,7 @@ func (p *parser) funcDeclOrNil() *FuncDecl { f.Pragma = p.takePragma() if p.got(_Lparen) { - rcvr := p.paramList(nil, _Rparen, false) + rcvr := p.paramList(nil, nil, _Rparen, false) switch len(rcvr) { case 0: p.error("method has no receiver") @@ -1369,12 +1421,12 @@ func (p *parser) funcType(context string) ([]*Field, *FuncType) { p.syntaxError("empty type parameter list") p.next() } else { - tparamList = p.paramList(nil, _Rbrack, true) + tparamList = p.paramList(nil, nil, _Rbrack, true) } } p.want(_Lparen) - typ.ParamList = p.paramList(nil, _Rparen, false) + typ.ParamList = p.paramList(nil, nil, _Rparen, false) typ.ResultList = p.funcResult() return tparamList, typ @@ -1392,6 +1444,13 @@ func (p *parser) arrayType(pos Pos, len Expr) Expr { len = p.expr() p.xnest-- } + if p.tok == _Comma { + // Trailing commas are accepted in type parameter + // lists but not in array type declarations. + // Accept for better error handling but complain. + p.syntaxError("unexpected comma; expecting ]") + p.next() + } p.want(_Rbrack) t := new(ArrayType) t.pos = pos @@ -1516,7 +1575,7 @@ func (p *parser) funcResult() []*Field { } if p.got(_Lparen) { - return p.paramList(nil, _Rparen, false) + return p.paramList(nil, nil, _Rparen, false) } pos := p.pos() @@ -1742,7 +1801,7 @@ func (p *parser) methodDecl() *Field { // A type argument list looks like a parameter list with only // types. Parse a parameter list and decide afterwards. - list := p.paramList(nil, _Rbrack, false) + list := p.paramList(nil, nil, _Rbrack, false) if len(list) == 0 { // The type parameter list is not [] but we got nothing // due to other errors (reported by paramList). Treat @@ -1948,17 +2007,41 @@ func (p *parser) paramDeclOrNil(name *Name, follow token) *Field { // ParameterList = ParameterDecl { "," ParameterDecl } . // "(" or "[" has already been consumed. // If name != nil, it is the first name after "(" or "[". +// If typ != nil, name must be != nil, and (name, typ) is the first field in the list. // In the result list, either all fields have a name, or no field has a name. -func (p *parser) paramList(name *Name, close token, requireNames bool) (list []*Field) { +func (p *parser) paramList(name *Name, typ Expr, close token, requireNames bool) (list []*Field) { if trace { defer p.trace("paramList")() } + // p.list won't invoke its function argument if we're at the end of the + // parameter list. If we have a complete field, handle this case here. + if name != nil && typ != nil && p.tok == close { + p.next() + par := new(Field) + par.pos = name.pos + par.Name = name + par.Type = typ + return []*Field{par} + } + var named int // number of parameters that have an explicit name and type var typed int // number of parameters that have an explicit type end := p.list(_Comma, close, func() bool { - par := p.paramDeclOrNil(name, close) + var par *Field + if typ != nil { + if debug && name == nil { + panic("initial type provided without name") + } + par = new(Field) + par.pos = name.pos + par.Name = name + par.Type = typ + } else { + par = p.paramDeclOrNil(name, close) + } name = nil // 1st name was consumed if present + typ = nil // 1st type was consumed if present if par != nil { if debug && par.Name == nil && par.Type == nil { panic("parameter without name or type") diff --git a/src/cmd/compile/internal/syntax/printer.go b/src/cmd/compile/internal/syntax/printer.go index c8d31799af..11190ab287 100644 --- a/src/cmd/compile/internal/syntax/printer.go +++ b/src/cmd/compile/internal/syntax/printer.go @@ -666,9 +666,7 @@ func (p *printer) printRawNode(n Node) { } p.print(n.Name) if n.TParamList != nil { - p.print(_Lbrack) - p.printFieldList(n.TParamList, nil, _Comma) - p.print(_Rbrack) + p.printParameterList(n.TParamList, true) } p.print(blank) if n.Alias { @@ -700,9 +698,7 @@ func (p *printer) printRawNode(n Node) { } p.print(n.Name) if n.TParamList != nil { - p.print(_Lbrack) - p.printFieldList(n.TParamList, nil, _Comma) - p.print(_Rbrack) + p.printParameterList(n.TParamList, true) } p.printSignature(n.Type) if n.Body != nil { @@ -887,38 +883,47 @@ func (p *printer) printDeclList(list []Decl) { } func (p *printer) printSignature(sig *FuncType) { - p.printParameterList(sig.ParamList) + p.printParameterList(sig.ParamList, false) if list := sig.ResultList; list != nil { p.print(blank) if len(list) == 1 && list[0].Name == nil { p.printNode(list[0].Type) } else { - p.printParameterList(list) + p.printParameterList(list, false) } } } -func (p *printer) printParameterList(list []*Field) { - p.print(_Lparen) - if len(list) > 0 { - for i, f := range list { - if i > 0 { - p.print(_Comma, blank) - } - if f.Name != nil { - p.printNode(f.Name) - if i+1 < len(list) { - f1 := list[i+1] - if f1.Name != nil && f1.Type == f.Type { - continue // no need to print type - } +func (p *printer) printParameterList(list []*Field, types bool) { + open, close := _Lparen, _Rparen + if types { + open, close = _Lbrack, _Rbrack + } + p.print(open) + for i, f := range list { + if i > 0 { + p.print(_Comma, blank) + } + if f.Name != nil { + p.printNode(f.Name) + if i+1 < len(list) { + f1 := list[i+1] + if f1.Name != nil && f1.Type == f.Type { + continue // no need to print type } - p.print(blank) } - p.printNode(f.Type) + p.print(blank) + } + p.printNode(unparen(f.Type)) // no need for (extra) parentheses around parameter types + } + // A type parameter list [P *T] where T is not a type literal requires a comma as in [P *T,] + // so that it's not parsed as [P*T]. + if types && len(list) == 1 { + if t, _ := list[0].Type.(*Operation); t != nil && t.Op == Mul && t.Y == nil && !isTypeLit(t.X) { + p.print(_Comma) } } - p.print(_Rparen) + p.print(close) } func (p *printer) printStmtList(list []Stmt, braces bool) { diff --git a/src/cmd/compile/internal/syntax/printer_test.go b/src/cmd/compile/internal/syntax/printer_test.go index 604f1fc1ca..941af0aeb4 100644 --- a/src/cmd/compile/internal/syntax/printer_test.go +++ b/src/cmd/compile/internal/syntax/printer_test.go @@ -53,54 +53,77 @@ func TestPrintError(t *testing.T) { } } -var stringTests = []string{ - "package p", - "package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )", +var stringTests = [][2]string{ + dup("package p"), + dup("package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )"), // generic type declarations - "package p; type _[T any] struct{}", - "package p; type _[A, B, C interface{m()}] struct{}", - "package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}] struct{}", + dup("package p; type _[T any] struct{}"), + dup("package p; type _[A, B, C interface{m()}] struct{}"), + dup("package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}] struct{}"), + + dup("package p; type _[P *T,] struct{}"), + dup("package p; type _[P *T, _ any] struct{}"), + {"package p; type _[P (*T),] struct{}", "package p; type _[P *T,] struct{}"}, + {"package p; type _[P (*T), _ any] struct{}", "package p; type _[P *T, _ any] struct{}"}, + {"package p; type _[P (T),] struct{}", "package p; type _[P T] struct{}"}, + {"package p; type _[P (T), _ any] struct{}", "package p; type _[P T, _ any] struct{}"}, + + dup("package p; type _[P *struct{}] struct{}"), + {"package p; type _[P (*struct{})] struct{}", "package p; type _[P *struct{}] struct{}"}, + {"package p; type _[P ([]int)] struct{}", "package p; type _[P []int] struct{}"}, + + dup("package p; type _ [P(T)]struct{}"), + dup("package p; type _ [P((T))]struct{}"), + dup("package p; type _ [P * *T]struct{}"), + dup("package p; type _ [P * T]struct{}"), + dup("package p; type _ [P(*T)]struct{}"), + dup("package p; type _ [P(**T)]struct{}"), + dup("package p; type _ [P * T - T]struct{}"), + + // array type declarations + dup("package p; type _ [P * T]struct{}"), + dup("package p; type _ [P * T - T]struct{}"), // generic function declarations - "package p; func _[T any]()", - "package p; func _[A, B, C interface{m()}]()", - "package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}]()", + dup("package p; func _[T any]()"), + dup("package p; func _[A, B, C interface{m()}]()"), + dup("package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}]()"), // methods with generic receiver types - "package p; func (R[T]) _()", - "package p; func (*R[A, B, C]) _()", - "package p; func (_ *R[A, B, C]) _()", + dup("package p; func (R[T]) _()"), + dup("package p; func (*R[A, B, C]) _()"), + dup("package p; func (_ *R[A, B, C]) _()"), // type constraint literals with elided interfaces - "package p; func _[P ~int, Q int | string]() {}", - "package p; func _[P struct{f int}, Q *P]() {}", + dup("package p; func _[P ~int, Q int | string]() {}"), + dup("package p; func _[P struct{f int}, Q *P]() {}"), // channels - "package p; type _ chan chan int", - "package p; type _ chan (<-chan int)", - "package p; type _ chan chan<- int", + dup("package p; type _ chan chan int"), + dup("package p; type _ chan (<-chan int)"), + dup("package p; type _ chan chan<- int"), - "package p; type _ <-chan chan int", - "package p; type _ <-chan <-chan int", - "package p; type _ <-chan chan<- int", + dup("package p; type _ <-chan chan int"), + dup("package p; type _ <-chan <-chan int"), + dup("package p; type _ <-chan chan<- int"), - "package p; type _ chan<- chan int", - "package p; type _ chan<- <-chan int", - "package p; type _ chan<- chan<- int", + dup("package p; type _ chan<- chan int"), + dup("package p; type _ chan<- <-chan int"), + dup("package p; type _ chan<- chan<- int"), // TODO(gri) expand } func TestPrintString(t *testing.T) { - for _, want := range stringTests { - ast, err := Parse(nil, strings.NewReader(want), nil, nil, AllowGenerics) + for _, test := range stringTests { + ast, err := Parse(nil, strings.NewReader(test[0]), nil, nil, AllowGenerics) if err != nil { t.Error(err) continue } - if got := String(ast); got != want { - t.Errorf("%q: got %q", want, got) + if got := String(ast); got != test[1] { + t.Errorf("%q: got %q", test[1], got) } } } diff --git a/src/cmd/compile/internal/syntax/testdata/issue49482.go2 b/src/cmd/compile/internal/syntax/testdata/issue49482.go2 new file mode 100644 index 0000000000..1fc303d169 --- /dev/null +++ b/src/cmd/compile/internal/syntax/testdata/issue49482.go2 @@ -0,0 +1,31 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +type ( + // these need a comma to disambiguate + _[P *T,] struct{} + _[P *T, _ any] struct{} + _[P (*T),] struct{} + _[P (*T), _ any] struct{} + _[P (T),] struct{} + _[P (T), _ any] struct{} + + // these parse as name followed by type + _[P *struct{}] struct{} + _[P (*struct{})] struct{} + _[P ([]int)] struct{} + + // array declarations + _ [P(T)]struct{} + _ [P((T))]struct{} + _ [P * *T] struct{} // this could be a name followed by a type but it makes the rules more complicated + _ [P * T]struct{} + _ [P(*T)]struct{} + _ [P(**T)]struct{} + _ [P * T - T]struct{} + _ [P*T-T /* ERROR unexpected comma */ ,]struct{} + _ [10 /* ERROR unexpected comma */ ,]struct{} +)