diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go
index 03ef071..46873c9 100644
--- a/internal/refactor/inline/inline.go
+++ b/internal/refactor/inline/inline.go
@@ -786,7 +786,7 @@
// Gather the effective call arguments, including the receiver.
// Later, elements will be eliminated (=> nil) by parameter substitution.
- args, err := st.arguments(caller, calleeDecl, assign1)
+ functionCall, err := st.unpackFuncCall(logf, caller, calleeDecl, assign1)
if err != nil {
return nil, err // e.g. implicit field selection cannot be made explicit
}
@@ -839,7 +839,7 @@
// TODO(adonovan): extract this to a function.
if sig.Variadic() {
lastParam := last(params)
- if len(args) > 0 && last(args).spread {
+ if len(functionCall.callArgs) > 0 && last(functionCall.callArgs).spread {
// spread call to variadic: tricky
lastParam.variadic = true
} else {
@@ -861,7 +861,7 @@
// g([]T{a1, ..., aN}...), which we simplify to g(a1, ..., an)
// later; see replaceCalleeID.
n := len(params) - 1
- ordinary, extra := args[:n], args[n:]
+ ordinary, extra := functionCall.callArgs[:n], functionCall.callArgs[n:]
var elts []ast.Expr
freevars := make(map[string]bool)
pure, effects := true, false
@@ -871,7 +871,7 @@
effects = effects || arg.effects
maps.Copy(freevars, arg.freevars)
}
- args = append(ordinary, &argument{
+ functionCall.callArgs = append(ordinary, &argument{
expr: &ast.CompositeLit{
Type: lastParamField.Type,
Elts: elts,
@@ -888,37 +888,45 @@
}
}
}
-
- typeArgs := st.typeArguments(caller.Call)
- if len(typeArgs) != len(callee.TypeParams) {
- return nil, fmt.Errorf("cannot inline: type parameter inference is not yet supported")
+ if len(functionCall.typeArgs) == 0 && len(callee.TypeParams) > 0 {
+ // Calling a generic function with type parameter inference
+ for _, param := range params {
+ param.generic = true
+ }
}
- if err := substituteTypeParams(logf, callee.TypeParams, typeArgs, params, replaceCalleeID); err != nil {
- return nil, err
+ if len(functionCall.typeArgs) > 0 && len(functionCall.typeArgs) == len(callee.TypeParams) {
+ // All type arguments specified at the callsite
+ if err := substituteTypeParams(logf, callee.TypeParams, functionCall.typeArgs, params, replaceCalleeID); err != nil {
+ return nil, err
+ }
+ }
+ if len(functionCall.typeArgs) > 0 && len(functionCall.typeArgs) != len(callee.TypeParams) {
+ // TODO(kusano): can this case happen? if so, add a test.
+ return nil, fmt.Errorf("Number of callsite type arguments (%d) does not match number of callee type parameters (%d)", len(functionCall.typeArgs), len(callee.TypeParams))
}
// Log effective arguments.
- for i, arg := range args {
+ for i, arg := range functionCall.callArgs {
logf("arg #%d: %s pure=%t effects=%t duplicable=%t free=%v type=%v",
i, debugFormatNode(caller.Fset, arg.expr),
arg.pure, arg.effects, arg.duplicable, arg.freevars, arg.typ)
}
// Note: computation below should be expressed in terms of
- // the args and params slices, not the raw material.
+ // the functionCall.CallArgs and params slices, not the raw material.
// Perform parameter substitution.
// May eliminate some elements of params/args.
- substitute(logf, caller, params, args, callee.Effects, callee.Falcon, replaceCalleeID)
+ substitute(logf, caller, params, functionCall.callArgs, callee.Effects, callee.Falcon, replaceCalleeID)
// Update the callee's signature syntax.
updateCalleeParams(calleeDecl, params)
// Create a var (param = arg; ...) decl for use by some strategies.
- bindingDecl := createBindingDecl(logf, caller, args, calleeDecl, callee.Results)
+ bindingDecl := createBindingDecl(logf, caller, functionCall.callArgs, calleeDecl, callee.Results)
var remainingArgs []ast.Expr
- for _, arg := range args {
+ for _, arg := range functionCall.callArgs {
if arg != nil {
remainingArgs = append(remainingArgs, arg.expr)
}
@@ -967,20 +975,20 @@
// Make correction for spread calls
// f(g()) or recv.f(g()) where g() is a tuple.
- if last := last(args); last != nil && last.spread {
+ if last := last(functionCall.callArgs); last != nil && last.spread {
nspread := last.typ.(*types.Tuple).Len()
- if len(args) > 1 { // [recv, g()]
+ if len(functionCall.callArgs) > 1 { // [recv, g()]
// A single AssignStmt cannot discard both, so use a 2-spec var decl.
res.new = &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{makeIdent("_")},
- Values: []ast.Expr{args[0].expr},
+ Values: []ast.Expr{functionCall.callArgs[0].expr},
},
&ast.ValueSpec{
Names: blanks[*ast.Ident](nspread),
- Values: []ast.Expr{args[1].expr},
+ Values: []ast.Expr{functionCall.callArgs[1].expr},
},
},
}
@@ -1297,7 +1305,7 @@
// func (recv *T, x, y int) { body }(new(T), g()),
// which is not a valid argument list because g() must appear alone.
// Reject this case for now.
- if len(args) == 2 && args[0] != nil && args[1] != nil && is[*types.Tuple](args[1].typ) {
+ if len(functionCall.callArgs) == 2 && functionCall.callArgs[0] != nil && functionCall.callArgs[1] != nil && is[*types.Tuple](functionCall.callArgs[1].typ) {
return nil, fmt.Errorf("can't yet inline spread call to method")
}
@@ -1444,16 +1452,19 @@
// typeArguments returns the type arguments of the call.
// It only collects the arguments that are explicitly provided; it does
// not attempt type inference.
-func (st *state) typeArguments(call *ast.CallExpr) []*argument {
+func (st *state) typeArguments(logf logger, call *ast.CallExpr) ([]*argument, ast.Expr) {
var exprs []ast.Expr
+ calledFun := call.Fun
switch d := ast.Unparen(call.Fun).(type) {
case *ast.IndexExpr:
exprs = []ast.Expr{d.Index}
+ calledFun = d.X
case *ast.IndexListExpr:
exprs = d.Indices
+ calledFun = d.X
default:
- // No type arguments
- return nil
+ // No type arguments
+ return nil, calledFun
}
var args []*argument
for _, e := range exprs {
@@ -1469,7 +1480,24 @@
}
args = append(args, arg)
}
- return args
+ return args, calledFun
+}
+
+type functionCall struct {
+ typeArgs []*argument
+ callArgs []*argument
+}
+
+// unpackFuncCall returns the type- and call-arguments of the call
+//
+// See comments for typeArguments() and arguments() for more details
+func (st *state) unpackFuncCall(logf logger, caller *Caller, calleeDecl *ast.FuncDecl, assign1 func(*types.Var) bool) (*functionCall, error) {
+ typeArgs, fn := st.typeArguments(logf, caller.Call)
+ args, err := st.arguments(logf, caller, fn, calleeDecl, assign1)
+ if err != nil {
+ return nil, err
+ }
+ return &functionCall{typeArgs: typeArgs, callArgs: args}, nil
}
// arguments returns the effective arguments of the call.
@@ -1502,15 +1530,12 @@
//
// We compute type-based predicates like pure, duplicable,
// freevars, etc, now, before we start modifying syntax.
-func (st *state) arguments(caller *Caller, calleeDecl *ast.FuncDecl, assign1 func(*types.Var) bool) ([]*argument, error) {
+func (st *state) arguments(logf logger, caller *Caller, calledFun ast.Expr, calleeDecl *ast.FuncDecl, assign1 func(*types.Var) bool) ([]*argument, error) {
var args []*argument
callArgs := caller.Call.Args
if calleeDecl.Recv != nil {
- if len(st.callee.impl.TypeParams) > 0 {
- return nil, fmt.Errorf("cannot inline: generic methods not yet supported")
- }
- sel := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr)
+ sel := ast.Unparen(calledFun).(*ast.SelectorExpr)
seln := caller.Info.Selections[sel]
var recvArg ast.Expr
switch seln.Kind() {
@@ -1628,6 +1653,7 @@
fieldType ast.Expr // syntax of type, from calleeDecl.Type.{Recv,Params}
info *paramInfo // information from AnalyzeCallee
variadic bool // (final) parameter is unsimplified ...T
+ generic bool // parameter type is generic
}
// A replacer replaces an identifier at the given offset in the callee.
@@ -1966,7 +1992,9 @@
(!ref.Assignable && !trivialConversion(arg.constant, arg.typ, param.obj.Type()))
if needType &&
+ !param.generic &&
!types.Identical(types.Default(arg.typ), param.obj.Type()) {
+ //TODO: do we ever need to do a type conversion for a param w/ generic type?
// If arg.expr is already an interface call, strip it.
if call, ok := argExpr.(*ast.CallExpr); ok && len(call.Args) == 1 {
diff --git a/internal/refactor/inline/inline_test.go b/internal/refactor/inline/inline_test.go
index 31e20af..2255752 100644
--- a/internal/refactor/inline/inline_test.go
+++ b/internal/refactor/inline/inline_test.go
@@ -33,6 +33,10 @@
"golang.org/x/tools/txtar"
)
+// TODO(before submission): add tests for:
+// (1) Multiple type parameters (extract from IndexListExpr)
+// (2) Is there a case where we specify a non-empty subset of the type args
+
// TestData executes test scenarios specified by files in testdata/*.txtar.
// Each txtar file describes two sets of files, some containing Go source
// and others expected results.
@@ -373,18 +377,6 @@
func TestErrors(t *testing.T) {
runTests(t, []testcase{
{
- "Inference of type parameters is not yet supported.",
- `func f[T any](x T) T { return x }`,
- `var _ = f(0)`,
- `error: type parameter inference is not yet supported`,
- },
- {
- "Methods on generic types are not yet supported.",
- `type G[T any] struct{}; func (G[T]) f(x T) T { return x }`,
- `var _ = G[int]{}.f(0)`,
- `error: generic methods not yet supported`,
- },
- {
"[NoPackageClause] Can't inline a callee using newer Go to a caller using older Go (#75726).",
"//go:build go1.23\n\npackage p\nfunc f() int { return 0 }",
"//go:build go1.22\n\npackage p\nvar _ = f()",
@@ -823,6 +815,24 @@
})
}
+func TestGenerics(t *testing.T) {
+ runTests(t, []testcase{
+ {
+ "Inference of type parameters is not yet supported.",
+ `func f[T any](x T) T { return x }`,
+ `var _ = f(0)`,
+ //`error: type parameter inference is not yet supported`,
+ `var _ = 0`,
+ },
+ {
+ "Methods on generic types are not yet supported.",
+ `type G[T any] struct{}; func (G[T]) f(x T) T { return x }`,
+ `var _ = G[int]{}.f(0)`,
+ `var _ = 0`,
+ },
+ })
+}
+
func TestTailCallStrategy(t *testing.T) {
runTests(t, []testcase{
{
diff --git a/internal/refactor/inline/testdata/generic.txtar b/internal/refactor/inline/testdata/generic.txtar
index ea0f5bf..ed932a9 100644
--- a/internal/refactor/inline/testdata/generic.txtar
+++ b/internal/refactor/inline/testdata/generic.txtar
@@ -42,7 +42,13 @@
package a
func _() {
- f(1) //@ inline(re"f", re"cannot inline.*type.*inference")
+ f(1) //@ inline(re"f", a2)
+}
+
+-- a2 --
+...
+func _() {
+ print(1) //@ inline(re"f", a2)
}
-- a/a3.go --
@@ -91,5 +97,10 @@
func (G[T]) f(x T) { print(x) }
func _() {
- G[int]{}.f[bool]() //@ inline(re"f", re"generic methods not yet supported")
+ G[int]{}.f[bool](0) //@ inline(re"f", a6)
+}
+-- a6 --
+...
+func _() {
+ print(0) //@ inline(re"f", a6)
}