bar.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package progress
  2. import (
  3. "fmt"
  4. "os"
  5. "strings"
  6. "time"
  7. "github.com/ollama/ollama/format"
  8. "golang.org/x/term"
  9. )
  10. type Bar struct {
  11. message string
  12. messageWidth int
  13. maxValue int64
  14. initialValue int64
  15. currentValue int64
  16. started time.Time
  17. stopped time.Time
  18. maxBuckets int
  19. buckets []bucket
  20. }
  21. type bucket struct {
  22. updated time.Time
  23. value int64
  24. }
  25. func NewBar(message string, maxValue, initialValue int64) *Bar {
  26. b := Bar{
  27. message: message,
  28. messageWidth: -1,
  29. maxValue: maxValue,
  30. initialValue: initialValue,
  31. currentValue: initialValue,
  32. started: time.Now(),
  33. maxBuckets: 10,
  34. }
  35. if initialValue >= maxValue {
  36. b.stopped = time.Now()
  37. }
  38. return &b
  39. }
  40. // formatDuration limits the rendering of a time.Duration to 2 units
  41. func formatDuration(d time.Duration) string {
  42. switch {
  43. case d >= 100*time.Hour:
  44. return "99h+"
  45. case d >= time.Hour:
  46. return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60)
  47. default:
  48. return d.Round(time.Second).String()
  49. }
  50. }
  51. func (b *Bar) String() string {
  52. termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
  53. if err != nil {
  54. termWidth = 80
  55. }
  56. var pre strings.Builder
  57. if len(b.message) > 0 {
  58. message := strings.TrimSpace(b.message)
  59. if b.messageWidth > 0 && len(message) > b.messageWidth {
  60. message = message[:b.messageWidth]
  61. }
  62. fmt.Fprintf(&pre, "%s", message)
  63. if padding := b.messageWidth - pre.Len(); padding > 0 {
  64. pre.WriteString(repeat(" ", padding))
  65. }
  66. pre.WriteString(" ")
  67. }
  68. fmt.Fprintf(&pre, "%3.0f%%", b.percent())
  69. var suf strings.Builder
  70. // max 13 characters: "999 MB/999 MB"
  71. if b.stopped.IsZero() {
  72. curValue := format.HumanBytes(b.currentValue)
  73. suf.WriteString(repeat(" ", 6-len(curValue)))
  74. suf.WriteString(curValue)
  75. suf.WriteString("/")
  76. maxValue := format.HumanBytes(b.maxValue)
  77. suf.WriteString(repeat(" ", 6-len(maxValue)))
  78. suf.WriteString(maxValue)
  79. } else {
  80. maxValue := format.HumanBytes(b.maxValue)
  81. suf.WriteString(repeat(" ", 6-len(maxValue)))
  82. suf.WriteString(maxValue)
  83. suf.WriteString(repeat(" ", 7))
  84. }
  85. rate := b.rate()
  86. // max 10 characters: " 999 MB/s"
  87. if b.stopped.IsZero() && rate > 0 {
  88. suf.WriteString(" ")
  89. humanRate := format.HumanBytes(int64(rate))
  90. suf.WriteString(repeat(" ", 6-len(humanRate)))
  91. suf.WriteString(humanRate)
  92. suf.WriteString("/s")
  93. } else {
  94. suf.WriteString(repeat(" ", 10))
  95. }
  96. // max 8 characters: " 59m59s"
  97. if b.stopped.IsZero() && rate > 0 {
  98. suf.WriteString(" ")
  99. var remaining time.Duration
  100. if rate > 0 {
  101. remaining = time.Duration(int64(float64(b.maxValue-b.currentValue)/rate)) * time.Second
  102. }
  103. humanRemaining := formatDuration(remaining)
  104. suf.WriteString(repeat(" ", 6-len(humanRemaining)))
  105. suf.WriteString(humanRemaining)
  106. } else {
  107. suf.WriteString(repeat(" ", 8))
  108. }
  109. var mid strings.Builder
  110. // add 5 extra spaces: 2 boundary characters and 1 space at each end
  111. f := termWidth - pre.Len() - suf.Len() - 5
  112. n := int(float64(f) * b.percent() / 100)
  113. mid.WriteString(" ▕")
  114. if n > 0 {
  115. mid.WriteString(repeat("█", n))
  116. }
  117. if f-n > 0 {
  118. mid.WriteString(repeat(" ", f-n))
  119. }
  120. mid.WriteString("▏ ")
  121. return pre.String() + mid.String() + suf.String()
  122. }
  123. func (b *Bar) Set(value int64) {
  124. if value >= b.maxValue {
  125. value = b.maxValue
  126. }
  127. b.currentValue = value
  128. if b.currentValue >= b.maxValue {
  129. b.stopped = time.Now()
  130. }
  131. // throttle bucket updates to 1 per second
  132. if len(b.buckets) == 0 || time.Since(b.buckets[len(b.buckets)-1].updated) > time.Second {
  133. b.buckets = append(b.buckets, bucket{
  134. updated: time.Now(),
  135. value: value,
  136. })
  137. if len(b.buckets) > b.maxBuckets {
  138. b.buckets = b.buckets[1:]
  139. }
  140. }
  141. }
  142. func (b *Bar) percent() float64 {
  143. if b.maxValue > 0 {
  144. return float64(b.currentValue) / float64(b.maxValue) * 100
  145. }
  146. return 0
  147. }
  148. func (b *Bar) rate() float64 {
  149. var numerator, denominator float64
  150. if !b.stopped.IsZero() {
  151. numerator = float64(b.currentValue - b.initialValue)
  152. denominator = b.stopped.Sub(b.started).Round(time.Second).Seconds()
  153. } else {
  154. switch len(b.buckets) {
  155. case 0:
  156. // noop
  157. case 1:
  158. numerator = float64(b.buckets[0].value - b.initialValue)
  159. denominator = b.buckets[0].updated.Sub(b.started).Round(time.Second).Seconds()
  160. default:
  161. first, last := b.buckets[0], b.buckets[len(b.buckets)-1]
  162. numerator = float64(last.value - first.value)
  163. denominator = last.updated.Sub(first.updated).Round(time.Second).Seconds()
  164. }
  165. }
  166. if denominator != 0 {
  167. return numerator / denominator
  168. }
  169. return 0
  170. }
  171. func repeat(s string, n int) string {
  172. if n > 0 {
  173. return strings.Repeat(s, n)
  174. }
  175. return ""
  176. }