go/types, types2: implement reverse type inference for function arguments

Allow function-typed function arguments to be generic and collect
their type parameters together with the callee's type parameters
(if any). Use a single inference step to infer the type arguments
for all type parameters simultaneously.

Requires Go 1.21 and that Config.EnableReverseTypeInference is set.
Does not yet support partially instantiated generic function arguments.
Not yet enabled in the compiler.

Known bug: inference may produce an incorrect result is the same
           generic function is passed twice in the same function
           call.

For #59338.

Change-Id: Ia1faa27a28c6353f0bbfd7f81feafc21bd36652c
Reviewed-on: https://go-review.googlesource.com/c/go/+/483935
Auto-Submit: Robert Griesemer <gri@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@google.com>
Run-TryBot: Robert Griesemer <gri@google.com>
This commit is contained in:
Robert Griesemer 2023-04-11 18:02:26 -07:00 committed by Gopher Robot
parent 57cb47209c
commit ea9097c9f7
14 changed files with 420 additions and 86 deletions

View File

@ -174,7 +174,7 @@ type Config struct {
// partially instantiated generic functions may be assigned // partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type // (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments. // inference will attempt to infer the missing type arguments.
// Experimental. Needs a proposal. // See proposal go.dev/issue/59338.
EnableReverseTypeInference bool EnableReverseTypeInference bool
} }

View File

