1// Mostly copied from Go's src/cmd/gofmt:
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package main
7
8import (
9	"bytes"
10	"flag"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"os"
15	"os/exec"
16	"path/filepath"
17	"strings"
18
19	"github.com/google/blueprint/parser"
20)
21
22var (
23	// main operation modes
24	list                = flag.Bool("l", false, "list files whose formatting differs from bpfmt's")
25	overwriteSourceFile = flag.Bool("w", false, "write result to (source) file")
26	writeToStout        = flag.Bool("o", false, "write result to stdout")
27	doDiff              = flag.Bool("d", false, "display diffs instead of rewriting files")
28	sortLists           = flag.Bool("s", false, "sort arrays")
29)
30
31var (
32	exitCode = 0
33)
34
35func report(err error) {
36	fmt.Fprintln(os.Stderr, err)
37	exitCode = 2
38}
39
40func usage() {
41	usageViolation("")
42}
43
44func usageViolation(violation string) {
45	fmt.Fprintln(os.Stderr, violation)
46	fmt.Fprintln(os.Stderr, "usage: bpfmt [flags] [path ...]")
47	flag.PrintDefaults()
48	os.Exit(2)
49}
50
51func processFile(filename string, out io.Writer) error {
52	f, err := os.Open(filename)
53	if err != nil {
54		return err
55	}
56	defer f.Close()
57
58	return processReader(filename, f, out)
59}
60
61func processReader(filename string, in io.Reader, out io.Writer) error {
62	src, err := ioutil.ReadAll(in)
63	if err != nil {
64		return err
65	}
66
67	r := bytes.NewBuffer(src)
68
69	file, errs := parser.Parse(filename, r, parser.NewScope(nil))
70	if len(errs) > 0 {
71		for _, err := range errs {
72			fmt.Fprintln(os.Stderr, err)
73		}
74		return fmt.Errorf("%d parsing errors", len(errs))
75	}
76
77	if *sortLists {
78		parser.SortLists(file)
79	}
80
81	res, err := parser.Print(file)
82	if err != nil {
83		return err
84	}
85
86	if !bytes.Equal(src, res) {
87		// formatting has changed
88		if *list {
89			fmt.Fprintln(out, filename)
90		}
91		if *overwriteSourceFile {
92			err = ioutil.WriteFile(filename, res, 0644)
93			if err != nil {
94				return err
95			}
96		}
97		if *doDiff {
98			data, err := diff(src, res)
99			if err != nil {
100				return fmt.Errorf("computing diff: %s", err)
101			}
102			fmt.Printf("diff %s bpfmt/%s\n", filename, filename)
103			out.Write(data)
104		}
105	}
106
107	if !*list && !*overwriteSourceFile && !*doDiff {
108		_, err = out.Write(res)
109	}
110
111	return err
112}
113
114func isBlueprintFile(f os.FileInfo) bool {
115	name := f.Name()
116	return !f.IsDir() && (name == "Blueprints" || (!strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".bp")))
117}
118
119func walkDir(path string) {
120	visitFile := func(path string, f os.FileInfo, err error) error {
121		if err == nil && isBlueprintFile(f) {
122			err = processFile(path, os.Stdout)
123		}
124		if err != nil {
125			report(err)
126		}
127		return nil
128	}
129
130	filepath.Walk(path, visitFile)
131}
132
133func main() {
134	flag.Usage = usage
135	flag.Parse()
136
137	if !*writeToStout && !*overwriteSourceFile && !*doDiff && !*list {
138		usageViolation("one of -d, -l, -o, or -w is required")
139	}
140
141	if flag.NArg() == 0 {
142		// file to parse is stdin
143		if *overwriteSourceFile {
144			fmt.Fprintln(os.Stderr, "error: cannot use -w with standard input")
145			os.Exit(2)
146		}
147		if err := processReader("<standard input>", os.Stdin, os.Stdout); err != nil {
148			report(err)
149		}
150		os.Exit(exitCode)
151	}
152
153	for i := 0; i < flag.NArg(); i++ {
154		path := flag.Arg(i)
155		switch dir, err := os.Stat(path); {
156		case err != nil:
157			report(err)
158		case dir.IsDir():
159			walkDir(path)
160		default:
161			if err := processFile(path, os.Stdout); err != nil {
162				report(err)
163			}
164		}
165	}
166
167	os.Exit(exitCode)
168}
169
170func diff(b1, b2 []byte) (data []byte, err error) {
171	f1, err := ioutil.TempFile("", "bpfmt")
172	if err != nil {
173		return
174	}
175	defer os.Remove(f1.Name())
176	defer f1.Close()
177
178	f2, err := ioutil.TempFile("", "bpfmt")
179	if err != nil {
180		return
181	}
182	defer os.Remove(f2.Name())
183	defer f2.Close()
184
185	f1.Write(b1)
186	f2.Write(b2)
187
188	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
189	if len(data) > 0 {
190		// diff exits with a non-zero status when the files don't match.
191		// Ignore that failure as long as we get output.
192		err = nil
193	}
194	return
195
196}
197