go/types: implemented (bidirectional) unifier and use it for type inference

- Implemented bidirectional unifier as a stand-alone mechanism
  separate from Checker.identical0.
- Use it instead of Checker.identical0 where we need unification.
- Missing: Bidirection functionality not fully implemented because
  we don't use it yet, but the basic outline is present.

Change-Id: I1666c9e4c9094eda749084bb69c700f1b5e879bb
This commit is contained in:
Robert Griesemer 2020-03-23 22:16:22 -07:00
parent 42e310c4a6
commit 5da07bf767
5 changed files with 383 additions and 21 deletions

View File

@ -15,8 +15,8 @@ import "go/token"
func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, args []*operand) []Type {
assert(params.Len() == len(args))
// targs is the list of inferred type parameter types.
targs := make([]Type, len(tparams))
u := check.unifier()
u.x.init(tparams)
// Terminology: TPP = type-parameterized function parameter
@ -25,6 +25,9 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a
var indices []int
for i, arg := range args {
par := params.At(i)
// If we permit bidirectional unification, this conditional code needs to be
// executed even if par.typ is not parameterized since the argument may be a
// generic function (for which we want to infer // its type arguments).
if IsParameterized(par.typ) {
if arg.mode == invalid {
// TODO(gri) we might still be able to infer all targs by
@ -32,7 +35,11 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a
return nil // error was reported earlier
}
if isTyped(arg.typ) {
if !check.identical0(par.typ, arg.typ, true, nil, targs) {
// If we permit bidirectional unification, and arg.typ is
// a generic function, we need to initialize u.y with the
// respectice type parameters of arg.typ.
if !u.unify(par.typ, arg.typ) {
//if !check.identical0(par.typ, arg.typ, true, nil, targs) {
// Calling subst for an error message can cause problems.
// TODO(gri) Determine best approach here.
// check.errorf(arg.pos(), "type %s for %s does not match %s = %s",
@ -57,7 +64,7 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a
// only parameter type it can possibly match against is a *TypeParam.
// Thus, only keep the indices of TPPs that are unstructured and which
// don't have a type inferred yet.
if tpar, _ := par.typ.(*TypeParam); tpar != nil && targs[tpar.index] == nil {
if tpar, _ := par.typ.(*TypeParam); tpar != nil && u.x.at(tpar.index) == nil {
indices[j] = i
j++
}
@ -72,7 +79,8 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a
// The default type for an untyped nil is untyped nil. We must not
// infer an untyped nil type as type parameter type. Ignore untyped
// nil by making sure all default argument types are typed.
if isTyped(targ) && !check.identical0(par.typ, targ, true, nil, targs) {
if isTyped(targ) && !u.unify(par.typ, targ) {
//if isTyped(targ) && !check.identical0(par.typ, targ, true, nil, targs) {
// TODO(gri) see TODO comment above
// check.errorf(arg.pos(), "default type %s for %s does not match %s = %s",
// Default(arg.typ), arg.expr, par.typ, check.subst(pos, par.typ, tparams, targs),
@ -82,15 +90,20 @@ func (check *Checker) infer(pos token.Pos, tparams []*TypeName, params *Tuple, a
}
}
// Check if all type parameters have been determined.
// Collect type arguments and check if they all have been determined.
// TODO(gri) consider moving this outside this function and then we won't need to pass in pos
for i, t := range targs {
if t == nil {
tpar := tparams[i]
var targs []Type // lazily allocated
for i, tpar := range tparams {
targ := u.x.at(i)
if targ == nil {
ppos := check.fset.Position(tpar.pos).String()
check.errorf(pos, "cannot infer %s (%s)", tpar.name, ppos)
return nil
}
if targs == nil {
targs = make([]Type, len(tparams))
}
targs[i] = targ
}
return targs

View File

@ -310,12 +310,9 @@ func (check *Checker) missingMethod(V Type, T *Interface, static bool) (method,
// comparison in that case.
// TODO(gri) is this always correct? what about type bounds?
// (Alternative is to rename/subst type parameters and compare.)
var tparams []Type
if len(mtyp.tparams) > 0 {
tparams = make([]Type, len(mtyp.tparams))
}
if !check.identical0(ftyp, mtyp, true, nil, tparams) {
u := check.unifier()
u.x.init(mtyp.tparams)
if !u.unify(ftyp, mtyp) {
return m, f
}
}
@ -376,12 +373,9 @@ func (check *Checker) missingMethod(V Type, T *Interface, static bool) (method,
// comparison (provide non-nil tparams to identical0) in that case.
// TODO(gri) is this always correct? what about type bounds?
// (Alternative is to rename/subst type parameters and compare.)
var tparams []Type
if len(mtyp.tparams) > 0 {
tparams = make([]Type, len(mtyp.tparams))
}
if !check.identical0(ftyp, mtyp, true, nil, tparams) {
u := check.unifier()
u.x.init(mtyp.tparams)
if !u.unify(ftyp, mtyp) {
return m, f
}
}

View File

@ -128,6 +128,7 @@ func (p *ifacePair) identical(q *ifacePair) bool {
}
// If a non-nil tparams is provided, type inference is done for type parameters in x.
// For changes to this code the corresponding changes should be made to unifier.nify.
func (check *Checker) identical0(x, y Type, cmpTags bool, p *ifacePair, tparams []Type) bool {
// If we want type inference, do not shortcut for equal types. Instead
// keep comparing them element-wise so we can infer the matching (and

View File

@ -4,6 +4,15 @@
package p
/*
func f(type T)(T)
func _() {
var x int
f(x)
}
*/
/*
func f(func(int))
func g(type T)(T)

345
src/go/types/unify.go Normal file
View File

@ -0,0 +1,345 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements type unification.
package types
import (
"go/token"
"sort"
)
// A unifier maintains the current type parameters for x and y
// and the respective types inferred for each type parameter.
// A uninifier is created by calling Checker.unifier.
type unifier struct {
check *Checker
x, y typeDesc // x and y must initialized via typeDesc.init
types []Type // inferred types, shared by x and y
}
// unifier returns a new unifier.
func (check *Checker) unifier() *unifier {
u := &unifier{check: check}
u.x.uplink = u
u.y.uplink = u
return u
}
// unify attempts to unify x and y and reports whether it succeeded.
func (u *unifier) unify(x, y Type) bool {
return u.nify(x, y, nil)
}
// A typeDesc describes a list of type parameters and the types inferred for them.
type typeDesc struct {
uplink *unifier
tparams []*TypeName
indices []int // len(d.indices) == len(d.tparams)
}
func (d *typeDesc) init(tparams []*TypeName) {
if len(tparams) == 0 {
return
}
d.tparams = tparams
d.indices = make([]int, len(tparams))
}
// at returns the type inferred (via unification) for the i'th type parameter; or nil.
// The index i must be a valid type parameter index: 0 <= i < len(d.tparams).
func (d *typeDesc) at(i int) Type {
if i := d.indices[i]; i != 0 {
typ := d.uplink.types[i-1]
assert(typ != nil)
return typ
}
return nil
}
// set sets the type typ inferred (via unification) for the i'th type parameter; typ must not be nil.
// The index i must be a valid type parameter index: 0 <= i < len(d.tparams).
func (d *typeDesc) set(i int, typ Type) {
assert(typ != nil)
u := d.uplink
u.types = append(u.types, typ)
d.indices[i] = len(u.types)
}
// If typ is a type parameter in tparams, index returns the
// corresponding tparams index. Otherwise, the result is < 0.
func (u *unifier) index(tparams []*TypeName, typ Type) int {
if t, ok := typ.(*TypeParam); ok {
// typ is a type parameter; check that it belongs to the (enclosing) type
if i := t.index; i < len(tparams) && tparams[i].typ == t {
return i
}
}
return -1
}
// nify must only be called by unifier.unify.
// nify implements the core unification algorithm which is an
// adapted version of Checker.identical0. For changes to that
// code the corresponding changes should be made here.
func (u *unifier) nify(x, y Type, p *ifacePair) bool {
//u.check.dump("### u.nify(%s, %s)", x, y)
i := u.index(u.x.tparams, x)
j := u.index(u.y.tparams, y)
switch {
case i >= 0 && j >= 0:
//u.check.dump("### i = %d, j = %d", i, j)
// x and y are type parameters
// This code is only needed for bidirectional type inference.
// TODO(gri) We should be able to combine this code with the simple case.
tx := u.x.at(i)
ty := u.y.at(j)
switch {
case tx != nil && ty != nil:
// both x and y have an inferred type - they must match
if tx == ty {
return true
}
return u.nify(tx, ty, p)
case tx != nil:
// x has an inferred type
// TODO(gri) fill this in (only needed for bidirection type inference)
panic("unimplemented: x has an inferred type")
case ty != nil:
// y has an inferred type
// TODO(gri) fill this in (only needed for bidirection type inference)
panic("unimplemented: y has an inferred type")
default:
// neither x nor y have an inferred type - unify the type parameters
// TODO(gri) fill this in (only needed for bidirection type inference)
panic("unimplemented: neither x nor y have an inferred type")
}
case i >= 0:
//u.check.dump("### i = %d", i)
// x is a type parameter
if tx := u.x.at(i); tx != nil {
// If we have inferred a type tx and it matches y, we
// are done. u.nify won't do this check, so do it now
// to avoid endless recursion.
if tx == y {
return true
}
return u.nify(tx, y, p)
}
// otherwise, infer type from y (which is known not to be a type parameter)
u.x.set(i, y)
return true
case j >= 0:
//u.check.dump("### j = %d", j)
// y is a type parameter
if ty := u.y.at(j); ty != nil {
// If we have inferred a type ty and it matches x, we
// are done. u.nify won't do this check, so do it now
// to avoid endless recursion.
if x == ty {
return true
}
return u.nify(x, ty, p)
}
// otherwise, infer type from x (which is known not to be a type parameter)
u.y.set(i, x)
return true
}
switch x := x.(type) {
case *Basic:
// Basic types are singletons except for the rune and byte
// aliases, thus we cannot solely rely on the x == y check
// above. See also comment in TypeName.IsAlias.
if y, ok := y.(*Basic); ok {
return x.kind == y.kind
}
case *Array:
// Two array types are identical if they have identical element types
// and the same array length.
if y, ok := y.(*Array); ok {
// If one or both array lengths are unknown (< 0) due to some error,
// assume they are the same to avoid spurious follow-on errors.
return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, p)
}
case *Slice:
// Two slice types are identical if they have identical element types.
if y, ok := y.(*Slice); ok {
return u.nify(x.elem, y.elem, p)
}
case *Struct:
// Two struct types are identical if they have the same sequence of fields,
// and if corresponding fields have the same names, and identical types,
// and identical tags. Two embedded fields are considered to have the same
// name. Lower-case field names from different packages are always different.
if y, ok := y.(*Struct); ok {
if x.NumFields() == y.NumFields() {
for i, f := range x.fields {
g := y.fields[i]
if f.embedded != g.embedded ||
x.Tag(i) != y.Tag(i) ||
!f.sameId(g.pkg, g.name) ||
!u.nify(f.typ, g.typ, p) {
return false
}
}
return true
}
}
case *Pointer:
// Two pointer types are identical if they have identical base types.
if y, ok := y.(*Pointer); ok {
return u.nify(x.base, y.base, p)
}
case *Tuple:
// Two tuples types are identical if they have the same number of elements
// and corresponding elements have identical types.
if y, ok := y.(*Tuple); ok {
if x.Len() == y.Len() {
if x != nil {
for i, v := range x.vars {
w := y.vars[i]
if !u.nify(v.typ, w.typ, p) {
return false
}
}
}
return true
}
}
case *Signature:
// Two function types are identical if they have the same number of parameters
// and result values, corresponding parameter and result types are identical,
// and either both functions are variadic or neither is. Parameter and result
// names are not required to match.
// TODO(gri) handle type parameters or document why we can ignore them.
if y, ok := y.(*Signature); ok {
return x.variadic == y.variadic &&
u.nify(x.params, y.params, p) &&
u.nify(x.results, y.results, p)
}
case *Interface:
// Two interface types are identical if they have the same set of methods with
// the same names and identical function types. Lower-case method names from
// different packages are always different. The order of the methods is irrelevant.
if y, ok := y.(*Interface); ok {
// If identical0 is called (indirectly) via an external API entry point
// (such as Identical, IdenticalIgnoreTags, etc.), check is nil. But in
// that case, interfaces are expected to be complete and lazy completion
// here is not needed.
if u.check != nil {
u.check.completeInterface(token.NoPos, x)
u.check.completeInterface(token.NoPos, y)
}
a := x.allMethods
b := y.allMethods
if len(a) == len(b) {
// Interface types are the only types where cycles can occur
// that are not "terminated" via named types; and such cycles
// can only be created via method parameter types that are
// anonymous interfaces (directly or indirectly) embedding
// the current interface. Example:
//
// type T interface {
// m() interface{T}
// }
//
// If two such (differently named) interfaces are compared,
// endless recursion occurs if the cycle is not detected.
//
// If x and y were compared before, they must be equal
// (if they were not, the recursion would have stopped);
// search the ifacePair stack for the same pair.
//
// This is a quadratic algorithm, but in practice these stacks
// are extremely short (bounded by the nesting depth of interface
// type declarations that recur via parameter types, an extremely
// rare occurrence). An alternative implementation might use a
// "visited" map, but that is probably less efficient overall.
q := &ifacePair{x, y, p}
for p != nil {
if p.identical(q) {
return true // same pair was compared before
}
p = p.prev
}
if debug {
assert(sort.IsSorted(byUniqueMethodName(a)))
assert(sort.IsSorted(byUniqueMethodName(b)))
}
for i, f := range a {
g := b[i]
if f.Id() != g.Id() || !u.nify(f.typ, g.typ, q) {
return false
}
}
return true
}
}
case *Map:
// Two map types are identical if they have identical key and value types.
if y, ok := y.(*Map); ok {
return u.nify(x.key, y.key, p) && u.nify(x.elem, y.elem, p)
}
case *Chan:
// Two channel types are identical if they have identical value types.
// For type unification, channel direction is ignored.
if y, ok := y.(*Chan); ok {
return u.nify(x.elem, y.elem, p)
}
case *Named:
// Two named types are identical if their type names originate
// in the same type declaration.
// if y, ok := y.(*Named); ok {
// return x.obj == y.obj
// }
if y, ok := y.(*Named); ok {
// TODO(gri) This is not always correct: two types may have the same names
// in the same package if one of them is nested in a function.
// Extremely unlikely but we need an always correct solution.
if x.obj.pkg == y.obj.pkg && stripArgNames(x.obj.name) == stripArgNames(y.obj.name) {
assert(len(x.targs) == len(y.targs))
for i, x := range x.targs {
if !u.nify(x, y.targs[i], p) {
return false
}
}
return true
}
}
case *TypeParam:
// Two type parameters (which are not part of the type parameters of the
// enclosing type) are identical if they originate in the same declaration.
if y, ok := y.(*TypeParam); ok {
return x == y
}
case nil:
// avoid a crash in case of nil type
default:
//u.check.dump("### u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams)
unreachable()
}
return false
}