cli.go
Go
1package cli
2
3import (
4 "fmt"
5 "os"
6 "sort"
7 "strconv"
8 "strings"
9 "time"
10
11 "golang.org/x/term"
12)
13
14type IndentState struct { levels []uint8; baseUnit uint8 }
15
16func NewIndentState() *IndentState {
17 return &IndentState{
18 levels: []uint8{0},
19 baseUnit: 4,
20 }
21}
22
23func (is *IndentState) Push() {
24 currentLevel := is.levels[len(is.levels)-1]
25 is.levels = append(is.levels, currentLevel+1)
26}
27
28func (is *IndentState) Pop() {
29 if len(is.levels) > 1 {
30 is.levels = is.levels[:len(is.levels)-1]
31 }
32}
33
34func (is *IndentState) Current() string {
35 level := is.levels[len(is.levels)-1]
36 return strings.Repeat(" ", int(is.baseUnit*level))
37}
38
39func (is *IndentState) AtLevel(level int) string {
40 return strings.Repeat(" ", int(is.baseUnit*uint8(level)))
41}
42
43type Value interface {
44 String() string
45 Set(string) error
46 Get() any
47}
48
49type stringValue struct{ p *string }
50
51func (v *stringValue) Set(s string) error { *v.p = s; return nil }
52func (v *stringValue) String() string { return *v.p }
53func (v *stringValue) Get() any { return *v.p }
54func newStringValue(p *string) *stringValue { return &stringValue{p} }
55
56type boolValue struct{ p *bool }
57
58func (v *boolValue) Set(s string) error {
59 val, err := strconv.ParseBool(s)
60 if err != nil && s != "" {
61 return fmt.Errorf("invalid boolean value '%s': %w", s, err)
62 }
63 *v.p = val || s == ""
64 return nil
65}
66func (v *boolValue) String() string { return strconv.FormatBool(*v.p) }
67func (v *boolValue) Get() any { return *v.p }
68func newBoolValue(p *bool) *boolValue {
69 return &boolValue{p}
70}
71
72type listValue struct{ p *[]string }
73
74func (v *listValue) Set(s string) error { *v.p = append(*v.p, s); return nil }
75func (v *listValue) String() string { return strings.Join(*v.p, ", ") }
76func (v *listValue) Get() any { return *v.p }
77func newListValue(p *[]string) *listValue { return &listValue{p} }
78
79type Flag struct {
80 Name string
81 Shorthand string
82 Usage string
83 Value Value
84 DefValue string
85 ExpectedType string
86}
87
88type FlagGroup struct {
89 Name string
90 Description string
91 Flags []FlagGroupEntry
92 GroupType string
93 AvailableFlagsHeader string
94}
95
96type FlagGroupEntry struct {
97 Name string
98 Prefix string
99 Usage string
100 Enabled *bool
101 Disabled *bool
102}
103
104type FlagSet struct {
105 name string
106 flags map[string]*Flag
107 shorthands map[string]*Flag
108 specialPrefix map[string]*Flag
109 args []string
110 flagGroups []FlagGroup
111}
112
113func NewFlagSet(name string) *FlagSet {
114 return &FlagSet{
115 name: name,
116 flags: make(map[string]*Flag),
117 shorthands: make(map[string]*Flag),
118 specialPrefix: make(map[string]*Flag),
119 }
120}
121
122func (f *FlagSet) Args() []string { return f.args }
123
124func (f *FlagSet) String(p *string, name, shorthand, value, usage, expectedType string) {
125 *p = value
126 f.Var(newStringValue(p), name, shorthand, usage, value, expectedType)
127}
128
129func (f *FlagSet) Bool(p *bool, name, shorthand string, value bool, usage string) {
130 *p = value
131 f.Var(newBoolValue(p), name, shorthand, usage, strconv.FormatBool(value), "")
132}
133
134func (f *FlagSet) List(p *[]string, name, shorthand string, value []string, usage, expectedType string) {
135 *p = value
136 f.Var(newListValue(p), name, shorthand, usage, fmt.Sprintf("%v", value), expectedType)
137}
138
139func (f *FlagSet) Special(p *[]string, prefix, usage, expectedType string) {
140 *p = []string{}
141 f.Var(newListValue(p), prefix, "", usage, "", expectedType)
142 f.specialPrefix[prefix] = f.flags[prefix]
143}
144
145func (f *FlagSet) DefineGroupFlags(entries []FlagGroupEntry) {
146 for i := range entries {
147 if entries[i].Enabled != nil {
148 f.Bool(entries[i].Enabled, entries[i].Prefix+entries[i].Name, "", *entries[i].Enabled, entries[i].Usage)
149 }
150 if entries[i].Disabled != nil {
151 disableUsage := "Disable '" + entries[i].Name + "'"
152 f.Bool(entries[i].Disabled, entries[i].Prefix+"no-"+entries[i].Name, "", *entries[i].Disabled, disableUsage)
153 }
154 }
155}
156
157func (f *FlagSet) AddFlagGroup(name, description, groupType, availableFlagsHeader string, entries []FlagGroupEntry) {
158 f.DefineGroupFlags(entries)
159 f.flagGroups = append(f.flagGroups, FlagGroup{
160 Name: name,
161 Description: description,
162 Flags: entries,
163 GroupType: groupType,
164 AvailableFlagsHeader: availableFlagsHeader,
165 })
166}
167
168func (f *FlagSet) Var(value Value, name, shorthand, usage, defValue, expectedType string) {
169 if name == "" {
170 panic("flag name cannot be empty")
171 }
172 flag := &Flag{Name: name, Shorthand: shorthand, Usage: usage, Value: value, DefValue: defValue, ExpectedType: expectedType}
173 if _, ok := f.flags[name]; ok {
174 panic(fmt.Sprintf("flag redefined: %s", name))
175 }
176 f.flags[name] = flag
177 if shorthand != "" {
178 if _, ok := f.shorthands[shorthand]; ok {
179 panic(fmt.Sprintf("shorthand flag redefined: %s", shorthand))
180 }
181 f.shorthands[shorthand] = flag
182 }
183}
184
185func (f *FlagSet) Parse(arguments []string) error {
186 f.args = []string{}
187 for i := 0; i < len(arguments); i++ {
188 arg := arguments[i]
189 if len(arg) < 2 || arg[0] != '-' {
190 f.args = append(f.args, arg)
191 continue
192 }
193 if arg == "--" {
194 f.args = append(f.args, arguments[i+1:]...)
195 break
196 }
197 if strings.HasPrefix(arg, "--") {
198 if err := f.parseLongFlag(arg, arguments, &i); err != nil {
199 return err
200 }
201 } else {
202 name := arg[1:]
203 if strings.Contains(name, "=") {
204 name = strings.SplitN(name, "=", 2)[0]
205 }
206
207 flag, ok := f.flags[name]
208 if ok {
209 parts := strings.SplitN(arg[1:], "=", 2)
210 if len(parts) == 2 {
211 if err := flag.Value.Set(parts[1]); err != nil {
212 return err
213 }
214 } else {
215 if _, isBool := flag.Value.(*boolValue); isBool {
216 if err := flag.Value.Set(""); err != nil {
217 return err
218 }
219 } else {
220 if i+1 >= len(arguments) {
221 return fmt.Errorf("flag needs an argument: -%s", name)
222 }
223 i++
224 if err := flag.Value.Set(arguments[i]); err != nil {
225 return err
226 }
227 }
228 }
229 } else {
230 if err := f.parseShortFlag(arg, arguments, &i); err != nil {
231 return err
232 }
233 }
234 }
235 }
236 return nil
237}
238
239func (f *FlagSet) parseLongFlag(arg string, arguments []string, i *int) error {
240 parts := strings.SplitN(arg[2:], "=", 2)
241 name := parts[0]
242 if name == "" {
243 return fmt.Errorf("empty flag name")
244 }
245 flag, ok := f.flags[name]
246 if !ok {
247 return fmt.Errorf("unknown flag: --%s", name)
248 }
249 if len(parts) == 2 {
250 return flag.Value.Set(parts[1])
251 }
252 if _, isBool := flag.Value.(*boolValue); isBool {
253 return flag.Value.Set("")
254 }
255 if *i+1 >= len(arguments) {
256 return fmt.Errorf("flag needs an argument: --%s", name)
257 }
258 *i++
259 return flag.Value.Set(arguments[*i])
260}
261
262func (f *FlagSet) parseShortFlag(arg string, arguments []string, i *int) error {
263 for prefix, flag := range f.specialPrefix {
264 if strings.HasPrefix(arg, "-"+prefix) && len(arg) > len(prefix)+1 {
265 return flag.Value.Set(arg[len(prefix)+1:])
266 }
267 }
268
269 shorthand := arg[1:2]
270 flag, ok := f.shorthands[shorthand]
271 if !ok {
272 return fmt.Errorf("unknown shorthand flag: -%s", shorthand)
273 }
274 if _, isBool := flag.Value.(*boolValue); isBool {
275 return flag.Value.Set("")
276 }
277 value := arg[2:]
278 if value == "" {
279 if *i+1 >= len(arguments) {
280 return fmt.Errorf("flag needs an argument: -%s", shorthand)
281 }
282 *i++
283 value = arguments[*i]
284 }
285 return flag.Value.Set(value)
286}
287
288type App struct {
289 Name string
290 Synopsis string
291 Description string
292 Authors []string
293 Repository string
294 Since int
295 FlagSet *FlagSet
296 Action func(args []string) error
297}
298
299func NewApp(name string) *App {
300 return &App{
301 Name: name,
302 FlagSet: NewFlagSet(name),
303 }
304}
305
306func (f *FlagSet) Lookup(name string) *Flag {
307 return f.flags[name]
308}
309
310func (a *App) Run(arguments []string) error {
311 help := false
312 a.FlagSet.Bool(&help, "help", "h", false, "Display this information")
313
314 if err := a.FlagSet.Parse(arguments); err != nil {
315 fmt.Fprintln(os.Stderr, err)
316 a.generateUsagePage(os.Stderr)
317 return err
318 }
319 if help {
320 a.generateHelpPage(os.Stdout)
321 return nil
322 }
323 if a.Action != nil {
324 return a.Action(a.FlagSet.Args())
325 }
326 return nil
327}
328
329func (a *App) generateUsagePage(w *os.File) {
330 var sb strings.Builder
331 termWidth := getTerminalWidth()
332 indent := NewIndentState()
333
334 fmt.Fprintf(&sb, "Usage: %s <options> [input.b] ...\n", a.Name)
335
336 optionFlags := a.getOptionFlags()
337 if len(optionFlags) > 0 {
338 maxFlagWidth := 0
339 maxUsageWidth := 0
340 for _, flag := range optionFlags {
341 flagStrLen := len(a.formatFlagString(flag))
342 if flagStrLen > maxFlagWidth {
343 maxFlagWidth = flagStrLen
344 }
345 usageLen := len(flag.Usage)
346 if usageLen > maxUsageWidth {
347 maxUsageWidth = usageLen
348 }
349 }
350
351 sb.WriteString("\n")
352 fmt.Fprintf(&sb, "%sOptions\n", indent.AtLevel(1))
353 sort.Slice(optionFlags, func(i, j int) bool { return optionFlags[i].Name < optionFlags[j].Name })
354 for _, flag := range optionFlags {
355 a.formatFlagLine(&sb, flag, indent, termWidth, maxFlagWidth, maxUsageWidth)
356 }
357 }
358
359 fmt.Fprintf(&sb, "\nRun '%s --help' for all available options and flags.\n", a.Name)
360 fmt.Fprint(w, sb.String())
361}
362
363func (a *App) generateHelpPage(w *os.File) {
364 var sb strings.Builder
365 termWidth := getTerminalWidth()
366 indent := NewIndentState()
367
368 globalMaxWidth := a.calculateGlobalMaxWidth()
369
370 globalMaxUsageWidth := 0
371 updateMaxUsage := func(s string) {
372 if len(s) > globalMaxUsageWidth {
373 globalMaxUsageWidth = len(s)
374 }
375 }
376 optionFlags := a.getOptionFlags()
377 for _, flag := range optionFlags {
378 updateMaxUsage(flag.Usage)
379 }
380 for _, group := range a.FlagSet.flagGroups {
381 for _, entry := range group.Flags {
382 updateMaxUsage(entry.Usage)
383 }
384 }
385
386 year := time.Now().Year()
387 sb.WriteString("\n")
388 fmt.Fprintf(&sb, "%sCopyright (c) %d: %s\n", indent.AtLevel(1), year, strings.Join(a.Authors, ", ")+" and contributors")
389 if a.Repository != "" {
390 fmt.Fprintf(&sb, "%sFor more details refer to %s\n", indent.AtLevel(1), a.Repository)
391 }
392
393 if a.Synopsis != "" {
394 sb.WriteString("\n")
395 fmt.Fprintf(&sb, "%sSynopsis\n", indent.AtLevel(1))
396 synopsis := strings.ReplaceAll(a.Synopsis, "[", "<")
397 synopsis = strings.ReplaceAll(synopsis, "]", ">")
398 fmt.Fprintf(&sb, "%s%s %s\n", indent.AtLevel(2), a.Name, synopsis)
399 }
400
401 if a.Description != "" {
402 sb.WriteString("\n")
403 fmt.Fprintf(&sb, "%sDescription\n", indent.AtLevel(1))
404 fmt.Fprintf(&sb, "%s%s\n", indent.AtLevel(2), a.Description)
405 }
406
407 if len(optionFlags) > 0 {
408 sb.WriteString("\n")
409 fmt.Fprintf(&sb, "%sOptions\n", indent.AtLevel(1))
410 sort.Slice(optionFlags, func(i, j int) bool { return optionFlags[i].Name < optionFlags[j].Name })
411 for _, flag := range optionFlags {
412 a.formatFlagLine(&sb, flag, indent, termWidth, globalMaxWidth, globalMaxUsageWidth)
413 }
414 }
415
416 if len(a.FlagSet.flagGroups) > 0 {
417 sortedGroups := make([]FlagGroup, len(a.FlagSet.flagGroups))
418 copy(sortedGroups, a.FlagSet.flagGroups)
419 sort.Slice(sortedGroups, func(i, j int) bool { return sortedGroups[i].Name < sortedGroups[j].Name })
420 for _, group := range sortedGroups {
421 a.formatFlagGroup(&sb, group, indent, termWidth, globalMaxWidth, globalMaxUsageWidth)
422 }
423 }
424 fmt.Fprint(w, sb.String())
425}
426
427func (a *App) getOptionFlags() []*Flag {
428 var optionFlags []*Flag
429 for _, flag := range a.FlagSet.flags {
430 if _, isSpecial := a.FlagSet.specialPrefix[flag.Name]; isSpecial {
431 continue
432 }
433 if a.isGroupFlag(flag.Name) {
434 continue
435 }
436 optionFlags = append(optionFlags, flag)
437 }
438 return optionFlags
439}
440
441func (a *App) isGroupFlag(flagName string) bool {
442 for _, group := range a.FlagSet.flagGroups {
443 for _, entry := range group.Flags {
444 if flagName == entry.Prefix+entry.Name || flagName == entry.Prefix+"no-"+entry.Name {
445 return true
446 }
447 }
448 }
449 return false
450}
451
452func (a *App) calculateGlobalMaxWidth() int {
453 maxWidth := 0
454 checkWidth := func(s string) {
455 if len(s) > maxWidth {
456 maxWidth = len(s)
457 }
458 }
459 for _, flag := range a.getOptionFlags() {
460 checkWidth(a.formatFlagString(flag))
461 }
462 for _, group := range a.FlagSet.flagGroups {
463 prefix := group.Flags[0].Prefix
464 groupType := strings.ToLower(strings.TrimSuffix(group.Name, "s"))
465 checkWidth(fmt.Sprintf("-%s<%s>", prefix, groupType))
466 checkWidth(fmt.Sprintf("-%sno-<%s>", prefix, groupType))
467 for _, entry := range group.Flags {
468 checkWidth(entry.Name)
469 }
470 }
471 return maxWidth
472}
473
474func (a *App) formatFlagString(flag *Flag) string {
475 var flagStr strings.Builder
476 _, isBool := flag.Value.(*boolValue)
477
478 if flag.Shorthand != "" {
479 fmt.Fprintf(&flagStr, "-%s", flag.Shorthand)
480 if !isBool {
481 fmt.Fprintf(&flagStr, " <%s>", flag.ExpectedType)
482 }
483 fmt.Fprintf(&flagStr, ", --%s", flag.Name)
484 if !isBool {
485 fmt.Fprintf(&flagStr, " <%s>", flag.ExpectedType)
486 }
487 } else {
488 fmt.Fprintf(&flagStr, "--%s", flag.Name)
489 if !isBool {
490 if flag.ExpectedType != "" {
491 fmt.Fprintf(&flagStr, "=%s", flag.ExpectedType)
492 }
493 }
494 }
495 return flagStr.String()
496}
497
498func (a *App) formatEntry(sb *strings.Builder, indent *IndentState, termWidth int, leftPart, usagePart, rightPart string, globalLeftWidth, globalMaxUsageWidth int) {
499 indentStr := indent.AtLevel(2)
500 indentWidth := len(indentStr)
501 spaceWidth := 1
502
503 fixedPartsWidth := indentWidth + globalLeftWidth + spaceWidth + 2 + len(rightPart)
504 maxFirstUsageWidth := termWidth - fixedPartsWidth
505 if maxFirstUsageWidth < 10 {
506 maxFirstUsageWidth = 10
507 }
508
509 usageLines := wrapText(usagePart, maxFirstUsageWidth)
510
511 firstUsageLine := ""
512 if len(usageLines) > 0 {
513 firstUsageLine = usageLines[0]
514 }
515
516 desiredUsageWidth := globalMaxUsageWidth
517 if desiredUsageWidth > maxFirstUsageWidth {
518 desiredUsageWidth = maxFirstUsageWidth
519 }
520
521 if rightPart != "" {
522 fmt.Fprintf(sb, "%s%-*s %-*s %s\n", indent.AtLevel(2), globalLeftWidth, leftPart, desiredUsageWidth, firstUsageLine, rightPart)
523 } else {
524 fmt.Fprintf(sb, "%s%-*s %s\n", indent.AtLevel(2), globalLeftWidth, leftPart, firstUsageLine)
525 }
526
527 wrappedIndent := strings.Repeat(" ", globalLeftWidth+spaceWidth)
528
529 availableWrappedWidth := termWidth - (indentWidth + globalLeftWidth + spaceWidth)
530 if availableWrappedWidth < 10 {
531 availableWrappedWidth = 10
532 }
533
534 wrappedLineMaxWidth := desiredUsageWidth + 2
535 termAvailable := termWidth - (indentWidth + globalLeftWidth + spaceWidth)
536 if wrappedLineMaxWidth > termAvailable {
537 wrappedLineMaxWidth = termAvailable
538 }
539
540 for i := 1; i < len(usageLines); i++ {
541 fmt.Fprintf(sb, "%s%s%s\n", indentStr, wrappedIndent, usageLines[i])
542 }
543}
544
545func (a *App) formatFlagLine(sb *strings.Builder, flag *Flag, indent *IndentState, termWidth, globalMaxWidth, globalMaxUsageWidth int) {
546 leftPart := a.formatFlagString(flag)
547 usagePart := flag.Usage
548
549 rightPart := ""
550 if flag.DefValue != "" && flag.DefValue != "false" && flag.DefValue != "[]" {
551 if _, isBool := flag.Value.(*boolValue); !isBool {
552 rightPart = fmt.Sprintf("|%s|", flag.DefValue)
553 }
554 }
555 a.formatEntry(sb, indent, termWidth, leftPart, usagePart, rightPart, globalMaxWidth, globalMaxUsageWidth)
556}
557
558func (a *App) formatFlagGroup(sb *strings.Builder, group FlagGroup, indent *IndentState, termWidth, globalMaxWidth, globalMaxUsageWidth int) {
559 sb.WriteString("\n")
560 fmt.Fprintf(sb, "%s%s\n", indent.AtLevel(1), group.Name)
561
562 prefix := group.Flags[0].Prefix
563 groupType := group.GroupType
564 if groupType == "" {
565 groupType = "flag"
566 }
567
568 fmt.Fprintf(sb, "%s%-*s Enable a specific %s\n", indent.AtLevel(2), globalMaxWidth, fmt.Sprintf("-%s<%s>", prefix, groupType), groupType)
569 fmt.Fprintf(sb, "%s%-*s Disable a specific %s\n", indent.AtLevel(2), globalMaxWidth, fmt.Sprintf("-%sno-<%s>", prefix, groupType), groupType)
570
571 if group.AvailableFlagsHeader != "" {
572 fmt.Fprintf(sb, "%s%s\n", indent.AtLevel(1), group.AvailableFlagsHeader)
573 }
574
575 sortedEntries := make([]FlagGroupEntry, len(group.Flags))
576 copy(sortedEntries, group.Flags)
577 sort.Slice(sortedEntries, func(i, j int) bool { return sortedEntries[i].Name < sortedEntries[j].Name })
578
579 for _, entry := range sortedEntries {
580 rightPart := ""
581 if entry.Enabled != nil && *entry.Enabled && (entry.Disabled == nil || !*entry.Disabled) {
582 rightPart = "|x|"
583 } else {
584 rightPart = "|-|"
585 }
586 a.formatEntry(sb, indent, termWidth, entry.Name, entry.Usage, rightPart, globalMaxWidth, globalMaxUsageWidth)
587 }
588}
589
590func getTerminalWidth() int {
591 width, _, err := term.GetSize(int(os.Stdout.Fd()))
592 if err != nil { return 80 }
593 if width < 20 {
594 return 20
595 }
596 return width
597}
598
599func wrapText(text string, maxWidth int) []string {
600 if maxWidth <= 0 {
601 return []string{text}
602 }
603 words := strings.Fields(text)
604 if len(words) == 0 {
605 return []string{}
606 }
607
608 var lines []string
609 var currentLine strings.Builder
610 currentLen := 0
611
612 for _, word := range words {
613 wordLen := len(word)
614 if currentLen+wordLen+1 > maxWidth && currentLen > 0 {
615 lines = append(lines, currentLine.String())
616 currentLine.Reset()
617 currentLen = 0
618 }
619 if currentLen > 0 {
620 currentLine.WriteString(" ")
621 currentLen++
622 }
623 currentLine.WriteString(word)
624 currentLen += wordLen
625 }
626 if currentLine.Len() > 0 {
627 lines = append(lines, currentLine.String())
628 }
629 return lines
630}