123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- package main
- import (
- "log"
- "net"
- "os"
- "os/signal"
- "os/user"
- "strconv"
- "strings"
- "syscall"
- "time"
- "git.sr.ht/~sircmpwn/go-bare"
- "github.com/sevlyar/go-daemon"
- )
- // todo memguard
- var PASSWORD = ""
- func getUser() (*user.User, error) {
- u, err := user.Current()
- if err != nil {
- log.Println("error getting user", err)
- return nil, err
- }
- return u, nil
- }
- func getUserByName(username string) (*user.User, error) {
- u, err := user.Lookup(username)
- if err != nil {
- log.Println("error getting user", err)
- return nil, err
- }
- return u, nil
- }
- func socketName(u string) string {
- return "/tmp/eeze-agent-" + u
- }
- func timeOut(timeout int, sigc chan os.Signal) {
- time.Sleep(time.Duration(timeout) * time.Second)
- sigc <- syscall.SIGINT
- }
- func catch(sigc chan os.Signal, socket string) {
- _ = <-sigc
- os.Remove(socket)
- log.Println("bye")
- os.Exit(0)
- }
- func handle(conn net.Conn) error {
- r := bare.NewReader(conn)
- cmd, err := r.ReadU8()
- if err != nil {
- log.Println("error reading command", err)
- return err
- }
- switch cmd {
- case 0:
- // todo memguard
- password, err := r.ReadString()
- if err != nil {
- log.Println("error reading password to store", err)
- return err
- }
- PASSWORD = password
- case 1:
- w := bare.NewWriter(conn)
- err = w.WriteString(PASSWORD)
- if err != nil {
- log.Println("error giving password", err)
- return err
- }
- default:
- return nil
- }
- return nil
- }
- func parseTimeout(timeout string) int {
- i, err := strconv.ParseInt(timeout, 10, 64)
- if err != nil {
- log.Println("error parsing timeout, defualting to 300s", err)
- } else if i < 0 {
- log.Println("timeout cannot be < 0, defualting to 300s", err)
- } else {
- return int(i)
- }
- return 300
- }
- func main() {
- log.Println("main")
- timeout := 300
- skipArg := true
- user, err := getUser()
- if err != nil {
- log.Println("error getting user name", err)
- return
- }
- for i, arg := range os.Args {
- if skipArg {
- skipArg = false
- continue
- }
- if arg == "-u" {
- u, err := getUserByName(os.Args[i+1])
- if err != nil {
- log.Println("error getting user from name", err)
- } else {
- user = u
- }
- skipArg = true
- } else {
- timeout = parseTimeout(arg)
- }
- }
- log.Println("read args ", timeout)
- socket := socketName(user.Username)
- log.Println("socket name ", socket)
- uid, err := strconv.ParseInt(user.Uid, 10, 64)
- if err != nil {
- log.Println("error parsing uid", err)
- return
- }
- gid, err := strconv.ParseInt(user.Gid, 10, 64)
- if err != nil {
- log.Println("error parsing gid", err)
- return
- }
- context := new(daemon.Context)
- context.LogFileName = socket + ".ctx.log"
- log.Println("ctx ", context)
- child, err := context.Reborn()
- if err != nil {
- log.Println("error forking", err)
- return
- }
- log.Println("reborn")
- if child != nil {
- log.Println("waiting for socket")
- i := 10000
- for true {
- _, err := os.Stat(socket)
- if err == nil || !strings.Contains(err.Error(), "no such file or directory") {
- os.Chown(socket, int(uid), int(gid))
- os.Chmod(socket, 0600)
- os.Chmod(context.LogFileName, 0644)
- log.Println("chmoded socket")
- break
- }
- i--
- }
- log.Println("socket exists")
- return
- } else {
- defer context.Release()
- sigc := make(chan os.Signal, 1)
- log.Println("made channel")
- signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
- log.Println("setup signal")
- go catch(sigc, socket)
- log.Println("ran catch")
- if timeout > 0 {
- go timeOut(timeout, sigc)
- log.Println("ran timeout")
- }
- server, err := net.Listen("unix", socket)
- if err != nil {
- log.Println("error listening", err)
- return
- }
- log.Println("listening")
- log.Println("accepting")
- for {
- conn, err := server.Accept()
- if err != nil {
- log.Println("error accepting", err)
- return
- }
- go handle(conn)
- }
- }
- }
|