Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package pluginrpc
import (
"fmt"
"io"
"sort"
"strconv"
"strings"

Expand All @@ -32,6 +33,7 @@ const (
FormatFlagName = "format"

protocolVersion = 1
flagWrapping = 140
)

type flags struct {
Expand All @@ -40,10 +42,13 @@ type flags struct {
format Format
}

func parseFlags(output io.Writer, args []string) (*flags, []string, error) {
func parseFlags(output io.Writer, args []string, spec Spec, doc string) (*flags, []string, error) {
flags := &flags{}
var formatString string
flagSet := pflag.NewFlagSet("plugin", pflag.ContinueOnError)
flagSet.Usage = func() {
_, _ = fmt.Fprint(output, getFlagUsage(flagSet, spec, doc))
}
flagSet.SetOutput(output)
flagSet.BoolVar(&flags.printProtocol, ProtocolFlagName, false, "Print the protocol to stdout and exit.")
flagSet.BoolVar(&flags.printSpec, SpecFlagName, false, "Print the spec to stdout in the specified format and exit.")
Expand All @@ -68,6 +73,35 @@ func parseFlags(output io.Writer, args []string) (*flags, []string, error) {
return flags, flagSet.Args(), nil
}

func getFlagUsage(flagSet *pflag.FlagSet, spec Spec, doc string) string {
var sb strings.Builder
if doc != "" {
_, _ = sb.WriteString(doc)
_, _ = sb.WriteString("\n\n")
}
_, _ = sb.WriteString("Commands:\n\n")
var argBasedProcedureStrings []string
var pathBasedProcedureStrings []string
for _, procedure := range spec.Procedures() {
if args := procedure.Args(); len(args) > 0 {
argBasedProcedureStrings = append(argBasedProcedureStrings, strings.Join(args, " "))
} else {
pathBasedProcedureStrings = append(pathBasedProcedureStrings, procedure.Path())
}
}
sort.Strings(argBasedProcedureStrings)
sort.Strings(pathBasedProcedureStrings)
for _, procedureString := range append(argBasedProcedureStrings, pathBasedProcedureStrings...) {
_, _ = sb.WriteString(" ")
_, _ = sb.WriteString(procedureString)
_, _ = sb.WriteString("\n")
}
_, _ = sb.WriteString("\nFlags:\n\n")
_, _ = sb.WriteString(flagSet.FlagUsagesWrapped(flagWrapping))
_, _ = sb.WriteString(" -h, --help Show this help.\n")
return sb.String()
}

func marshalProtocol(value int) []byte {
return []byte(strconv.Itoa(value) + "\n")
}
Expand Down
6 changes: 5 additions & 1 deletion internal/example/cmd/echo-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ func newServer() (pluginrpc.Server, error) {
serverRegistrar := pluginrpc.NewServerRegistrar()
echoServiceServer := examplev1pluginrpc.NewEchoServiceServer(pluginrpc.NewHandler(spec), echoServiceHandler{})
examplev1pluginrpc.RegisterEchoServiceServer(serverRegistrar, echoServiceServer)
return pluginrpc.NewServer(spec, serverRegistrar)
return pluginrpc.NewServer(
spec,
serverRegistrar,
pluginrpc.ServerWithDoc("An example plugin that implements the EchoService."),
)
}

type echoServiceHandler struct{}
Expand Down
31 changes: 26 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,35 @@ type Server interface {
//
// Once passed to this constructor, the ServerRegistrar can no longer have new
// paths registered to it.
func NewServer(spec Spec, serverRegistrar ServerRegistrar, _ ...ServerOption) (Server, error) {
return newServer(spec, serverRegistrar)
func NewServer(spec Spec, serverRegistrar ServerRegistrar, options ...ServerOption) (Server, error) {
return newServer(spec, serverRegistrar, options...)
}

// ServerOption is an option for a new Server.
type ServerOption func(*serverOptions)

// ServerWithDoc will attach the given documentation to the server.
//
// This will add ths given docs as a prefix when the flag -h/--help is used.
func ServerWithDoc(doc string) ServerOption {
return func(serverOptions *serverOptions) {
serverOptions.doc = doc
}
}

// *** PRIVATE ***

type server struct {
spec Spec
pathToHandleFunc map[string]func(context.Context, HandleEnv, ...HandleOption) error
doc string
}

func newServer(spec Spec, serverRegistrar ServerRegistrar) (*server, error) {
func newServer(spec Spec, serverRegistrar ServerRegistrar, options ...ServerOption) (*server, error) {
serverOptions := newServerOptions()
for _, option := range options {
option(serverOptions)
}
pathToHandleFunc, err := serverRegistrar.pathToHandleFunc()
if err != nil {
return nil, err
Expand All @@ -72,11 +86,12 @@ func newServer(spec Spec, serverRegistrar ServerRegistrar) (*server, error) {
return &server{
spec: spec,
pathToHandleFunc: pathToHandleFunc,
doc: serverOptions.doc,
}, nil
}

func (s *server) Serve(ctx context.Context, env Env) error {
flags, args, err := parseFlags(env.Stderr, env.Args)
flags, args, err := parseFlags(env.Stderr, env.Args, s.spec, s.doc)
if err != nil {
if errors.Is(err, pflag.ErrHelp) {
return nil
Expand Down Expand Up @@ -111,4 +126,10 @@ func (s *server) Serve(ctx context.Context, env Env) error {

func (*server) isServer() {}

type serverOptions struct{}
type serverOptions struct {
doc string
}

func newServerOptions() *serverOptions {
return &serverOptions{}
}
Loading