@ -550,18 +550,55 @@ type T[P any] []P
{`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`, {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
[]testInst{{`foo`, []string{`int`}, `func(int)`}}, []testInst{{`foo`, []string{`int`}, `func(int)`}},
}, },
// reverse type parameter inference
{`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
[]testInst{{`g`, []string{`int`}, `func(int)`}},
},
{`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
[]testInst{{`g`, []string{`int`}, `func(int)`}},
},
{`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
[]testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
},
{`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
[]testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
},
// reverse3a not possible (cannot assign to generic function outside of argument passing)
{`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
[]testInst{
{`f`, []string{`string`}, `func(func(int) string)`},
{`g`, []string{`int`}, `func(int) string`},
},
},
{`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
[]testInst{
{`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
{`h`, []string{`int`}, `func([]int, *float32)`},
},
},
{`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
[]testInst{
{`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
{`h`, []string{`int`}, `func([]int, *float32)`},
},
},
} }
for _, test := range tests { for _, test := range tests {
imports := make(testImporter) imports := make(testImporter)
conf := Config{ conf := Config{
Importer: imports, Importer: imports,
Error: func(error) {}, // ignore errors EnableReverseTypeInference: true,
} }
instMap := make(map[*syntax.Name]Instance) instMap := make(map[*syntax.Name]Instance)
useMap := make(map[*syntax.Name]Object) useMap := make(map[*syntax.Name]Object)
makePkg := func(src string) *Package { makePkg := func(src string) *Package {
pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
// allow error for issue51803
if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
t.Fatal(err)
}
imports[pkg.Name()] = pkg imports[pkg.Name()] = pkg
return pkg return pkg
} }

View File

@ -23,14 +23,16 @@ import (
func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) { func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) {
assert(tsig != nil || inst != nil) assert(tsig != nil || inst != nil)
var versionErr bool // set if version error was reported
var instErrPos poser // position for instantion error
if inst != nil {
instErrPos = inst.Pos()
} else {
instErrPos = pos
}
if !check.allowVersion(check.pkg, pos, 1, 18) { if !check.allowVersion(check.pkg, pos, 1, 18) {
var posn poser check.versionErrorf(instErrPos, "go1.18", "function instantiation")
if inst != nil { versionErr = true
posn = inst.Pos()
} else {
posn = pos
}
check.versionErrorf(posn, "go1.18", "function instantiation")
} }
// targs and xlist are the type arguments and corresponding type expressions, or nil. // targs and xlist are the type arguments and corresponding type expressions, or nil.
@ -72,6 +74,13 @@ func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst
// of a synthetic function f where f's parameters are the parameters and results // of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and // of x and where the arguments to the call of f are values of the parameter and
// result types of x. // result types of x.
if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
if inst != nil {
check.versionErrorf(instErrPos, "go1.21", "partially instantiated function in assignment")
} else {
check.versionErrorf(instErrPos, "go1.21", "implicitly instantiated function in assignment")
}
}
n := tsig.params.Len() n := tsig.params.Len()
m := tsig.results.Len() m := tsig.results.Len()
args = make([]*operand, n+m) args = make([]*operand, n+m)
@ -303,7 +312,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
} }
// evaluate arguments // evaluate arguments
args := check.exprList(call.ArgList) args := check.genericExprList(call.ArgList)
sig = check.arguments(call, sig, targs, args, xlist) sig = check.arguments(call, sig, targs, args, xlist)
if wasGeneric && sig.TypeParams().Len() == 0 { if wasGeneric && sig.TypeParams().Len() == 0 {
@ -338,6 +347,8 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
return statement return statement
} }
// exprList evaluates a list of expressions and returns the corresponding operands.
// A single-element expression list may evaluate to multiple operands.
func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) { func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
switch len(elist) { switch len(elist) {
case 0: case 0:
@ -356,6 +367,25 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
return return
} }
// genericExprList is like exprList but result operands may be generic (not fully instantiated).
func (check *Checker) genericExprList(elist []syntax.Expr) (xlist []*operand) {
switch len(elist) {
case 0:
// nothing to do
case 1:
xlist = check.genericMultiExpr(elist[0])
default:
// multiple (possibly invalid) values
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
check.genericExpr(&x, e)
xlist[i] = &x
}
}
return
}
// xlist is the list of type argument expressions supplied in the source code. // xlist is the list of type argument expressions supplied in the source code.
func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []syntax.Expr) (rsig *Signature) { func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []syntax.Expr) (rsig *Signature) {
rsig = sig rsig = sig
@ -386,7 +416,7 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
// set up parameters // set up parameters
sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!) sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
adjusted := false // indicates if sigParams is different from t.params adjusted := false // indicates if sigParams is different from sig.params
if sig.variadic { if sig.variadic {
if ddd { if ddd {
// variadic_func(a, b, c...) // variadic_func(a, b, c...)
@ -451,8 +481,12 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
return return
} }
// infer type arguments and instantiate signature if necessary // collect type parameters of callee and generic function arguments
if sig.TypeParams().Len() > 0 { var tparams []*TypeParam
// collect type parameters of callee
n := sig.TypeParams().Len()
if n > 0 {
if !check.allowVersion(check.pkg, call.Pos(), 1, 18) { if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
if iexpr, _ := call.Fun.(*syntax.IndexExpr); iexpr != nil { if iexpr, _ := call.Fun.(*syntax.IndexExpr); iexpr != nil {
check.versionErrorf(iexpr.Pos(), "go1.18", "function instantiation") check.versionErrorf(iexpr.Pos(), "go1.18", "function instantiation")
@ -460,29 +494,72 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
check.versionErrorf(call.Pos(), "go1.18", "implicit function instantiation") check.versionErrorf(call.Pos(), "go1.18", "implicit function instantiation")
} }
} }
// rename type parameters to avoid problems with recursive calls
// Rename type parameters to avoid problems with recursive calls.
var tparams []*TypeParam
tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams) tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
}
targs := check.infer(call.Pos(), tparams, targs, sigParams, args) // collect type parameters from generic function arguments
var genericArgs []int // indices of generic function arguments
if check.conf.EnableReverseTypeInference {
for i, arg := range args {
// generic arguments cannot have a defined (*Named) type - no need for underlying type below
if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
// TODO(gri) need to also rename type parameters for cases like f(g, g)
tparams = append(tparams, asig.TypeParams().list()...)
genericArgs = append(genericArgs, i)
}
}
}
if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
// at the moment we only support implicit instantiations of argument functions
check.versionErrorf(args[genericArgs[0]].Pos(), "go1.21", "implicitly instantiated function as argument")
}
// tparams holds the type parameters of the callee and generic function arguments, if any:
// the first n type parameters belong to the callee, followed by mi type parameters for each
// of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
// infer missing type arguments of callee and function arguments
if len(tparams) > 0 {
targs = check.infer(call.Pos(), tparams, targs, sigParams, args)
if targs == nil { if targs == nil {
// TODO(gri) If infer inferred the first targs[:n], consider instantiating
// the call signature for better error messages/gopls behavior.
// Perhaps instantiate as much as we can, also for arguments.
// This will require changes to how infer returns its results.
return // error already reported return // error already reported
} }
// compute result signature // compute result signature: instantiate if needed
rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist) rsig = sig
assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore if n > 0 {
check.recordInstance(call.Fun, targs, rsig) rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(call.Fun, targs[:n], rsig)
}
// Optimization: Only if the parameter list was adjusted do we // Optimization: Only if the callee's parameter list was adjusted do we need to
// need to compute it from the adjusted list; otherwise we can // compute it from the adjusted list; otherwise we can simply use the result
// simply use the result signature's parameter list. // signature's parameter list. We only need the n type parameters and arguments
if adjusted { // of the callee.
sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple) if n > 0 && adjusted {
sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
} else { } else {
sigParams = rsig.params sigParams = rsig.params
} }
// compute argument signatures: instantiate if needed
j := n
for _, i := range genericArgs {
asig := args[i].typ.(*Signature)
k := j + asig.TypeParams().Len()
// targs[j:k] are the inferred type arguments for asig
asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
assert(asig.TypeParams().Len() == 0) // signature is not generic anymore
args[i].typ = asig
check.recordInstance(args[i].expr, targs[j:k], asig)
j = k
}
} }
// check arguments // check arguments

View File

@ -1882,6 +1882,27 @@ func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*opera
return return
} }
// genericMultiExpr is like multiExpr but a one-element result may also be generic
// and potential comma-ok expressions are returned as single values.
func (check *Checker) genericMultiExpr(e syntax.Expr) (list []*operand) {
var x operand
check.rawExpr(nil, &x, e, nil, true)
check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
// multiple values - cannot be generic
list = make([]*operand, t.Len())
for i, v := range t.vars {
list[i] = &operand{mode: value, expr: e, typ: v.typ}
}
return
}
// exactly one (possible invalid or generic) value
list = []*operand{&x}
return
}
// exprWithHint typechecks expression e and initializes x with the expression value; // exprWithHint typechecks expression e and initializes x with the expression value;
// hint is the type of a composite literal element. // hint is the type of a composite literal element.
// If an error occurred, x.mode is set to invalid. // If an error occurred, x.mode is set to invalid.

View File

@ -140,17 +140,17 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
} }
for i, arg := range args { for i, arg := range args {
if arg.mode == invalid {
// An error was reported earlier. Ignore this arg
// and continue, we may still be able to infer all
// targs resulting in fewer follow-on errors.
// TODO(gri) determine if we still need this check
continue
}
par := params.At(i) par := params.At(i)
// If we permit bidirectional unification, this conditional code needs to be if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
// executed even if par.typ is not parameterized since the argument may be a // Function parameters are always typed. Arguments may be untyped.
// generic function (for which we want to infer its type arguments). // Collect the indices of untyped arguments and handle them later.
if isParameterized(tparams, par.typ) {
if arg.mode == invalid {
// An error was reported earlier. Ignore this targ
// and continue, we may still be able to infer all
// targs resulting in fewer follow-on errors.
continue
}
if isTyped(arg.typ) { if isTyped(arg.typ) {
if !u.unify(par.typ, arg.typ) { if !u.unify(par.typ, arg.typ) {
errorf("type", par.typ, arg.typ, arg) errorf("type", par.typ, arg.typ, arg)
@ -263,7 +263,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
} }
// --- 3 --- // --- 3 ---
// use information from untyped contants // use information from untyped constants
if traceInference { if traceInference {
u.tracef("== untyped arguments: %v", untyped) u.tracef("== untyped arguments: %v", untyped)
@ -541,7 +541,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
} }
case *TypeParam: case *TypeParam:
// t must be one of w.tparams
return tparamIndex(w.tparams, t) >= 0 return tparamIndex(w.tparams, t) >= 0
default: default:

View File

@ -175,7 +175,7 @@ type Config struct {
// partially instantiated generic functions may be assigned // partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type // (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments. // inference will attempt to infer the missing type arguments.
// Experimental. Needs a proposal. // See proposal go.dev/issue/59338.
_EnableReverseTypeInference bool _EnableReverseTypeInference bool
} }

View File

@ -550,18 +550,57 @@ type T[P any] []P
{`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`, {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
[]testInst{{`foo`, []string{`int`}, `func(int)`}}, []testInst{{`foo`, []string{`int`}, `func(int)`}},
}, },
// reverse type parameter inference
{`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
[]testInst{{`g`, []string{`int`}, `func(int)`}},
},
{`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
[]testInst{{`g`, []string{`int`}, `func(int)`}},
},
{`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
[]testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
},
{`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
[]testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
},
// reverse3a not possible (cannot assign to generic function outside of argument passing)
{`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
[]testInst{
{`f`, []string{`string`}, `func(func(int) string)`},
{`g`, []string{`int`}, `func(int) string`},
},
},
{`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
[]testInst{
{`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
{`h`, []string{`int`}, `func([]int, *float32)`},
},
},
{`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
[]testInst{
{`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
{`h`, []string{`int`}, `func([]int, *float32)`},
},
},
} }
for _, test := range tests { for _, test := range tests {
imports := make(testImporter) imports := make(testImporter)
conf := Config{ conf := Config{
Importer: imports, Importer: imports,
Error: func(error) {}, // ignore errors // Unexported field: set below with boolFieldAddr
// _EnableReverseTypeInference: true,
} }
*boolFieldAddr(&conf, "_EnableReverseTypeInference") = true
instMap := make(map[*ast.Ident]Instance) instMap := make(map[*ast.Ident]Instance)
useMap := make(map[*ast.Ident]Object) useMap := make(map[*ast.Ident]Object)
makePkg := func(src string) *Package { makePkg := func(src string) *Package {
pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
// allow error for issue51803
if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
t.Fatal(err)
}
imports[pkg.Name()] = pkg imports[pkg.Name()] = pkg
return pkg return pkg
} }

View File

@ -25,14 +25,16 @@ import (
func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) { func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
assert(tsig != nil || ix != nil) assert(tsig != nil || ix != nil)
var versionErr bool // set if version error was reported
var instErrPos positioner // position for instantion error
if ix != nil {
instErrPos = inNode(ix.Orig, ix.Lbrack)
} else {
instErrPos = atPos(pos)
}
if !check.allowVersion(check.pkg, pos, 1, 18) { if !check.allowVersion(check.pkg, pos, 1, 18) {
var posn positioner check.softErrorf(instErrPos, UnsupportedFeature, "function instantiation requires go1.18 or later")
if ix != nil { versionErr = true
posn = inNode(ix.Orig, ix.Lbrack)
} else {
posn = atPos(pos)
}
check.softErrorf(posn, UnsupportedFeature, "function instantiation requires go1.18 or later")
} }
// targs and xlist are the type arguments and corresponding type expressions, or nil. // targs and xlist are the type arguments and corresponding type expressions, or nil.
@ -74,6 +76,13 @@ func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *t
// of a synthetic function f where f's parameters are the parameters and results // of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and // of x and where the arguments to the call of f are values of the parameter and
// result types of x. // result types of x.
if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
if ix != nil {
check.softErrorf(instErrPos, UnsupportedFeature, "partially instantiated function in assignment requires go1.21 or later")
} else {
check.softErrorf(instErrPos, UnsupportedFeature, "implicitly instantiated function in assignment requires go1.21 or later")
}
}
n := tsig.params.Len() n := tsig.params.Len()
m := tsig.results.Len() m := tsig.results.Len()
args = make([]*operand, n+m) args = make([]*operand, n+m)
@ -308,7 +317,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
} }
// evaluate arguments // evaluate arguments
args := check.exprList(call.Args) args := check.genericExprList(call.Args)
sig = check.arguments(call, sig, targs, args, xlist) sig = check.arguments(call, sig, targs, args, xlist)
if wasGeneric && sig.TypeParams().Len() == 0 { if wasGeneric && sig.TypeParams().Len() == 0 {
@ -343,6 +352,8 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
return statement return statement
} }
// exprList evaluates a list of expressions and returns the corresponding operands.
// A single-element expression list may evaluate to multiple operands.
func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
switch len(elist) { switch len(elist) {
case 0: case 0:
@ -361,6 +372,25 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
return return
} }
// genericExprList is like exprList but result operands may be generic (not fully instantiated).
func (check *Checker) genericExprList(elist []ast.Expr) (xlist []*operand) {
switch len(elist) {
case 0:
// nothing to do
case 1:
xlist = check.genericMultiExpr(elist[0])
default:
// multiple (possibly invalid) values
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
check.genericExpr(&x, e)
xlist[i] = &x
}
}
return
}
// xlist is the list of type argument expressions supplied in the source code. // xlist is the list of type argument expressions supplied in the source code.
func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) { func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) {
rsig = sig rsig = sig
@ -391,7 +421,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
// set up parameters // set up parameters
sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!) sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
adjusted := false // indicates if sigParams is different from t.params adjusted := false // indicates if sigParams is different from sig.params
if sig.variadic { if sig.variadic {
if ddd { if ddd {
// variadic_func(a, b, c...) // variadic_func(a, b, c...)
@ -452,8 +482,12 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
return return
} }
// infer type arguments and instantiate signature if necessary // collect type parameters of callee and generic function arguments
if sig.TypeParams().Len() > 0 { var tparams []*TypeParam
// collect type parameters of callee
n := sig.TypeParams().Len()
if n > 0 {
if !check.allowVersion(check.pkg, call.Pos(), 1, 18) { if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
switch call.Fun.(type) { switch call.Fun.(type) {
case *ast.IndexExpr, *ast.IndexListExpr: case *ast.IndexExpr, *ast.IndexListExpr:
@ -463,29 +497,72 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later") check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later")
} }
} }
// rename type parameters to avoid problems with recursive calls
// Rename type parameters to avoid problems with recursive calls.
var tparams []*TypeParam
tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams) tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
}
targs := check.infer(call, tparams, targs, sigParams, args) // collect type parameters from generic function arguments
var genericArgs []int // indices of generic function arguments
if check.conf._EnableReverseTypeInference {
for i, arg := range args {
// generic arguments cannot have a defined (*Named) type - no need for underlying type below
if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
// TODO(gri) need to also rename type parameters for cases like f(g, g)
tparams = append(tparams, asig.TypeParams().list()...)
genericArgs = append(genericArgs, i)
}
}
}
if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
// at the moment we only support implicit instantiations of argument functions
check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicitly instantiated function as argument requires go1.21 or later")
}
// tparams holds the type parameters of the callee and generic function arguments, if any:
// the first n type parameters belong to the callee, followed by mi type parameters for each
// of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
// infer missing type arguments of callee and function arguments
if len(tparams) > 0 {
targs = check.infer(call, tparams, targs, sigParams, args)
if targs == nil { if targs == nil {
// TODO(gri) If infer inferred the first targs[:n], consider instantiating
// the call signature for better error messages/gopls behavior.
// Perhaps instantiate as much as we can, also for arguments.
// This will require changes to how infer returns its results.
return // error already reported return // error already reported
} }
// compute result signature // compute result signature: instantiate if needed
rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist) rsig = sig
assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore if n > 0 {
check.recordInstance(call.Fun, targs, rsig) rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(call.Fun, targs[:n], rsig)
}
// Optimization: Only if the parameter list was adjusted do we // Optimization: Only if the callee's parameter list was adjusted do we need to
// need to compute it from the adjusted list; otherwise we can // compute it from the adjusted list; otherwise we can simply use the result
// simply use the result signature's parameter list. // signature's parameter list. We only need the n type parameters and arguments
if adjusted { // of the callee.
sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple) if n > 0 && adjusted {
sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
} else { } else {
sigParams = rsig.params sigParams = rsig.params
} }
// compute argument signatures: instantiate if needed
j := n
for _, i := range genericArgs {
asig := args[i].typ.(*Signature)
k := j + asig.TypeParams().Len()
// targs[j:k] are the inferred type arguments for asig
asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
assert(asig.TypeParams().Len() == 0) // signature is not generic anymore
args[i].typ = asig
check.recordInstance(args[i].expr, targs[j:k], asig)
j = k
}
} }
// check arguments // check arguments

