diff --git a/src/go/ast/ast.go b/src/go/ast/ast.go index fd83d39712..46a8d0cb5e 100644 --- a/src/go/ast/ast.go +++ b/src/go/ast/ast.go @@ -192,7 +192,7 @@ func isDirective(c string) bool { type Field struct { Doc *CommentGroup // associated documentation; or nil Names []*Ident // field/method/(type) parameter names; or nil - Type Expr // field/method/parameter type or contract + Type Expr // field/method/parameter type or contract; or nil Tag *BasicLit // field tag; or nil Comment *CommentGroup // line comments; or nil } @@ -201,14 +201,23 @@ func (f *Field) Pos() token.Pos { if len(f.Names) > 0 { return f.Names[0].Pos() } - return f.Type.Pos() + if f.Type != nil { + return f.Type.Pos() + } + return token.NoPos } func (f *Field) End() token.Pos { if f.Tag != nil { return f.Tag.End() } - return f.Type.End() + if f.Type != nil { + return f.Type.End() + } + if len(f.Names) > 0 { + return f.Names[len(f.Names)-1].End() + } + return token.NoPos } // A FieldList represents a list of Fields, enclosed by parentheses or braces. @@ -463,7 +472,7 @@ type ( ) type Constraint struct { - Param *Ident // constrained type parameter; or nil (for embedded constraints) + Param *Ident // constrained type parameter; or nil (for embedded contracts) MNames []*Ident // list of method names; or nil (for embedded contracts or type constraints) Types []Expr // embedded constraint (single *CallExpr), list of types, or list of method types (*FuncType) } diff --git a/src/go/printer/nodes.go b/src/go/printer/nodes.go index cac9c09701..c9c4483c3e 100644 --- a/src/go/printer/nodes.go +++ b/src/go/printer/nodes.go @@ -319,8 +319,11 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp } } -func (p *printer) parameters(fields *ast.FieldList) { +func (p *printer) parameters(isTypeParam bool, fields *ast.FieldList) { p.print(fields.Opening, token.LPAREN) + if isTypeParam { + p.print(token.TYPE) + } if len(fields.List) > 0 { prevLine := p.lineFor(fields.Opening) ws := indent @@ -328,13 +331,8 @@ func (p *printer) parameters(fields *ast.FieldList) { // determine par begin and end line (may be different // if there are multiple parameter names for this par // or the type is on a separate line) - var parLineBeg int - if len(par.Names) > 0 { - parLineBeg = p.lineFor(par.Names[0].Pos()) - } else { - parLineBeg = p.lineFor(par.Type.Pos()) - } - var parLineEnd = p.lineFor(par.Type.End()) + parLineBeg := p.lineFor(par.Pos()) + parLineEnd := p.lineFor(par.End()) // separating "," if needed needsLinebreak := 0 < prevLine && prevLine < parLineBeg if i > 0 { @@ -350,7 +348,7 @@ func (p *printer) parameters(fields *ast.FieldList) { if needsLinebreak && p.linebreak(parLineBeg, 0, ws, true) > 0 { // break line if the opening "(" or previous parameter ended on a different line ws = ignore - } else if i > 0 { + } else if isTypeParam && len(par.Names) > 0 || i > 0 { p.print(blank) } // parameter names @@ -362,10 +360,14 @@ func (p *printer) parameters(fields *ast.FieldList) { // by a linebreak call after a type, or in the next multi-line identList // will do the right thing. p.identList(par.Names, ws == indent) - p.print(blank) + if par.Type != nil { + p.print(blank) + } } // parameter type - p.expr(stripParensAlways(par.Type)) + if par.Type != nil { + p.expr(stripParensAlways(par.Type)) + } prevLine = parLineEnd } // if the closing ")" is on a separate line from the last parameter, @@ -384,7 +386,7 @@ func (p *printer) parameters(fields *ast.FieldList) { func (p *printer) signature(params, result *ast.FieldList) { if params != nil { - p.parameters(params) + p.parameters(false, params) } else { p.print(token.LPAREN, token.RPAREN) } @@ -397,7 +399,7 @@ func (p *printer) signature(params, result *ast.FieldList) { p.expr(stripParensAlways(result.List[0].Type)) return } - p.parameters(result) + p.parameters(false, result) } } @@ -969,6 +971,23 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { p.print(blank) p.expr(x.Value) + case *ast.ContractType: + if p.mode&noContractKeyword == 0 { + p.print(&ast.Ident{NamePos: x.Contract, Name: "contract"}) + } + p.print(token.LPAREN) + for i, par := range x.TParams { + if i > 0 { + p.print(token.COMMA, blank) + } + p.print(par) + } + p.print(token.RPAREN, x.Lbrace, token.LBRACE) + for _, c := range x.Constraints { + p.constraint(c) + } + p.print(x.Rbrace, token.RBRACE) + default: panic("unreachable") } @@ -1063,6 +1082,25 @@ func (p *printer) expr(x ast.Expr) { p.expr1(x, token.LowestPrec, depth) } +// TODO(gri) complete this +func (p *printer) constraint(x *ast.Constraint) { + if x.Param != nil { + p.print(x.Param, blank) + if len(x.MNames) > 0 { + // method names + p.print(blank) + for i, m := range x.MNames { + p.print(m) + t := x.Types[i].(*ast.FuncType) + p.signature(t.Params, t.Results) + //p.print(token.COMMA) + } + } + } else { + // embedded contract + } +} + // ---------------------------------------------------------------------------- // Statements @@ -1581,6 +1619,9 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) { case *ast.TypeSpec: p.setComment(s.Doc) p.expr(s.Name) + if s.TParams != nil { + p.parameters(true, s.TParams) + } if n == 1 { p.print(blank) } else { @@ -1599,7 +1640,15 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) { func (p *printer) genDecl(d *ast.GenDecl) { p.setComment(d.Doc) - p.print(d.Pos(), d.Tok, blank) + // Contract declarations rely on the pseudo-keyword (identifier) "contract"; + // in the AST the respective token is ast.IDENT. Catch and correct this here. + if d.Tok == token.IDENT { + p.print(&ast.Ident{NamePos: d.Pos(), Name: "contract"}, noContractKeyword) + defer p.print(noContractKeyword) + } else { + p.print(d.Pos(), d.Tok) + } + p.print(blank) if d.Lparen.IsValid() || len(d.Specs) > 1 { // group of parenthesized declarations @@ -1769,12 +1818,12 @@ func (p *printer) funcDecl(d *ast.FuncDecl) { // FUNC is emitted). startCol := p.out.Column - len("func ") if d.Recv != nil { - p.parameters(d.Recv) // method: print receiver + p.parameters(false, d.Recv) // method: print receiver p.print(blank) } p.expr(d.Name) - if d.TParams.NumFields() != 0 { - p.parameters(d.TParams) + if d.TParams != nil { + p.parameters(true, d.TParams) } p.signature(d.Type.Params, d.Type.Results) p.funcBody(p.distanceFrom(d.Pos(), startCol), vtab, d.Body) diff --git a/src/go/printer/printer.go b/src/go/printer/printer.go index 9d0add40b6..fb26c59568 100644 --- a/src/go/printer/printer.go +++ b/src/go/printer/printer.go @@ -38,8 +38,9 @@ const ( type pmode int const ( - noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment - noExtraLinebreak // disables extra line break after /*-style comment + noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment + noExtraLinebreak // disables extra line break after /*-style comment + noContractKeyword // disables printing of "contract" pseudo-keyword when printing a contract type ) type commentInfo struct { diff --git a/src/go/printer/printer_test.go b/src/go/printer/printer_test.go index 1e9d47ce73..63e02475a6 100644 --- a/src/go/printer/printer_test.go +++ b/src/go/printer/printer_test.go @@ -206,6 +206,7 @@ var data = []entry{ {"complit.input", "complit.x", export}, {"go2numbers.input", "go2numbers.golden", idempotent}, {"go2numbers.input", "go2numbers.stdfmt", stdFormat | idempotent}, + {"contracts.input", "contracts.golden", idempotent}, } func TestFiles(t *testing.T) {