diff --git a/cmd/internal/pkgsite-cli/command.go b/cmd/internal/pkgsite-cli/command.go
index 9297fa9..2955028 100644
--- a/cmd/internal/pkgsite-cli/command.go
+++ b/cmd/internal/pkgsite-cli/command.go
@@ -21,7 +21,7 @@
args string // e.g. "<package>[@version]"; empty for no-arg commands
summary string // one-line description
flags *flag.FlagSet
- run func(fs *flag.FlagSet, stdout, stderr io.Writer) int
+ run func(fs *flag.FlagSet, args []string, stdout, stderr io.Writer) int
}
func (c *command) usageLine() string {
@@ -111,21 +111,29 @@
return 0
}
}
- return c.run(nil, stdout, stderr)
+ return c.run(nil, nil, stdout, stderr)
}
c.flags.SetOutput(stderr)
c.flags.Usage = func() { printCommandUsage(stderr, c) }
- // TODO: Consider supporting flags after positional arguments for better UX.
- // Currently, flags must appear before positional arguments.
- // Works: pkgsite-cli package -doc=text -examples -imports -json -module golang.org/x/tools golang.org/x/tools/go/packages
- // Fails: pkgsite-cli package golang.org/x/tools/go/packages -doc=text -examples -imports -json -module golang.org/x/tools
- if err := c.flags.Parse(args); err != nil {
- if err == flag.ErrHelp {
- return 0
+
+ // Support flags after positional arguments.
+ var posArgs []string
+ for len(args) > 0 {
+ if err := c.flags.Parse(args); err != nil {
+ if err == flag.ErrHelp {
+ return 0
+ }
+ return 2
}
- return 2
+ args = c.flags.Args()
+ if len(args) > 0 {
+ // The first non-flag argument is a positional argument.
+ posArgs = append(posArgs, args[0])
+ args = args[1:]
+ }
}
- return c.run(c.flags, stdout, stderr)
+
+ return c.run(c.flags, posArgs, stdout, stderr)
}
func versionInfo() string {
diff --git a/cmd/internal/pkgsite-cli/command_test.go b/cmd/internal/pkgsite-cli/command_test.go
index d8129cc..79a87dd 100644
--- a/cmd/internal/pkgsite-cli/command_test.go
+++ b/cmd/internal/pkgsite-cli/command_test.go
@@ -7,6 +7,7 @@
import (
"bytes"
"flag"
+ "fmt"
"io"
"os"
"path/filepath"
@@ -46,19 +47,25 @@
dummyCmd := &command{
name: "dummy",
summary: "dummy command",
- run: func(fs *flag.FlagSet, stdout, stderr io.Writer) int {
+ run: func(fs *flag.FlagSet, args []string, stdout, stderr io.Writer) int {
return 0
},
}
+ var testBool bool
flagsCmd := &command{
name: "flags",
summary: "command with flags",
flags: flag.NewFlagSet("flags", flag.ContinueOnError),
- run: func(fs *flag.FlagSet, stdout, stderr io.Writer) int {
+ run: func(fs *flag.FlagSet, args []string, stdout, stderr io.Writer) int {
+ fmt.Fprint(stdout, strings.Join(args, ","))
+ if testBool {
+ fmt.Fprint(stdout, "+bool")
+ }
return 0
},
}
+ flagsCmd.flags.BoolVar(&testBool, "bool", false, "test bool flag")
cmds := []*command{dummyCmd, flagsCmd}
@@ -94,6 +101,18 @@
wantStdout: "",
},
{
+ name: "flags after positional",
+ args: []string{"flags", "arg1", "-bool"},
+ wantExit: 0,
+ wantStdout: "arg1+bool",
+ },
+ {
+ name: "mixed flags and positional",
+ args: []string{"flags", "-bool", "arg1"},
+ wantExit: 0,
+ wantStdout: "arg1+bool",
+ },
+ {
name: "unknown command",
args: []string{"unknown"},
wantExit: 2,
@@ -104,12 +123,15 @@
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var stdout, stderr bytes.Buffer
+ // Reset flag for each test case since it's a package level variable in closure
+ testBool = false
+
exit := dispatch(tt.args, cmds, &stdout, &stderr)
if exit != tt.wantExit {
t.Errorf("dispatch() exit = %d, want %d", exit, tt.wantExit)
}
- if tt.wantStdout != "" && !strings.Contains(stdout.String(), tt.wantStdout) {
- t.Errorf("dispatch() stdout = %q, want to contain %q", stdout.String(), tt.wantStdout)
+ if tt.wantStdout != "" && stdout.String() != tt.wantStdout {
+ t.Errorf("dispatch() stdout = %q, want %q", stdout.String(), tt.wantStdout)
}
if tt.wantStderr != "" && !strings.Contains(stderr.String(), tt.wantStderr) {
t.Errorf("dispatch() stderr = %q, want to contain %q", stderr.String(), tt.wantStderr)
diff --git a/cmd/internal/pkgsite-cli/main.go b/cmd/internal/pkgsite-cli/main.go
index 341ef03..3b2854c 100644
--- a/cmd/internal/pkgsite-cli/main.go
+++ b/cmd/internal/pkgsite-cli/main.go
@@ -47,7 +47,9 @@
searchFS := flag.NewFlagSet(filepath.Base(os.Args[0])+" search", flag.ContinueOnError)
sf.register(searchFS)
- pkgRun := func(fs *flag.FlagSet, stdout, stderr io.Writer) int { return runPackage(fs, &pf, stdout, stderr) }
+ pkgRun := func(fs *flag.FlagSet, args []string, stdout, stderr io.Writer) int {
+ return runPackage(fs, args, &pf, stdout, stderr)
+ }
var cmds []*command
cmds = []*command{
@@ -63,24 +65,34 @@
args: "<module>[@version]",
summary: "module information",
flags: modFS,
- run: func(fs *flag.FlagSet, stdout, stderr io.Writer) int { return runModule(fs, &mf, stdout, stderr) },
+ run: func(fs *flag.FlagSet, args []string, stdout, stderr io.Writer) int {
+ return runModule(fs, args, &mf, stdout, stderr)
+ },
},
{
name: "search",
args: "<query>",
summary: "search for packages",
flags: searchFS,
- run: func(fs *flag.FlagSet, stdout, stderr io.Writer) int { return runSearch(fs, &sf, stdout, stderr) },
+ run: func(fs *flag.FlagSet, args []string, stdout, stderr io.Writer) int {
+ return runSearch(fs, args, &sf, stdout, stderr)
+ },
},
{
name: "help",
summary: "show this help message",
- run: func(_ *flag.FlagSet, stdout, _ io.Writer) int { printUsage(stdout, cmds); return 0 },
+ run: func(_ *flag.FlagSet, _ []string, stdout, _ io.Writer) int {
+ printUsage(stdout, cmds)
+ return 0
+ },
},
{
name: "version",
summary: "print version information",
- run: func(_ *flag.FlagSet, stdout, _ io.Writer) int { fmt.Fprintln(stdout, versionInfo()); return 0 },
+ run: func(_ *flag.FlagSet, _ []string, stdout, _ io.Writer) int {
+ fmt.Fprintln(stdout, versionInfo())
+ return 0
+ },
},
}
return cmds
diff --git a/cmd/internal/pkgsite-cli/module.go b/cmd/internal/pkgsite-cli/module.go
index 457bd07..e82da40 100644
--- a/cmd/internal/pkgsite-cli/module.go
+++ b/cmd/internal/pkgsite-cli/module.go
@@ -14,12 +14,12 @@
"golang.org/x/sync/errgroup"
)
-func runModule(fs *flag.FlagSet, m *moduleFlags, stdout, stderr io.Writer) int {
- if fs.NArg() != 1 {
+func runModule(fs *flag.FlagSet, args []string, m *moduleFlags, stdout, stderr io.Writer) int {
+ if len(args) != 1 {
fs.Usage()
return 2
}
- path, version := splitPathVersion(fs.Arg(0))
+ path, version := splitPathVersion(args[0])
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
diff --git a/cmd/internal/pkgsite-cli/package.go b/cmd/internal/pkgsite-cli/package.go
index 728f89f..f49dd91 100644
--- a/cmd/internal/pkgsite-cli/package.go
+++ b/cmd/internal/pkgsite-cli/package.go
@@ -16,13 +16,13 @@
"golang.org/x/pkgsite/cmd/internal/pkgsite-cli/client"
)
-func runPackage(fs *flag.FlagSet, p *packageFlags, stdout, stderr io.Writer) int {
- if fs.NArg() != 1 {
- fmt.Fprintf(stderr, "Error: expected exactly 1 package argument, got %d\n", fs.NArg())
+func runPackage(fs *flag.FlagSet, args []string, p *packageFlags, stdout, stderr io.Writer) int {
+ if len(args) != 1 {
+ fmt.Fprintf(stderr, "Error: expected exactly 1 package argument, got %d\n", len(args))
fs.Usage()
return 2
}
- path, version := splitPathVersion(fs.Arg(0))
+ path, version := splitPathVersion(args[0])
goos, goarch, err := defaultGOOSGOARCH()
if err != nil {
diff --git a/cmd/internal/pkgsite-cli/search.go b/cmd/internal/pkgsite-cli/search.go
index e2585f9..56a022b 100644
--- a/cmd/internal/pkgsite-cli/search.go
+++ b/cmd/internal/pkgsite-cli/search.go
@@ -14,12 +14,12 @@
"golang.org/x/pkgsite/cmd/internal/pkgsite-cli/client"
)
-func runSearch(fs *flag.FlagSet, s *searchFlags, stdout, stderr io.Writer) int {
- if fs.NArg() < 1 {
+func runSearch(fs *flag.FlagSet, args []string, s *searchFlags, stdout, stderr io.Writer) int {
+ if len(args) < 1 {
fs.Usage()
return 2
}
- query := strings.Join(fs.Args(), " ")
+ query := strings.Join(args, " ")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()