shm.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. // License: GPLv3 Copyright: 2022, Kovid Goyal, <kovid at kovidgoyal.net>
  2. package shm
  3. import (
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "os"
  10. "strings"
  11. "kitty/tools/cli"
  12. "golang.org/x/sys/unix"
  13. )
  14. var _ = fmt.Print
  15. var ErrPatternHasSeparator = errors.New("The specified pattern has file path separators in it")
  16. var ErrPatternTooLong = errors.New("The specified pattern for the SHM name is too long")
  17. type ErrNotSupported struct {
  18. err error
  19. }
  20. func (self *ErrNotSupported) Error() string {
  21. return fmt.Sprintf("POSIX shared memory not supported on this platform: with underlying error: %v", self.err)
  22. }
  23. // prefix_and_suffix splits pattern by the last wildcard "*", if applicable,
  24. // returning prefix as the part before "*" and suffix as the part after "*".
  25. func prefix_and_suffix(pattern string) (prefix, suffix string, err error) {
  26. for i := 0; i < len(pattern); i++ {
  27. if os.IsPathSeparator(pattern[i]) {
  28. return "", "", ErrPatternHasSeparator
  29. }
  30. }
  31. if pos := strings.LastIndexByte(pattern, '*'); pos != -1 {
  32. prefix, suffix = pattern[:pos], pattern[pos+1:]
  33. } else {
  34. prefix = pattern
  35. }
  36. return prefix, suffix, nil
  37. }
  38. type MMap interface {
  39. Close() error
  40. Unlink() error
  41. Slice() []byte
  42. Name() string
  43. IsFileSystemBacked() bool
  44. FileSystemName() string
  45. Stat() (fs.FileInfo, error)
  46. Flush() error
  47. Seek(offset int64, whence int) (ret int64, err error)
  48. Read(b []byte) (n int, err error)
  49. Write(b []byte) (n int, err error)
  50. }
  51. type AccessFlags int
  52. const (
  53. READ AccessFlags = iota
  54. WRITE
  55. COPY
  56. )
  57. func mmap(sz int, access AccessFlags, fd int, off int64) ([]byte, error) {
  58. flags := unix.MAP_SHARED
  59. prot := unix.PROT_READ
  60. switch access {
  61. case COPY:
  62. prot |= unix.PROT_WRITE
  63. flags = unix.MAP_PRIVATE
  64. case WRITE:
  65. prot |= unix.PROT_WRITE
  66. }
  67. b, err := unix.Mmap(fd, off, sz, prot, flags)
  68. if err != nil {
  69. return nil, err
  70. }
  71. return b, nil
  72. }
  73. func munmap(s []byte) error {
  74. return unix.Munmap(s)
  75. }
  76. func CreateTemp(pattern string, size uint64) (MMap, error) {
  77. return create_temp(pattern, size)
  78. }
  79. func truncate_or_unlink(ans *os.File, size uint64, unlink func(string) error) (err error) {
  80. fd := int(ans.Fd())
  81. sz := int64(size)
  82. if err = Fallocate_simple(fd, sz); err != nil {
  83. if !errors.Is(err, errors.ErrUnsupported) {
  84. return fmt.Errorf("fallocate() failed on fd from shm_open(%s) with size: %d with error: %w", ans.Name(), size, err)
  85. }
  86. for {
  87. if err = unix.Ftruncate(fd, sz); !errors.Is(err, unix.EINTR) {
  88. break
  89. }
  90. }
  91. }
  92. if err != nil {
  93. _ = ans.Close()
  94. _ = unlink(ans.Name())
  95. return fmt.Errorf("Failed to ftruncate() SHM file %s to size: %d with error: %w", ans.Name(), size, err)
  96. }
  97. return
  98. }
  99. const NUM_BYTES_FOR_SIZE = 4
  100. var ErrRegionTooSmall = errors.New("mmaped region too small")
  101. func WriteWithSize(self MMap, b []byte, at int) error {
  102. if len(self.Slice()) < at+len(b)+NUM_BYTES_FOR_SIZE {
  103. return ErrRegionTooSmall
  104. }
  105. binary.BigEndian.PutUint32(self.Slice()[at:], uint32(len(b)))
  106. copy(self.Slice()[at+NUM_BYTES_FOR_SIZE:], b)
  107. return nil
  108. }
  109. func ReadWithSize(self MMap, at int) ([]byte, error) {
  110. s := self.Slice()[at:]
  111. if len(s) < NUM_BYTES_FOR_SIZE {
  112. return nil, ErrRegionTooSmall
  113. }
  114. size := int(binary.BigEndian.Uint32(self.Slice()[at : at+NUM_BYTES_FOR_SIZE]))
  115. s = s[NUM_BYTES_FOR_SIZE:]
  116. if len(s) < size {
  117. return nil, ErrRegionTooSmall
  118. }
  119. return s[:size], nil
  120. }
  121. func ReadWithSizeAndUnlink(name string, file_callback ...func(fs.FileInfo) error) ([]byte, error) {
  122. mmap, err := Open(name, 0)
  123. if err != nil {
  124. return nil, err
  125. }
  126. if len(file_callback) > 0 {
  127. s, err := mmap.Stat()
  128. if err != nil {
  129. return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
  130. }
  131. for _, f := range file_callback {
  132. err = f(s)
  133. if err != nil {
  134. return nil, err
  135. }
  136. }
  137. }
  138. defer func() {
  139. mmap.Close()
  140. _ = mmap.Unlink()
  141. }()
  142. slice, err := ReadWithSize(mmap, 0)
  143. if err != nil {
  144. return nil, err
  145. }
  146. ans := make([]byte, len(slice))
  147. copy(ans, slice)
  148. return ans, nil
  149. }
  150. func Read(self MMap, b []byte) (n int, err error) {
  151. pos, err := self.Seek(0, io.SeekCurrent)
  152. if err != nil {
  153. return 0, err
  154. }
  155. if pos < 0 {
  156. pos = 0
  157. }
  158. s := self.Slice()
  159. sz := int64(len(s))
  160. if pos >= sz {
  161. return 0, io.EOF
  162. }
  163. n = copy(b, s[pos:])
  164. _, err = self.Seek(int64(n), io.SeekCurrent)
  165. return
  166. }
  167. func Write(self MMap, b []byte) (n int, err error) {
  168. if len(b) == 0 {
  169. return 0, nil
  170. }
  171. pos, _ := self.Seek(0, io.SeekCurrent)
  172. if pos < 0 {
  173. pos = 0
  174. }
  175. s := self.Slice()
  176. if pos >= int64(len(s)) {
  177. return 0, io.ErrShortWrite
  178. }
  179. n = copy(s[pos:], b)
  180. if _, err = self.Seek(int64(n), io.SeekCurrent); err != nil {
  181. return n, err
  182. }
  183. if n < len(b) {
  184. return n, io.ErrShortWrite
  185. }
  186. return n, nil
  187. }
  188. func test_integration_with_python(args []string) (rc int, err error) {
  189. switch args[0] {
  190. default:
  191. return 1, fmt.Errorf("Unknown test type: %s", args[0])
  192. case "read":
  193. data, err := ReadWithSizeAndUnlink(args[1])
  194. if err != nil {
  195. return 1, err
  196. }
  197. _, err = os.Stdout.Write(data)
  198. if err != nil {
  199. return 1, err
  200. }
  201. case "write":
  202. data, err := io.ReadAll(os.Stdin)
  203. if err != nil {
  204. return 1, err
  205. }
  206. mmap, err := CreateTemp("shmtest-", uint64(len(data)+NUM_BYTES_FOR_SIZE))
  207. if err != nil {
  208. return 1, err
  209. }
  210. if err = WriteWithSize(mmap, data, 0); err != nil {
  211. return 1, err
  212. }
  213. mmap.Close()
  214. fmt.Println(mmap.Name())
  215. }
  216. return 0, nil
  217. }
  218. func TestEntryPoint(root *cli.Command) {
  219. root.AddSubCommand(&cli.Command{
  220. Name: "shm",
  221. OnlyArgsAllowed: true,
  222. Run: func(cmd *cli.Command, args []string) (rc int, err error) {
  223. return test_integration_with_python(args)
  224. },
  225. })
  226. }