diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index 0da462479a..6e58c23e46 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -490,6 +490,16 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { return e case *ast.BasicLit: return e + case *ast.FuncLit: + typ := t.instantiateExpr(ta, e.Type).(*ast.FuncType) + body := t.instantiateBlockStmt(ta, e.Body) + if typ == e.Type && body == e.Body { + return e + } + return &ast.FuncLit{ + Type: typ, + Body: body, + } case *ast.ParenExpr: x := t.instantiateExpr(ta, e.X) if x == e.X { @@ -518,6 +528,16 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { Star: e.Star, X: x, } + case *ast.UnaryExpr: + x := t.instantiateExpr(ta, e.X) + if x == e.X { + return e + } + r = &ast.UnaryExpr{ + OpPos: e.OpPos, + Op: e.Op, + X: x, + } case *ast.BinaryExpr: x := t.instantiateExpr(ta, e.X) y := t.instantiateExpr(ta, e.Y) diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 981f92df7a..d3976f98a9 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -208,6 +208,9 @@ func (t *translator) translate(file *ast.File) { for j := range decl.Specs { t.translateValueSpec(&decl.Specs[j]) } + case token.IDENT: + // A contract. + decl = nil } if decl != nil { newDecls = append(newDecls, decl) @@ -331,6 +334,16 @@ func (t *translator) translateExpr(pe *ast.Expr) { switch e := (*pe).(type) { case *ast.Ident: return + case *ast.BasicLit: + return + case *ast.FuncLit: + t.translateFieldList(e.Type.TParams) + t.translateFieldList(e.Type.Params) + t.translateFieldList(e.Type.Results) + t.translateBlockStmt(e.Body) + case *ast.CompositeLit: + t.translateExpr(&e.Type) + t.translateExprList(e.Elts) case *ast.ParenExpr: t.translateExpr(&e.X) case *ast.BinaryExpr: @@ -348,9 +361,9 @@ func (t *translator) translateExpr(pe *ast.Expr) { t.translateExpr(&e.Max) case *ast.CallExpr: t.translateExprList(e.Args) - if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && ftyp.TParams() != nil { + if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && len(ftyp.TParams()) > 0 { t.translateFunctionInstantiation(pe) - } else if ntyp, ok := t.lookupType(e.Fun).(*types.Named); ok && ntyp.TParams() != nil { + } else if ntyp, ok := t.lookupType(e.Fun).(*types.Named); ok && len(ntyp.TParams()) > 0 && len(ntyp.TArgs()) == 0 { t.translateTypeInstantiation(pe) } t.translateExpr(&e.Fun) @@ -362,11 +375,10 @@ func (t *translator) translateExpr(pe *ast.Expr) { t.translateExpr(&e.Elt) case *ast.StructType: t.translateFieldList(e.Fields) - case *ast.BasicLit: - return - case *ast.CompositeLit: - t.translateExpr(&e.Type) - t.translateExprList(e.Elts) + case *ast.FuncType: + t.translateFieldList(e.TParams) + t.translateFieldList(e.Params) + t.translateFieldList(e.Results) default: panic(fmt.Sprintf("unimplemented Expr %T", e)) } @@ -486,7 +498,7 @@ func (t *translator) instantiatedIdent(call *ast.CallExpr) qualifiedIdent { } return qualifiedIdent{pkg: pn.Imported(), ident: fun.Sel} } - panic(fmt.Sprintf("instantiated object %v is not an identifier", call.Fun)) + panic(fmt.Sprintf("instantiated object %T %v is not an identifier", call.Fun, call.Fun)) } // instantiationTypes returns the type arguments of an instantiation. diff --git a/src/go/go2go/types.go b/src/go/go2go/types.go index 7224b3bb53..7d41d577e1 100644 --- a/src/go/go2go/types.go +++ b/src/go/go2go/types.go @@ -13,11 +13,11 @@ import ( // lookupType returns the types.Type for an AST expression. // Returns nil if the type is not known. func (t *translator) lookupType(e ast.Expr) types.Type { - if t, ok := t.info.Types[e]; ok { - return t.Type + if typ, ok := t.info.Types[e]; ok { + return typ.Type } - if t, ok := t.types[e]; ok { - return t + if typ, ok := t.types[e]; ok { + return typ } return nil } @@ -64,6 +64,10 @@ func (t *translator) instantiateType(ta *typeArgs, typ types.Type) types.Type { // This should only be called from instantiateType. func (t *translator) doInstantiateType(ta *typeArgs, typ types.Type) types.Type { switch typ := typ.(type) { + case *types.Named: + return typ + case *types.Basic: + return typ case *types.TypeParam: if instType, ok := ta.typ(typ); ok { return instType @@ -75,7 +79,7 @@ func (t *translator) doInstantiateType(ta *typeArgs, typ types.Type) types.Type if elem == instElem { return typ } - return types.NewSlice(elem) + return types.NewSlice(instElem) case *types.Signature: params := t.instantiateTypeTuple(ta, typ.Params()) results := t.instantiateTypeTuple(ta, typ.Results()) diff --git a/src/go/types/predicates.go b/src/go/types/predicates.go index b2693d4c75..fe8d94b38b 100644 --- a/src/go/types/predicates.go +++ b/src/go/types/predicates.go @@ -85,6 +85,8 @@ func Comparable(T Type) bool { return true case *Array: return Comparable(t.elem) + case *TypeParam: + return t.Interface().is(Comparable) } return false } diff --git a/test/gen/g006.go2 b/test/gen/g006.go2 new file mode 100644 index 0000000000..5c5375d8e4 --- /dev/null +++ b/test/gen/g006.go2 @@ -0,0 +1,96 @@ +// run + +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "math" + "os" + "sort" +) + +contract Ordered(T) { + T int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, uintptr, + float32, float64, + string +} + +type orderedSlice(type Elem Ordered) []Elem + +func (s orderedSlice(Elem)) Len() int { return len(s) } +func (s orderedSlice(Elem)) Less(i, j int) bool { + if s[i] < s[j] { + return true + } + isNaN := func(f Elem) bool { return f != f } + if isNaN(s[i]) && !isNaN(s[j]) { + return true + } + return false +} +func (s orderedSlice(Elem)) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func OrderedSlice(type Elem Ordered)(s []Elem) { + sort.Sort(orderedSlice(Elem)(s)) +} + +var ints = []int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586} +var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN(), math.NaN(), math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8} +var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} + +func TestSortOrderedInts() bool { + return testOrdered(int)("ints", ints, sort.Ints) +} + +func TestSortOrderedFloat64s() bool { + return testOrdered(float64)("float64s", float64s, sort.Float64s) +} + +func TestSortOrderedStrings() bool { + return testOrdered(string)("strings", strings, sort.Strings) +} + +func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem)) bool { + s1 := make([]Elem, len(s)) + copy(s1, s) + s2 := make([]Elem, len(s)) + copy(s2, s) + OrderedSlice(Elem)(s1) + sorter(s2) + ok := true + if !sliceEq(Elem)(s1, s2) { + fmt.Printf("%s: got %v, want %v", name, s1, s2) + ok = false + } + for i := len(s1) - 1; i > 0; i-- { + if s1[i] < s1[i-1] { + fmt.Printf("%s: element %d (%v) < element %d (%v)", name, i, s1[i], i - 1, s1[i - 1]) + ok = false + } + } + return ok +} + +func sliceEq(type Elem Ordered)(s1, s2[]Elem) bool { + for i, v1 := range s1 { + v2 := s2[i] + if v1 != v2 { + isNaN := func(f Elem) bool { return f != f } + if !isNaN(v1) || !isNaN(v2) { + return false + } + } + } + return true +} + +func main() { + if !TestSortOrderedInts() || !TestSortOrderedFloat64s() || !TestSortOrderedStrings() { + os.Exit(1) + } +}