rewrite.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "fmt"
  7. "go/ast"
  8. "go/parser"
  9. "go/token"
  10. "os"
  11. "reflect"
  12. "strings"
  13. "unicode"
  14. "unicode/utf8"
  15. )
  16. func initRewrite() {
  17. if *rewriteRule == "" {
  18. rewrite = nil // disable any previous rewrite
  19. return
  20. }
  21. f := strings.Split(*rewriteRule, "->")
  22. if len(f) != 2 {
  23. fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
  24. os.Exit(2)
  25. }
  26. pattern := parseExpr(f[0], "pattern")
  27. replace := parseExpr(f[1], "replacement")
  28. rewrite = func(fset *token.FileSet, p *ast.File) *ast.File {
  29. return rewriteFile(fset, pattern, replace, p)
  30. }
  31. }
  32. // parseExpr parses s as an expression.
  33. // It might make sense to expand this to allow statement patterns,
  34. // but there are problems with preserving formatting and also
  35. // with what a wildcard for a statement looks like.
  36. func parseExpr(s, what string) ast.Expr {
  37. x, err := parser.ParseExpr(s)
  38. if err != nil {
  39. fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
  40. os.Exit(2)
  41. }
  42. return x
  43. }
  44. // Keep this function for debugging.
  45. /*
  46. func dump(msg string, val reflect.Value) {
  47. fmt.Printf("%s:\n", msg)
  48. ast.Print(fileSet, val.Interface())
  49. fmt.Println()
  50. }
  51. */
  52. // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
  53. func rewriteFile(fileSet *token.FileSet, pattern, replace ast.Expr, p *ast.File) *ast.File {
  54. cmap := ast.NewCommentMap(fileSet, p, p.Comments)
  55. m := make(map[string]reflect.Value)
  56. pat := reflect.ValueOf(pattern)
  57. repl := reflect.ValueOf(replace)
  58. var rewriteVal func(val reflect.Value) reflect.Value
  59. rewriteVal = func(val reflect.Value) reflect.Value {
  60. // don't bother if val is invalid to start with
  61. if !val.IsValid() {
  62. return reflect.Value{}
  63. }
  64. val = apply(rewriteVal, val)
  65. for k := range m {
  66. delete(m, k)
  67. }
  68. if match(m, pat, val) {
  69. val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
  70. }
  71. return val
  72. }
  73. r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
  74. r.Comments = cmap.Filter(r).Comments() // recreate comments list
  75. return r
  76. }
  77. // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
  78. func set(x, y reflect.Value) {
  79. // don't bother if x cannot be set or y is invalid
  80. if !x.CanSet() || !y.IsValid() {
  81. return
  82. }
  83. defer func() {
  84. if x := recover(); x != nil {
  85. if s, ok := x.(string); ok &&
  86. (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
  87. // x cannot be set to y - ignore this rewrite
  88. return
  89. }
  90. panic(x)
  91. }
  92. }()
  93. x.Set(y)
  94. }
  95. // Values/types for special cases.
  96. var (
  97. objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
  98. scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
  99. identType = reflect.TypeOf((*ast.Ident)(nil))
  100. objectPtrType = reflect.TypeOf((*ast.Object)(nil))
  101. positionType = reflect.TypeOf(token.NoPos)
  102. callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
  103. scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
  104. )
  105. // apply replaces each AST field x in val with f(x), returning val.
  106. // To avoid extra conversions, f operates on the reflect.Value form.
  107. func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
  108. if !val.IsValid() {
  109. return reflect.Value{}
  110. }
  111. // *ast.Objects introduce cycles and are likely incorrect after
  112. // rewrite; don't follow them but replace with nil instead
  113. if val.Type() == objectPtrType {
  114. return objectPtrNil
  115. }
  116. // similarly for scopes: they are likely incorrect after a rewrite;
  117. // replace them with nil
  118. if val.Type() == scopePtrType {
  119. return scopePtrNil
  120. }
  121. switch v := reflect.Indirect(val); v.Kind() {
  122. case reflect.Slice:
  123. for i := 0; i < v.Len(); i++ {
  124. e := v.Index(i)
  125. set(e, f(e))
  126. }
  127. case reflect.Struct:
  128. for i := 0; i < v.NumField(); i++ {
  129. e := v.Field(i)
  130. set(e, f(e))
  131. }
  132. case reflect.Interface:
  133. e := v.Elem()
  134. set(v, f(e))
  135. }
  136. return val
  137. }
  138. func isWildcard(s string) bool {
  139. rune, size := utf8.DecodeRuneInString(s)
  140. return size == len(s) && unicode.IsLower(rune)
  141. }
  142. // match reports whether pattern matches val,
  143. // recording wildcard submatches in m.
  144. // If m == nil, match checks whether pattern == val.
  145. func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
  146. // Wildcard matches any expression. If it appears multiple
  147. // times in the pattern, it must match the same expression
  148. // each time.
  149. if m != nil && pattern.IsValid() && pattern.Type() == identType {
  150. name := pattern.Interface().(*ast.Ident).Name
  151. if isWildcard(name) && val.IsValid() {
  152. // wildcards only match valid (non-nil) expressions.
  153. if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
  154. if old, ok := m[name]; ok {
  155. return match(nil, old, val)
  156. }
  157. m[name] = val
  158. return true
  159. }
  160. }
  161. }
  162. // Otherwise, pattern and val must match recursively.
  163. if !pattern.IsValid() || !val.IsValid() {
  164. return !pattern.IsValid() && !val.IsValid()
  165. }
  166. if pattern.Type() != val.Type() {
  167. return false
  168. }
  169. // Special cases.
  170. switch pattern.Type() {
  171. case identType:
  172. // For identifiers, only the names need to match
  173. // (and none of the other *ast.Object information).
  174. // This is a common case, handle it all here instead
  175. // of recursing down any further via reflection.
  176. p := pattern.Interface().(*ast.Ident)
  177. v := val.Interface().(*ast.Ident)
  178. return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
  179. case objectPtrType, positionType:
  180. // object pointers and token positions always match
  181. return true
  182. case callExprType:
  183. // For calls, the Ellipsis fields (token.Position) must
  184. // match since that is how f(x) and f(x...) are different.
  185. // Check them here but fall through for the remaining fields.
  186. p := pattern.Interface().(*ast.CallExpr)
  187. v := val.Interface().(*ast.CallExpr)
  188. if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
  189. return false
  190. }
  191. }
  192. p := reflect.Indirect(pattern)
  193. v := reflect.Indirect(val)
  194. if !p.IsValid() || !v.IsValid() {
  195. return !p.IsValid() && !v.IsValid()
  196. }
  197. switch p.Kind() {
  198. case reflect.Slice:
  199. if p.Len() != v.Len() {
  200. return false
  201. }
  202. for i := 0; i < p.Len(); i++ {
  203. if !match(m, p.Index(i), v.Index(i)) {
  204. return false
  205. }
  206. }
  207. return true
  208. case reflect.Struct:
  209. for i := 0; i < p.NumField(); i++ {
  210. if !match(m, p.Field(i), v.Field(i)) {
  211. return false
  212. }
  213. }
  214. return true
  215. case reflect.Interface:
  216. return match(m, p.Elem(), v.Elem())
  217. }
  218. // Handle token integers, etc.
  219. return p.Interface() == v.Interface()
  220. }
  221. // subst returns a copy of pattern with values from m substituted in place
  222. // of wildcards and pos used as the position of tokens from the pattern.
  223. // if m == nil, subst returns a copy of pattern and doesn't change the line
  224. // number information.
  225. func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
  226. if !pattern.IsValid() {
  227. return reflect.Value{}
  228. }
  229. // Wildcard gets replaced with map value.
  230. if m != nil && pattern.Type() == identType {
  231. name := pattern.Interface().(*ast.Ident).Name
  232. if isWildcard(name) {
  233. if old, ok := m[name]; ok {
  234. return subst(nil, old, reflect.Value{})
  235. }
  236. }
  237. }
  238. if pos.IsValid() && pattern.Type() == positionType {
  239. // use new position only if old position was valid in the first place
  240. if old := pattern.Interface().(token.Pos); !old.IsValid() {
  241. return pattern
  242. }
  243. return pos
  244. }
  245. // Otherwise copy.
  246. switch p := pattern; p.Kind() {
  247. case reflect.Slice:
  248. if p.IsNil() {
  249. // Do not turn nil slices into empty slices. go/ast
  250. // guarantees that certain lists will be nil if not
  251. // populated.
  252. return reflect.Zero(p.Type())
  253. }
  254. v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
  255. for i := 0; i < p.Len(); i++ {
  256. v.Index(i).Set(subst(m, p.Index(i), pos))
  257. }
  258. return v
  259. case reflect.Struct:
  260. v := reflect.New(p.Type()).Elem()
  261. for i := 0; i < p.NumField(); i++ {
  262. v.Field(i).Set(subst(m, p.Field(i), pos))
  263. }
  264. return v
  265. case reflect.Pointer:
  266. v := reflect.New(p.Type()).Elem()
  267. if elem := p.Elem(); elem.IsValid() {
  268. v.Set(subst(m, elem, pos).Addr())
  269. }
  270. return v
  271. case reflect.Interface:
  272. v := reflect.New(p.Type()).Elem()
  273. if elem := p.Elem(); elem.IsValid() {
  274. v.Set(subst(m, elem, pos))
  275. }
  276. return v
  277. }
  278. return pattern
  279. }