View File

@ -39,6 +39,7 @@ import (
"go/scanner" "go/scanner"
"go/token" "go/token"
"internal/testenv" "internal/testenv"
"internal/types/errors"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -295,9 +296,9 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man
} }
} }
func readCode(err Error) int { func readCode(err Error) errors.Code {
v := reflect.ValueOf(err) v := reflect.ValueOf(err)
return int(v.FieldByName("go116code").Int()) return errors.Code(v.FieldByName("go116code").Int())
} }
// boolFieldAddr(conf, name) returns the address of the boolean field conf.<name>. // boolFieldAddr(conf, name) returns the address of the boolean field conf.<name>.

View File

@ -1829,6 +1829,27 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand,
return return
} }
// genericMultiExpr is like multiExpr but a one-element result may also be generic
// and potential comma-ok expressions are returned as single values.
func (check *Checker) genericMultiExpr(e ast.Expr) (list []*operand) {
var x operand
check.rawExpr(nil, &x, e, nil, true)
check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
// multiple values - cannot be generic
list = make([]*operand, t.Len())
for i, v := range t.vars {
list[i] = &operand{mode: value, expr: e, typ: v.typ}
}
return
}
// exactly one (possible invalid or generic) value
list = []*operand{&x}
return
}
// exprWithHint typechecks expression e and initializes x with the expression value; // exprWithHint typechecks expression e and initializes x with the expression value;
// hint is the type of a composite literal element. // hint is the type of a composite literal element.
// If an error occurred, x.mode is set to invalid. // If an error occurred, x.mode is set to invalid.

