main.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. // Copyright (c) 2017 Arista Networks, Inc.
  2. // Use of this source code is governed by the Apache License 2.0
  3. // that can be found in the COPYING file.
  4. package main
  5. import (
  6. "bytes"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "go/build"
  11. "io/ioutil"
  12. "os"
  13. "path/filepath"
  14. "sort"
  15. "strings"
  16. "golang.org/x/tools/go/vcs"
  17. )
  18. // Implementation taken from "isStandardImportPath" in go's source.
  19. func isStdLibPath(path string) bool {
  20. i := strings.Index(path, "/")
  21. if i < 0 {
  22. i = len(path)
  23. }
  24. elem := path[:i]
  25. return !strings.Contains(elem, ".")
  26. }
  27. // sortImports takes in an "import" body and returns it sorted
  28. func sortImports(in []byte, sections []string) []byte {
  29. type importLine struct {
  30. index int // index into inLines
  31. path string // import path used for sorting
  32. }
  33. // imports holds all the import lines, separated by section. The
  34. // first section is for stdlib imports, the following sections
  35. // hold the user specified sections, the final section is for
  36. // everything else.
  37. imports := make([][]importLine, len(sections)+2)
  38. addImport := func(section, index int, importPath string) {
  39. imports[section] = append(imports[section], importLine{index, importPath})
  40. }
  41. stdlib := 0
  42. offset := 1
  43. other := len(imports) - 1
  44. inLines := bytes.Split(in, []byte{'\n'})
  45. for i, line := range inLines {
  46. if len(line) == 0 {
  47. continue
  48. }
  49. start := bytes.IndexByte(line, '"')
  50. if start == -1 {
  51. continue
  52. }
  53. if comment := bytes.Index(line, []byte("//")); comment > -1 && comment < start {
  54. continue
  55. }
  56. start++ // skip '"'
  57. end := bytes.IndexByte(line[start:], '"') + start
  58. s := string(line[start:end])
  59. found := false
  60. for j, sect := range sections {
  61. if strings.HasPrefix(s, sect) && (len(sect) == len(s) || s[len(sect)] == '/') {
  62. addImport(j+offset, i, s)
  63. found = true
  64. break
  65. }
  66. }
  67. if found {
  68. continue
  69. }
  70. if isStdLibPath(s) {
  71. addImport(stdlib, i, s)
  72. } else {
  73. addImport(other, i, s)
  74. }
  75. }
  76. out := make([]byte, 0, len(in)+2)
  77. needSeperator := false
  78. for _, section := range imports {
  79. if len(section) == 0 {
  80. continue
  81. }
  82. if needSeperator {
  83. out = append(out, '\n')
  84. }
  85. sort.Slice(section, func(a, b int) bool {
  86. return section[a].path < section[b].path
  87. })
  88. for _, s := range section {
  89. out = append(out, inLines[s.index]...)
  90. out = append(out, '\n')
  91. }
  92. needSeperator = true
  93. }
  94. return out
  95. }
  96. func genFile(in []byte, sections []string) ([]byte, error) {
  97. out := make([]byte, 0, len(in)+3) // Add some fudge to avoid re-allocation
  98. for {
  99. const importLine = "\nimport (\n"
  100. const importLineLen = len(importLine)
  101. importStart := bytes.Index(in, []byte(importLine))
  102. if importStart == -1 {
  103. break
  104. }
  105. // Save to `out` everything up to and including "import(\n"
  106. out = append(out, in[:importStart+importLineLen]...)
  107. in = in[importStart+importLineLen:]
  108. importLen := bytes.Index(in, []byte("\n)\n"))
  109. if importLen == -1 {
  110. return nil, errors.New(`parsing error: missing ")"`)
  111. }
  112. // Sort body of "import" and write it to `out`
  113. out = append(out, sortImports(in[:importLen], sections)...)
  114. out = append(out, []byte(")")...)
  115. in = in[importLen+2:]
  116. }
  117. // Write everything leftover to out
  118. out = append(out, in...)
  119. return out, nil
  120. }
  121. // returns true if the file changed
  122. func processFile(filename string, writeFile, listDiffFiles bool, sections []string) (bool, error) {
  123. in, err := ioutil.ReadFile(filename)
  124. if err != nil {
  125. return false, err
  126. }
  127. out, err := genFile(in, sections)
  128. if err != nil {
  129. return false, err
  130. }
  131. equal := bytes.Equal(in, out)
  132. if listDiffFiles {
  133. return !equal, nil
  134. }
  135. if !writeFile {
  136. os.Stdout.Write(out)
  137. return !equal, nil
  138. }
  139. if equal {
  140. return false, nil
  141. }
  142. temp, err := ioutil.TempFile(filepath.Dir(filename), filepath.Base(filename))
  143. if err != nil {
  144. return false, err
  145. }
  146. defer os.RemoveAll(temp.Name())
  147. s, err := os.Stat(filename)
  148. if err != nil {
  149. return false, err
  150. }
  151. if _, err = temp.Write(out); err != nil {
  152. return false, err
  153. }
  154. if err := temp.Close(); err != nil {
  155. return false, err
  156. }
  157. if err := os.Chmod(temp.Name(), s.Mode()); err != nil {
  158. return false, err
  159. }
  160. if err := os.Rename(temp.Name(), filename); err != nil {
  161. return false, err
  162. }
  163. return true, nil
  164. }
  165. // maps directory to vcsRoot
  166. var vcsRootCache = make(map[string]string)
  167. func vcsRootImportPath(f string) (string, error) {
  168. path, err := filepath.Abs(f)
  169. if err != nil {
  170. return "", err
  171. }
  172. dir := filepath.Dir(path)
  173. if root, ok := vcsRootCache[dir]; ok {
  174. return root, nil
  175. }
  176. gopath := build.Default.GOPATH
  177. var root string
  178. _, root, err = vcs.FromDir(dir, filepath.Join(gopath, "src"))
  179. if err != nil {
  180. return "", err
  181. }
  182. vcsRootCache[dir] = root
  183. return root, nil
  184. }
  185. func main() {
  186. writeFile := flag.Bool("w", false, "write result to file instead of stdout")
  187. listDiffFiles := flag.Bool("l", false, "list files whose formatting differs from importsort")
  188. var sections multistring
  189. flag.Var(&sections, "s", "package `prefix` to define an import section,"+
  190. ` ex: "cvshub.com/company". May be specified multiple times.`+
  191. " If not specified the repository root is used.")
  192. flag.Parse()
  193. checkVCSRoot := sections == nil
  194. for _, f := range flag.Args() {
  195. if checkVCSRoot {
  196. root, err := vcsRootImportPath(f)
  197. if err != nil {
  198. fmt.Fprintf(os.Stderr, "error determining VCS root for file %q: %s", f, err)
  199. continue
  200. } else {
  201. sections = multistring{root}
  202. }
  203. }
  204. diff, err := processFile(f, *writeFile, *listDiffFiles, sections)
  205. if err != nil {
  206. fmt.Fprintf(os.Stderr, "error while proccessing file %q: %s", f, err)
  207. continue
  208. }
  209. if *listDiffFiles && diff {
  210. fmt.Println(f)
  211. }
  212. }
  213. }
  214. type multistring []string
  215. func (m *multistring) String() string {
  216. return strings.Join(*m, ", ")
  217. }
  218. func (m *multistring) Set(s string) error {
  219. *m = append(*m, s)
  220. return nil
  221. }