From 5da07bf767b270813c55857a72d00af3719182f8 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Mon, 23 Mar 2020 22:16:22 -0700 Subject: [PATCH] go/types: implemented (bidirectional) unifier and use it for type inference - Implemented bidirectional unifier as a stand-alone mechanism separate from Checker.identical0. - Use it instead of Checker.identical0 where we need unification. - Missing: Bidirection functionality not fully implemented because we don't use it yet, but the basic outline is present. Change-Id: I1666c9e4c9094eda749084bb69c700f1b5e879bb --- src/go/types/infer.go | 31 ++- src/go/types/lookup.go | 18 +- src/go/types/predicates.go | 1 + src/go/types/testdata/tmp.go2 | 9 + src/go/types/unify.go | 345 ++++++++++++++++++++++++++++++++++ 5 files changed, 383 insertions(+), 21 deletions(-) create mode 100644 src/go/types/unify.go diff --git a/src/go/types/infer.go b/src/go/types/infer.go index 06e52c0c34..f1606e0fbc 100644 --- a/src/go/types/infer.go +++ b/src/go/types/infer.go @@ -15,8 +15,8 @@ import "go/token" func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, args []*operand) []Type { assert(params.Len() == len(args)) - // targs is the list of inferred type parameter types. - targs := make([]Type, len(tparams)) + u := check.unifier() + u.x.init(tparams) // Terminology: TPP = type-parameterized function parameter @@ -25,6 +25,9 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a var indices []int for i, arg := range args { par := params.At(i) + // If we permit bidirectional unification, this conditional code needs to be + // executed even if par.typ is not parameterized since the argument may be a + // generic function (for which we want to infer // its type arguments). if IsParameterized(par.typ) { if arg.mode == invalid { // TODO(gri) we might still be able to infer all targs by @@ -32,7 +35,11 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a return nil // error was reported earlier } if isTyped(arg.typ) { - if !check.identical0(par.typ, arg.typ, true, nil, targs) { + // If we permit bidirectional unification, and arg.typ is + // a generic function, we need to initialize u.y with the + // respectice type parameters of arg.typ. + if !u.unify(par.typ, arg.typ) { + //if !check.identical0(par.typ, arg.typ, true, nil, targs) { // Calling subst for an error message can cause problems. // TODO(gri) Determine best approach here. // check.errorf(arg.pos(), "type %s for %s does not match %s = %s", @@ -57,7 +64,7 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a // only parameter type it can possibly match against is a *TypeParam. // Thus, only keep the indices of TPPs that are unstructured and which // don't have a type inferred yet. - if tpar, _ := par.typ.(*TypeParam); tpar != nil && targs[tpar.index] == nil { + if tpar, _ := par.typ.(*TypeParam); tpar != nil && u.x.at(tpar.index) == nil { indices[j] = i j++ } @@ -72,7 +79,8 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a // The default type for an untyped nil is untyped nil. We must not // infer an untyped nil type as type parameter type. Ignore untyped // nil by making sure all default argument types are typed. - if isTyped(targ) && !check.identical0(par.typ, targ, true, nil, targs) { + if isTyped(targ) && !u.unify(par.typ, targ) { + //if isTyped(targ) && !check.identical0(par.typ, targ, true, nil, targs) { // TODO(gri) see TODO comment above // check.errorf(arg.pos(), "default type %s for %s does not match %s = %s", // Default(arg.typ), arg.expr, par.typ, check.subst(pos, par.typ, tparams, targs), @@ -82,15 +90,20 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a } } - // Check if all type parameters have been determined. + // Collect type arguments and check if they all have been determined. // TODO(gri) consider moving this outside this function and then we won't need to pass in pos - for i, t := range targs { - if t == nil { - tpar := tparams[i] + var targs []Type // lazily allocated + for i, tpar := range tparams { + targ := u.x.at(i) + if targ == nil { ppos := check.fset.Position(tpar.pos).String() check.errorf(pos, "cannot infer %s (%s)", tpar.name, ppos) return nil } + if targs == nil { + targs = make([]Type, len(tparams)) + } + targs[i] = targ } return targs diff --git a/src/go/types/lookup.go b/src/go/types/lookup.go index 04353da498..3fa8565cb5 100644 --- a/src/go/types/lookup.go +++ b/src/go/types/lookup.go @@ -310,12 +310,9 @@ func (check *Checker) missingMethod(V Type, T *Interface, static bool) (method, // comparison in that case. // TODO(gri) is this always correct? what about type bounds? // (Alternative is to rename/subst type parameters and compare.) - var tparams []Type - if len(mtyp.tparams) > 0 { - tparams = make([]Type, len(mtyp.tparams)) - } - - if !check.identical0(ftyp, mtyp, true, nil, tparams) { + u := check.unifier() + u.x.init(mtyp.tparams) + if !u.unify(ftyp, mtyp) { return m, f } } @@ -376,12 +373,9 @@ func (check *Checker) missingMethod(V Type, T *Interface, static bool) (method, // comparison (provide non-nil tparams to identical0) in that case. // TODO(gri) is this always correct? what about type bounds? // (Alternative is to rename/subst type parameters and compare.) - var tparams []Type - if len(mtyp.tparams) > 0 { - tparams = make([]Type, len(mtyp.tparams)) - } - - if !check.identical0(ftyp, mtyp, true, nil, tparams) { + u := check.unifier() + u.x.init(mtyp.tparams) + if !u.unify(ftyp, mtyp) { return m, f } } diff --git a/src/go/types/predicates.go b/src/go/types/predicates.go index eb919b06bd..1babe3238d 100644 --- a/src/go/types/predicates.go +++ b/src/go/types/predicates.go @@ -128,6 +128,7 @@ func (p *ifacePair) identical(q *ifacePair) bool { } // If a non-nil tparams is provided, type inference is done for type parameters in x. +// For changes to this code the corresponding changes should be made to unifier.nify. func (check *Checker) identical0(x, y Type, cmpTags bool, p *ifacePair, tparams []Type) bool { // If we want type inference, do not shortcut for equal types. Instead // keep comparing them element-wise so we can infer the matching (and diff --git a/src/go/types/testdata/tmp.go2 b/src/go/types/testdata/tmp.go2 index 28916cc839..fdcef50337 100644 --- a/src/go/types/testdata/tmp.go2 +++ b/src/go/types/testdata/tmp.go2 @@ -4,6 +4,15 @@ package p +/* +func f(type T)(T) + +func _() { + var x int + f(x) +} +*/ + /* func f(func(int)) func g(type T)(T) diff --git a/src/go/types/unify.go b/src/go/types/unify.go new file mode 100644 index 0000000000..bad6ad199e --- /dev/null +++ b/src/go/types/unify.go @@ -0,0 +1,345 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements type unification. + +package types + +import ( + "go/token" + "sort" +) + +// A unifier maintains the current type parameters for x and y +// and the respective types inferred for each type parameter. +// A uninifier is created by calling Checker.unifier. +type unifier struct { + check *Checker + x, y typeDesc // x and y must initialized via typeDesc.init + types []Type // inferred types, shared by x and y +} + +// unifier returns a new unifier. +func (check *Checker) unifier() *unifier { + u := &unifier{check: check} + u.x.uplink = u + u.y.uplink = u + return u +} + +// unify attempts to unify x and y and reports whether it succeeded. +func (u *unifier) unify(x, y Type) bool { + return u.nify(x, y, nil) +} + +// A typeDesc describes a list of type parameters and the types inferred for them. +type typeDesc struct { + uplink *unifier + tparams []*TypeName + indices []int // len(d.indices) == len(d.tparams) +} + +func (d *typeDesc) init(tparams []*TypeName) { + if len(tparams) == 0 { + return + } + d.tparams = tparams + d.indices = make([]int, len(tparams)) +} + +// at returns the type inferred (via unification) for the i'th type parameter; or nil. +// The index i must be a valid type parameter index: 0 <= i < len(d.tparams). +func (d *typeDesc) at(i int) Type { + if i := d.indices[i]; i != 0 { + typ := d.uplink.types[i-1] + assert(typ != nil) + return typ + } + return nil +} + +// set sets the type typ inferred (via unification) for the i'th type parameter; typ must not be nil. +// The index i must be a valid type parameter index: 0 <= i < len(d.tparams). +func (d *typeDesc) set(i int, typ Type) { + assert(typ != nil) + u := d.uplink + u.types = append(u.types, typ) + d.indices[i] = len(u.types) +} + +// If typ is a type parameter in tparams, index returns the +// corresponding tparams index. Otherwise, the result is < 0. +func (u *unifier) index(tparams []*TypeName, typ Type) int { + if t, ok := typ.(*TypeParam); ok { + // typ is a type parameter; check that it belongs to the (enclosing) type + if i := t.index; i < len(tparams) && tparams[i].typ == t { + return i + } + } + return -1 +} + +// nify must only be called by unifier.unify. +// nify implements the core unification algorithm which is an +// adapted version of Checker.identical0. For changes to that +// code the corresponding changes should be made here. +func (u *unifier) nify(x, y Type, p *ifacePair) bool { + //u.check.dump("### u.nify(%s, %s)", x, y) + i := u.index(u.x.tparams, x) + j := u.index(u.y.tparams, y) + switch { + case i >= 0 && j >= 0: + //u.check.dump("### i = %d, j = %d", i, j) + // x and y are type parameters + // This code is only needed for bidirectional type inference. + // TODO(gri) We should be able to combine this code with the simple case. + tx := u.x.at(i) + ty := u.y.at(j) + switch { + case tx != nil && ty != nil: + // both x and y have an inferred type - they must match + if tx == ty { + return true + } + return u.nify(tx, ty, p) + + case tx != nil: + // x has an inferred type + // TODO(gri) fill this in (only needed for bidirection type inference) + panic("unimplemented: x has an inferred type") + + case ty != nil: + // y has an inferred type + // TODO(gri) fill this in (only needed for bidirection type inference) + panic("unimplemented: y has an inferred type") + + default: + // neither x nor y have an inferred type - unify the type parameters + // TODO(gri) fill this in (only needed for bidirection type inference) + panic("unimplemented: neither x nor y have an inferred type") + } + + case i >= 0: + //u.check.dump("### i = %d", i) + // x is a type parameter + if tx := u.x.at(i); tx != nil { + // If we have inferred a type tx and it matches y, we + // are done. u.nify won't do this check, so do it now + // to avoid endless recursion. + if tx == y { + return true + } + return u.nify(tx, y, p) + } + // otherwise, infer type from y (which is known not to be a type parameter) + u.x.set(i, y) + return true + + case j >= 0: + //u.check.dump("### j = %d", j) + // y is a type parameter + if ty := u.y.at(j); ty != nil { + // If we have inferred a type ty and it matches x, we + // are done. u.nify won't do this check, so do it now + // to avoid endless recursion. + if x == ty { + return true + } + return u.nify(x, ty, p) + } + // otherwise, infer type from x (which is known not to be a type parameter) + u.y.set(i, x) + return true + + } + + switch x := x.(type) { + case *Basic: + // Basic types are singletons except for the rune and byte + // aliases, thus we cannot solely rely on the x == y check + // above. See also comment in TypeName.IsAlias. + if y, ok := y.(*Basic); ok { + return x.kind == y.kind + } + + case *Array: + // Two array types are identical if they have identical element types + // and the same array length. + if y, ok := y.(*Array); ok { + // If one or both array lengths are unknown (< 0) due to some error, + // assume they are the same to avoid spurious follow-on errors. + return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, p) + } + + case *Slice: + // Two slice types are identical if they have identical element types. + if y, ok := y.(*Slice); ok { + return u.nify(x.elem, y.elem, p) + } + + case *Struct: + // Two struct types are identical if they have the same sequence of fields, + // and if corresponding fields have the same names, and identical types, + // and identical tags. Two embedded fields are considered to have the same + // name. Lower-case field names from different packages are always different. + if y, ok := y.(*Struct); ok { + if x.NumFields() == y.NumFields() { + for i, f := range x.fields { + g := y.fields[i] + if f.embedded != g.embedded || + x.Tag(i) != y.Tag(i) || + !f.sameId(g.pkg, g.name) || + !u.nify(f.typ, g.typ, p) { + return false + } + } + return true + } + } + + case *Pointer: + // Two pointer types are identical if they have identical base types. + if y, ok := y.(*Pointer); ok { + return u.nify(x.base, y.base, p) + } + + case *Tuple: + // Two tuples types are identical if they have the same number of elements + // and corresponding elements have identical types. + if y, ok := y.(*Tuple); ok { + if x.Len() == y.Len() { + if x != nil { + for i, v := range x.vars { + w := y.vars[i] + if !u.nify(v.typ, w.typ, p) { + return false + } + } + } + return true + } + } + + case *Signature: + // Two function types are identical if they have the same number of parameters + // and result values, corresponding parameter and result types are identical, + // and either both functions are variadic or neither is. Parameter and result + // names are not required to match. + // TODO(gri) handle type parameters or document why we can ignore them. + if y, ok := y.(*Signature); ok { + return x.variadic == y.variadic && + u.nify(x.params, y.params, p) && + u.nify(x.results, y.results, p) + } + + case *Interface: + // Two interface types are identical if they have the same set of methods with + // the same names and identical function types. Lower-case method names from + // different packages are always different. The order of the methods is irrelevant. + if y, ok := y.(*Interface); ok { + // If identical0 is called (indirectly) via an external API entry point + // (such as Identical, IdenticalIgnoreTags, etc.), check is nil. But in + // that case, interfaces are expected to be complete and lazy completion + // here is not needed. + if u.check != nil { + u.check.completeInterface(token.NoPos, x) + u.check.completeInterface(token.NoPos, y) + } + a := x.allMethods + b := y.allMethods + if len(a) == len(b) { + // Interface types are the only types where cycles can occur + // that are not "terminated" via named types; and such cycles + // can only be created via method parameter types that are + // anonymous interfaces (directly or indirectly) embedding + // the current interface. Example: + // + // type T interface { + // m() interface{T} + // } + // + // If two such (differently named) interfaces are compared, + // endless recursion occurs if the cycle is not detected. + // + // If x and y were compared before, they must be equal + // (if they were not, the recursion would have stopped); + // search the ifacePair stack for the same pair. + // + // This is a quadratic algorithm, but in practice these stacks + // are extremely short (bounded by the nesting depth of interface + // type declarations that recur via parameter types, an extremely + // rare occurrence). An alternative implementation might use a + // "visited" map, but that is probably less efficient overall. + q := &ifacePair{x, y, p} + for p != nil { + if p.identical(q) { + return true // same pair was compared before + } + p = p.prev + } + if debug { + assert(sort.IsSorted(byUniqueMethodName(a))) + assert(sort.IsSorted(byUniqueMethodName(b))) + } + for i, f := range a { + g := b[i] + if f.Id() != g.Id() || !u.nify(f.typ, g.typ, q) { + return false + } + } + return true + } + } + + case *Map: + // Two map types are identical if they have identical key and value types. + if y, ok := y.(*Map); ok { + return u.nify(x.key, y.key, p) && u.nify(x.elem, y.elem, p) + } + + case *Chan: + // Two channel types are identical if they have identical value types. + // For type unification, channel direction is ignored. + if y, ok := y.(*Chan); ok { + return u.nify(x.elem, y.elem, p) + } + + case *Named: + // Two named types are identical if their type names originate + // in the same type declaration. + // if y, ok := y.(*Named); ok { + // return x.obj == y.obj + // } + if y, ok := y.(*Named); ok { + // TODO(gri) This is not always correct: two types may have the same names + // in the same package if one of them is nested in a function. + // Extremely unlikely but we need an always correct solution. + if x.obj.pkg == y.obj.pkg && stripArgNames(x.obj.name) == stripArgNames(y.obj.name) { + assert(len(x.targs) == len(y.targs)) + for i, x := range x.targs { + if !u.nify(x, y.targs[i], p) { + return false + } + } + return true + } + } + + case *TypeParam: + // Two type parameters (which are not part of the type parameters of the + // enclosing type) are identical if they originate in the same declaration. + if y, ok := y.(*TypeParam); ok { + return x == y + } + + case nil: + // avoid a crash in case of nil type + + default: + //u.check.dump("### u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams) + unreachable() + } + + return false +}