filter.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ldap
  5. import (
  6. "errors"
  7. "fmt"
  8. "github.com/gogits/gogs/modules/asn1-ber"
  9. )
  10. const (
  11. FilterAnd = 0
  12. FilterOr = 1
  13. FilterNot = 2
  14. FilterEqualityMatch = 3
  15. FilterSubstrings = 4
  16. FilterGreaterOrEqual = 5
  17. FilterLessOrEqual = 6
  18. FilterPresent = 7
  19. FilterApproxMatch = 8
  20. FilterExtensibleMatch = 9
  21. )
  22. var FilterMap = map[uint64]string{
  23. FilterAnd: "And",
  24. FilterOr: "Or",
  25. FilterNot: "Not",
  26. FilterEqualityMatch: "Equality Match",
  27. FilterSubstrings: "Substrings",
  28. FilterGreaterOrEqual: "Greater Or Equal",
  29. FilterLessOrEqual: "Less Or Equal",
  30. FilterPresent: "Present",
  31. FilterApproxMatch: "Approx Match",
  32. FilterExtensibleMatch: "Extensible Match",
  33. }
  34. const (
  35. FilterSubstringsInitial = 0
  36. FilterSubstringsAny = 1
  37. FilterSubstringsFinal = 2
  38. )
  39. var FilterSubstringsMap = map[uint64]string{
  40. FilterSubstringsInitial: "Substrings Initial",
  41. FilterSubstringsAny: "Substrings Any",
  42. FilterSubstringsFinal: "Substrings Final",
  43. }
  44. func CompileFilter(filter string) (*ber.Packet, error) {
  45. if len(filter) == 0 || filter[0] != '(' {
  46. return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
  47. }
  48. packet, pos, err := compileFilter(filter, 1)
  49. if err != nil {
  50. return nil, err
  51. }
  52. if pos != len(filter) {
  53. return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
  54. }
  55. return packet, nil
  56. }
  57. func DecompileFilter(packet *ber.Packet) (ret string, err error) {
  58. defer func() {
  59. if r := recover(); r != nil {
  60. err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
  61. }
  62. }()
  63. ret = "("
  64. err = nil
  65. childStr := ""
  66. switch packet.Tag {
  67. case FilterAnd:
  68. ret += "&"
  69. for _, child := range packet.Children {
  70. childStr, err = DecompileFilter(child)
  71. if err != nil {
  72. return
  73. }
  74. ret += childStr
  75. }
  76. case FilterOr:
  77. ret += "|"
  78. for _, child := range packet.Children {
  79. childStr, err = DecompileFilter(child)
  80. if err != nil {
  81. return
  82. }
  83. ret += childStr
  84. }
  85. case FilterNot:
  86. ret += "!"
  87. childStr, err = DecompileFilter(packet.Children[0])
  88. if err != nil {
  89. return
  90. }
  91. ret += childStr
  92. case FilterSubstrings:
  93. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  94. ret += "="
  95. switch packet.Children[1].Children[0].Tag {
  96. case FilterSubstringsInitial:
  97. ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
  98. case FilterSubstringsAny:
  99. ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
  100. case FilterSubstringsFinal:
  101. ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes())
  102. }
  103. case FilterEqualityMatch:
  104. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  105. ret += "="
  106. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  107. case FilterGreaterOrEqual:
  108. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  109. ret += ">="
  110. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  111. case FilterLessOrEqual:
  112. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  113. ret += "<="
  114. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  115. case FilterPresent:
  116. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  117. ret += "=*"
  118. case FilterApproxMatch:
  119. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  120. ret += "~="
  121. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  122. }
  123. ret += ")"
  124. return
  125. }
  126. func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
  127. for pos < len(filter) && filter[pos] == '(' {
  128. child, newPos, err := compileFilter(filter, pos+1)
  129. if err != nil {
  130. return pos, err
  131. }
  132. pos = newPos
  133. parent.AppendChild(child)
  134. }
  135. if pos == len(filter) {
  136. return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
  137. }
  138. return pos + 1, nil
  139. }
  140. func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
  141. var packet *ber.Packet
  142. var err error
  143. defer func() {
  144. if r := recover(); r != nil {
  145. err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
  146. }
  147. }()
  148. newPos := pos
  149. switch filter[pos] {
  150. case '(':
  151. packet, newPos, err = compileFilter(filter, pos+1)
  152. newPos++
  153. return packet, newPos, err
  154. case '&':
  155. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
  156. newPos, err = compileFilterSet(filter, pos+1, packet)
  157. return packet, newPos, err
  158. case '|':
  159. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
  160. newPos, err = compileFilterSet(filter, pos+1, packet)
  161. return packet, newPos, err
  162. case '!':
  163. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
  164. var child *ber.Packet
  165. child, newPos, err = compileFilter(filter, pos+1)
  166. packet.AppendChild(child)
  167. return packet, newPos, err
  168. default:
  169. attribute := ""
  170. condition := ""
  171. for newPos < len(filter) && filter[newPos] != ')' {
  172. switch {
  173. case packet != nil:
  174. condition += fmt.Sprintf("%c", filter[newPos])
  175. case filter[newPos] == '=':
  176. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
  177. case filter[newPos] == '>' && filter[newPos+1] == '=':
  178. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
  179. newPos++
  180. case filter[newPos] == '<' && filter[newPos+1] == '=':
  181. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
  182. newPos++
  183. case filter[newPos] == '~' && filter[newPos+1] == '=':
  184. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
  185. newPos++
  186. case packet == nil:
  187. attribute += fmt.Sprintf("%c", filter[newPos])
  188. }
  189. newPos++
  190. }
  191. if newPos == len(filter) {
  192. err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
  193. return packet, newPos, err
  194. }
  195. if packet == nil {
  196. err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
  197. return packet, newPos, err
  198. }
  199. packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
  200. switch {
  201. case packet.Tag == FilterEqualityMatch && condition == "*":
  202. packet.Tag = FilterPresent
  203. packet.Description = FilterMap[uint64(packet.Tag)]
  204. case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
  205. // Any
  206. packet.Tag = FilterSubstrings
  207. packet.Description = FilterMap[uint64(packet.Tag)]
  208. seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
  209. seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
  210. packet.AppendChild(seq)
  211. case packet.Tag == FilterEqualityMatch && condition[0] == '*':
  212. // Final
  213. packet.Tag = FilterSubstrings
  214. packet.Description = FilterMap[uint64(packet.Tag)]
  215. seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
  216. seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring"))
  217. packet.AppendChild(seq)
  218. case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
  219. // Initial
  220. packet.Tag = FilterSubstrings
  221. packet.Description = FilterMap[uint64(packet.Tag)]
  222. seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
  223. seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
  224. packet.AppendChild(seq)
  225. default:
  226. packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, condition, "Condition"))
  227. }
  228. newPos++
  229. return packet, newPos, err
  230. }
  231. }