auth.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package auth
  2. import (
  3. "encoding/gob"
  4. "errors"
  5. "net/http"
  6. "time"
  7. "github.com/gorilla/sessions"
  8. )
  9. // custom sessionKey type to prevent collision
  10. type sessionKey uint
  11. func init() {
  12. // need to register our Key with gob so gorilla/sessions can (de)serialize it
  13. gob.Register(userKey)
  14. gob.Register(time.Time{})
  15. }
  16. const (
  17. defaultSessionName = "AuthSession"
  18. userKey sessionKey = iota
  19. userTimeout
  20. )
  21. // errors to be checked against returned
  22. var (
  23. ErrBadLogin = errors.New("Bad Login")
  24. ErrNotAuthorized = errors.New("Not Authorized")
  25. )
  26. // Auther allows for custom authentication backends
  27. type Auther interface {
  28. // Check should return a non-nil error for failed requests (like ErrBadLogin)
  29. // and it can pass custom data that is saved in the cookie through the first return argument
  30. Check(user, pass string) (interface{}, error)
  31. }
  32. type Handler struct {
  33. auther Auther
  34. store sessions.Store
  35. errorHandler ErrorHandler
  36. notAuthorizedHandler http.Handler
  37. redirLanding string // the url to redirect to after login
  38. redirLogout string // the url to redirect to after logout
  39. // how long should a session life
  40. lifetime time.Duration
  41. // the name of the cookie
  42. sessionName string
  43. }
  44. func NewHandler(a Auther, options ...Option) (*Handler, error) {
  45. var ah Handler
  46. ah.auther = a
  47. for _, o := range options {
  48. if err := o(&ah); err != nil {
  49. return nil, err
  50. }
  51. }
  52. if ah.store == nil {
  53. return nil, errors.New("please set a session.Store")
  54. }
  55. // defaults
  56. if ah.lifetime == 0 {
  57. ah.lifetime = 5 * time.Minute
  58. }
  59. if ah.redirLanding == "" {
  60. ah.redirLanding = "/"
  61. }
  62. if ah.redirLogout == "" {
  63. ah.redirLogout = ah.redirLanding
  64. }
  65. if ah.sessionName == "" {
  66. ah.sessionName = defaultSessionName
  67. }
  68. if ah.errorHandler == nil {
  69. ah.errorHandler = func(w http.ResponseWriter, r *http.Request, err error, code int) {
  70. http.Error(w, err.Error(), code)
  71. }
  72. }
  73. if ah.notAuthorizedHandler == nil {
  74. ah.notAuthorizedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  75. ah.errorHandler(w, r, ErrNotAuthorized, http.StatusUnauthorized)
  76. })
  77. }
  78. return &ah, nil
  79. }
  80. func (ah Handler) Authorize(w http.ResponseWriter, r *http.Request) {
  81. session, err := ah.store.Get(r, ah.sessionName)
  82. if err != nil {
  83. ah.errorHandler(w, r, err, http.StatusInternalServerError)
  84. return
  85. }
  86. if err := r.ParseForm(); err != nil {
  87. ah.errorHandler(w, r, err, http.StatusInternalServerError)
  88. return
  89. }
  90. user := r.Form.Get("user")
  91. pass := r.Form.Get("pass")
  92. if user == "" || pass == "" {
  93. ah.errorHandler(w, r, ErrBadLogin, http.StatusBadRequest)
  94. return
  95. }
  96. id, err := ah.auther.Check(user, pass)
  97. if err != nil {
  98. var code = http.StatusInternalServerError
  99. if err == ErrBadLogin {
  100. code = http.StatusBadRequest
  101. }
  102. ah.errorHandler(w, r, err, code)
  103. return
  104. }
  105. session.Values[userKey] = id
  106. session.Values[userTimeout] = time.Now().Add(ah.lifetime)
  107. if err := session.Save(r, w); err != nil {
  108. ah.errorHandler(w, r, err, http.StatusInternalServerError)
  109. return
  110. }
  111. http.Redirect(w, r, ah.redirLanding, http.StatusSeeOther)
  112. return
  113. }
  114. func (ah Handler) Authenticate(h http.Handler) http.Handler {
  115. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  116. if _, err := ah.AuthenticateRequest(r); err != nil {
  117. ah.notAuthorizedHandler.ServeHTTP(w, r)
  118. return
  119. }
  120. h.ServeHTTP(w, r)
  121. })
  122. }
  123. func (ah Handler) AuthenticateRequest(r *http.Request) (interface{}, error) {
  124. session, err := ah.store.Get(r, ah.sessionName)
  125. if err != nil {
  126. return nil, err
  127. }
  128. if session.IsNew {
  129. return nil, ErrNotAuthorized
  130. }
  131. user, ok := session.Values[userKey]
  132. if !ok {
  133. return nil, ErrNotAuthorized
  134. }
  135. t, ok := session.Values[userTimeout]
  136. if !ok {
  137. return nil, ErrNotAuthorized
  138. }
  139. tout, ok := t.(time.Time)
  140. if !ok {
  141. return nil, ErrNotAuthorized
  142. }
  143. if time.Now().After(tout) {
  144. return nil, ErrNotAuthorized
  145. }
  146. return user, nil
  147. }
  148. func (ah Handler) Logout(w http.ResponseWriter, r *http.Request) {
  149. session, err := ah.store.Get(r, ah.sessionName)
  150. if err != nil {
  151. ah.errorHandler(w, r, err, http.StatusInternalServerError)
  152. return
  153. }
  154. session.Values[userTimeout] = time.Now().Add(-ah.lifetime)
  155. session.Options.MaxAge = -1
  156. if err := session.Save(r, w); err != nil {
  157. ah.errorHandler(w, r, err, http.StatusInternalServerError)
  158. return
  159. }
  160. http.Redirect(w, r, ah.redirLogout, http.StatusSeeOther)
  161. return
  162. }