View File

@ -142,17 +142,17 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
} }
for i, arg := range args { for i, arg := range args {
if arg.mode == invalid {
// An error was reported earlier. Ignore this arg
// and continue, we may still be able to infer all
// targs resulting in fewer follow-on errors.
// TODO(gri) determine if we still need this check
continue
}
par := params.At(i) par := params.At(i)
// If we permit bidirectional unification, this conditional code needs to be if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
// executed even if par.typ is not parameterized since the argument may be a // Function parameters are always typed. Arguments may be untyped.
// generic function (for which we want to infer its type arguments). // Collect the indices of untyped arguments and handle them later.
if isParameterized(tparams, par.typ) {
if arg.mode == invalid {
// An error was reported earlier. Ignore this targ
// and continue, we may still be able to infer all
// targs resulting in fewer follow-on errors.
continue
}
if isTyped(arg.typ) { if isTyped(arg.typ) {
if !u.unify(par.typ, arg.typ) { if !u.unify(par.typ, arg.typ) {
errorf("type", par.typ, arg.typ, arg) errorf("type", par.typ, arg.typ, arg)
@ -265,7 +265,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
} }
// --- 3 --- // --- 3 ---
// use information from untyped contants // use information from untyped constants
if traceInference { if traceInference {
u.tracef("== untyped arguments: %v", untyped) u.tracef("== untyped arguments: %v", untyped)
@ -543,7 +543,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
} }
case *TypeParam: case *TypeParam:
// t must be one of w.tparams
return tparamIndex(w.tparams, t) >= 0 return tparamIndex(w.tparams, t) >= 0
default: default:

View File

@ -10,11 +10,13 @@
package p package p
func f1[P any](P) {} func f1[P any](P) {}
func f2[P any]() P { var x P; return x } func f2[P any]() P { var x P; return x }
func f3[P, Q any](P) Q { var x Q; return x } func f3[P, Q any](P) Q { var x Q; return x }
func f4[P any](P, P) {} func f4[P any](P, P) {}
func f5[P any](P) []P { return nil } func f5[P any](P) []P { return nil }
func f6[P any](int) P { var x P; return x }
func f7[P any](P) string { return "" }
// initialization expressions // initialization expressions
var ( var (
@ -71,3 +73,22 @@ func _() func(string) []int {
func _() (_, _ func(int)) { return f1, f1 } func _() (_, _ func(int)) { return f1, f1 }
func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ } func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ }
// Argument passing
func g1(func(int)) {}
func g2(func(int, int)) {}
func g3(func(int) string) {}
func g4[P any](func(P) string) {}
func g5[P, Q any](func(P) string, func(P) Q) {}
func g6(func(int), func(string)) {}
func _() {
g1(f1)
g1(f2 /* ERROR "cannot infer P" */)
g2(f4)
g4(f6)
g5(f6, f7)
// TODO(gri) this should work (requires type parameter renaming for f1)
g6(f1, f1 /* ERROR "type func[P any](P) of f1 does not match func(string)" */)
}

View File

@ -0,0 +1,21 @@
// -reverseTypeInference -lang=go1.20
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
func g[P any](P) {}
func h[P, Q any](P) Q { panic(0) }
var _ func(int) = g /* ERROR "implicitly instantiated function in assignment requires go1.21 or later" */
var _ func(int) string = h[ /* ERROR "partially instantiated function in assignment requires go1.21 or later" */ int]
func f1(func(int)) {}
func f2(int, func(int)) {}
func _() {
f1( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ g)
f2( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ 0, g)
}

View File

@ -0,0 +1,21 @@
// -reverseTypeInference
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
func g[P any](P) {}
func h[P, Q any](P) Q { panic(0) }
var _ func(int) = g
var _ func(int) string = h[int]
func f1(func(int)) {}
func f2(int, func(int)) {}
func _() {
f1(g)
f2(0, g)
}