diff --git a/src/go/ast/walk.go b/src/go/ast/walk.go index 8ca21959b1..2167de97aa 100644 --- a/src/go/ast/walk.go +++ b/src/go/ast/walk.go @@ -39,6 +39,16 @@ func walkDeclList(v Visitor, list []Decl) { } } +func walkConstraintList(v Visitor, constraints []*Constraint) { + for _, x := range constraints { + if x.Param != nil { + Walk(v, x.Param) + } + walkIdentList(v, x.MNames) + walkExprList(v, x.Types) + } +} + // TODO(gri): Investigate if providing a closure to Walk leads to // simpler use (and may help eliminate Inspect in turn). @@ -71,7 +81,9 @@ func Walk(v Visitor, node Node) { Walk(v, n.Doc) } walkIdentList(v, n.Names) - Walk(v, n.Type) + if n.Type != nil { + Walk(v, n.Type) + } if n.Tag != nil { Walk(v, n.Tag) } @@ -161,6 +173,9 @@ func Walk(v Visitor, node Node) { Walk(v, n.Fields) case *FuncType: + if n.TParams != nil { + Walk(v, n.TParams) + } if n.Params != nil { Walk(v, n.Params) } @@ -315,11 +330,24 @@ func Walk(v Visitor, node Node) { Walk(v, n.Doc) } Walk(v, n.Name) + if n.TParams != nil { + Walk(v, n.TParams) + } Walk(v, n.Type) if n.Comment != nil { Walk(v, n.Comment) } + case *ContractSpec: + if n.Doc != nil { + Walk(v, n.Doc) + } + walkIdentList(v, n.TParams) + walkConstraintList(v, n.Constraints) + if n.Comment != nil { + Walk(v, n.Comment) + } + case *BadDecl: // nothing to do diff --git a/src/go/printer/nodes.go b/src/go/printer/nodes.go index e903b9465d..ce476d786b 100644 --- a/src/go/printer/nodes.go +++ b/src/go/printer/nodes.go @@ -1072,10 +1072,10 @@ func (p *printer) expr(x ast.Expr) { // TODO(gri) complete this func (p *printer) constraint(x *ast.Constraint) { if x.Param != nil { - p.print(x.Param, blank) + p.linebreak(p.lineFor(x.Types[0].Pos()), 1, ignore, false) + p.print(indent, x.Param, blank) if len(x.MNames) > 0 { // method names - p.print(blank) for i, m := range x.MNames { p.print(m) t := x.Types[i].(*ast.FuncType) @@ -1083,8 +1083,12 @@ func (p *printer) constraint(x *ast.Constraint) { //p.print(token.COMMA) } } - } else { - // embedded contract + p.print(unindent) + } else if len(x.Types) > 0 { + p.linebreak(p.lineFor(x.Types[0].Pos()), 1, ignore, false) + p.print(indent) + p.exprList(token.NoPos, x.Types, 1, 0, token.NoPos, false) + p.print(unindent) } } @@ -1630,10 +1634,17 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) { } p.print(par) } - p.print(token.RPAREN, s.Lbrace, token.LBRACE) + p.print(token.RPAREN) + if len(s.Constraints) > 0 { + p.print(blank) + } + p.print(s.Lbrace, token.LBRACE) for _, c := range s.Constraints { p.constraint(c) } + if len(s.Constraints) > 0 { + p.linebreak(p.lineFor(s.Rbrace), 1, ignore, true) + } p.print(s.Rbrace, token.RBRACE) p.setComment(s.Comment) diff --git a/src/go/printer/printer.go b/src/go/printer/printer.go index 9d0add40b6..7b639bf30e 100644 --- a/src/go/printer/printer.go +++ b/src/go/printer/printer.go @@ -1038,6 +1038,8 @@ func getDoc(n ast.Node) *ast.CommentGroup { return n.Doc case *ast.TypeSpec: return n.Doc + case *ast.ContractSpec: + return n.Doc case *ast.GenDecl: return n.Doc case *ast.FuncDecl: @@ -1058,6 +1060,8 @@ func getLastComment(n ast.Node) *ast.CommentGroup { return n.Comment case *ast.TypeSpec: return n.Comment + case *ast.ContractSpec: + return n.Comment case *ast.GenDecl: if len(n.Specs) > 0 { return getLastComment(n.Specs[len(n.Specs)-1]) diff --git a/src/go/printer/testdata/contracts.golden b/src/go/printer/testdata/contracts.golden index 1a8695b264..499261d0e5 100644 --- a/src/go/printer/testdata/contracts.golden +++ b/src/go/printer/testdata/contracts.golden @@ -12,7 +12,9 @@ contract ( C2(A, B, C){} ) -contract _(T){T m()} +contract _(T) { + T m() +} type _(type T) struct{} type _(type A C) struct{}