agent.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package main
  2. import (
  3. "log"
  4. "net"
  5. "os"
  6. "os/signal"
  7. "os/user"
  8. "strconv"
  9. "strings"
  10. "syscall"
  11. "time"
  12. "git.sr.ht/~sircmpwn/go-bare"
  13. "github.com/sevlyar/go-daemon"
  14. )
  15. // todo memguard
  16. var PASSWORD = ""
  17. func getUser() (*user.User, error) {
  18. u, err := user.Current()
  19. if err != nil {
  20. log.Println("error getting user", err)
  21. return nil, err
  22. }
  23. return u, nil
  24. }
  25. func getUserByName(username string) (*user.User, error) {
  26. u, err := user.Lookup(username)
  27. if err != nil {
  28. log.Println("error getting user", err)
  29. return nil, err
  30. }
  31. return u, nil
  32. }
  33. func socketName(u string) string {
  34. return "/tmp/eeze-agent-" + u
  35. }
  36. func timeOut(timeout int, sigc chan os.Signal) {
  37. time.Sleep(time.Duration(timeout) * time.Second)
  38. sigc <- syscall.SIGINT
  39. }
  40. func catch(sigc chan os.Signal, socket string) {
  41. _ = <-sigc
  42. os.Remove(socket)
  43. log.Println("bye")
  44. os.Exit(0)
  45. }
  46. func handle(conn net.Conn) error {
  47. r := bare.NewReader(conn)
  48. cmd, err := r.ReadU8()
  49. if err != nil {
  50. log.Println("error reading command", err)
  51. return err
  52. }
  53. switch cmd {
  54. case 0:
  55. // todo memguard
  56. password, err := r.ReadString()
  57. if err != nil {
  58. log.Println("error reading password to store", err)
  59. return err
  60. }
  61. PASSWORD = password
  62. case 1:
  63. w := bare.NewWriter(conn)
  64. err = w.WriteString(PASSWORD)
  65. if err != nil {
  66. log.Println("error giving password", err)
  67. return err
  68. }
  69. default:
  70. return nil
  71. }
  72. return nil
  73. }
  74. func parseTimeout(timeout string) int {
  75. i, err := strconv.ParseInt(timeout, 10, 64)
  76. if err != nil {
  77. log.Println("error parsing timeout, defualting to 300s", err)
  78. } else if i < 0 {
  79. log.Println("timeout cannot be < 0, defualting to 300s", err)
  80. } else {
  81. return int(i)
  82. }
  83. return 300
  84. }
  85. func main() {
  86. log.Println("main")
  87. timeout := 300
  88. skipArg := true
  89. user, err := getUser()
  90. if err != nil {
  91. log.Println("error getting user name", err)
  92. return
  93. }
  94. for i, arg := range os.Args {
  95. if skipArg {
  96. skipArg = false
  97. continue
  98. }
  99. if arg == "-u" {
  100. u, err := getUserByName(os.Args[i+1])
  101. if err != nil {
  102. log.Println("error getting user from name", err)
  103. } else {
  104. user = u
  105. }
  106. skipArg = true
  107. } else {
  108. timeout = parseTimeout(arg)
  109. }
  110. }
  111. log.Println("read args ", timeout)
  112. socket := socketName(user.Username)
  113. log.Println("socket name ", socket)
  114. uid, err := strconv.ParseInt(user.Uid, 10, 64)
  115. if err != nil {
  116. log.Println("error parsing uid", err)
  117. return
  118. }
  119. gid, err := strconv.ParseInt(user.Gid, 10, 64)
  120. if err != nil {
  121. log.Println("error parsing gid", err)
  122. return
  123. }
  124. context := new(daemon.Context)
  125. context.LogFileName = socket + ".ctx.log"
  126. log.Println("ctx ", context)
  127. child, err := context.Reborn()
  128. if err != nil {
  129. log.Println("error forking", err)
  130. return
  131. }
  132. log.Println("reborn")
  133. if child != nil {
  134. log.Println("waiting for socket")
  135. i := 10000
  136. for true {
  137. _, err := os.Stat(socket)
  138. if err == nil || !strings.Contains(err.Error(), "no such file or directory") {
  139. os.Chown(socket, int(uid), int(gid))
  140. os.Chmod(socket, 0600)
  141. os.Chmod(context.LogFileName, 0644)
  142. log.Println("chmoded socket")
  143. break
  144. }
  145. i--
  146. }
  147. log.Println("socket exists")
  148. return
  149. } else {
  150. defer context.Release()
  151. sigc := make(chan os.Signal, 1)
  152. log.Println("made channel")
  153. signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
  154. log.Println("setup signal")
  155. go catch(sigc, socket)
  156. log.Println("ran catch")
  157. if timeout > 0 {
  158. go timeOut(timeout, sigc)
  159. log.Println("ran timeout")
  160. }
  161. server, err := net.Listen("unix", socket)
  162. if err != nil {
  163. log.Println("error listening", err)
  164. return
  165. }
  166. log.Println("listening")
  167. log.Println("accepting")
  168. for {
  169. conn, err := server.Accept()
  170. if err != nil {
  171. log.Println("error accepting", err)
  172. return
  173. }
  174. go handle(conn)
  175. }
  176. }
  177. }