db.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. package core
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "errors"
  6. "fmt"
  7. "reflect"
  8. "regexp"
  9. )
  10. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  11. vv := reflect.ValueOf(mp)
  12. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  13. return "", []interface{}{}, ErrNoMapPointer
  14. }
  15. args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
  16. var err error
  17. query = re.ReplaceAllStringFunc(query, func(src string) string {
  18. v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
  19. if !v.IsValid() {
  20. err = fmt.Errorf("map key %s is missing", src[1:])
  21. } else {
  22. args = append(args, v.Interface())
  23. }
  24. return "?"
  25. })
  26. return query, args, err
  27. }
  28. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  29. vv := reflect.ValueOf(st)
  30. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  31. return "", []interface{}{}, ErrNoStructPointer
  32. }
  33. args := make([]interface{}, 0)
  34. var err error
  35. query = re.ReplaceAllStringFunc(query, func(src string) string {
  36. fv := vv.Elem().FieldByName(src[1:]).Interface()
  37. if v, ok := fv.(driver.Valuer); ok {
  38. var value driver.Value
  39. value, err = v.Value()
  40. if err != nil {
  41. return "?"
  42. }
  43. args = append(args, value)
  44. } else {
  45. args = append(args, fv)
  46. }
  47. return "?"
  48. })
  49. if err != nil {
  50. return "", []interface{}{}, err
  51. }
  52. return query, args, nil
  53. }
  54. type DB struct {
  55. *sql.DB
  56. Mapper IMapper
  57. }
  58. func Open(driverName, dataSourceName string) (*DB, error) {
  59. db, err := sql.Open(driverName, dataSourceName)
  60. if err != nil {
  61. return nil, err
  62. }
  63. return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
  64. }
  65. func FromDB(db *sql.DB) *DB {
  66. return &DB{db, NewCacheMapper(&SnakeMapper{})}
  67. }
  68. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  69. rows, err := db.DB.Query(query, args...)
  70. if err != nil {
  71. if rows != nil {
  72. rows.Close()
  73. }
  74. return nil, err
  75. }
  76. return &Rows{rows, db.Mapper}, nil
  77. }
  78. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  79. query, args, err := MapToSlice(query, mp)
  80. if err != nil {
  81. return nil, err
  82. }
  83. return db.Query(query, args...)
  84. }
  85. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  86. query, args, err := StructToSlice(query, st)
  87. if err != nil {
  88. return nil, err
  89. }
  90. return db.Query(query, args...)
  91. }
  92. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  93. rows, err := db.Query(query, args...)
  94. if err != nil {
  95. return &Row{nil, err}
  96. }
  97. return &Row{rows, nil}
  98. }
  99. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  100. query, args, err := MapToSlice(query, mp)
  101. if err != nil {
  102. return &Row{nil, err}
  103. }
  104. return db.QueryRow(query, args...)
  105. }
  106. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  107. query, args, err := StructToSlice(query, st)
  108. if err != nil {
  109. return &Row{nil, err}
  110. }
  111. return db.QueryRow(query, args...)
  112. }
  113. type Stmt struct {
  114. *sql.Stmt
  115. Mapper IMapper
  116. names map[string]int
  117. }
  118. func (db *DB) Prepare(query string) (*Stmt, error) {
  119. names := make(map[string]int)
  120. var i int
  121. query = re.ReplaceAllStringFunc(query, func(src string) string {
  122. names[src[1:]] = i
  123. i += 1
  124. return "?"
  125. })
  126. stmt, err := db.DB.Prepare(query)
  127. if err != nil {
  128. return nil, err
  129. }
  130. return &Stmt{stmt, db.Mapper, names}, nil
  131. }
  132. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  133. vv := reflect.ValueOf(mp)
  134. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  135. return nil, errors.New("mp should be a map's pointer")
  136. }
  137. args := make([]interface{}, len(s.names))
  138. for k, i := range s.names {
  139. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  140. }
  141. return s.Stmt.Exec(args...)
  142. }
  143. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  144. vv := reflect.ValueOf(st)
  145. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  146. return nil, errors.New("mp should be a map's pointer")
  147. }
  148. args := make([]interface{}, len(s.names))
  149. for k, i := range s.names {
  150. args[i] = vv.Elem().FieldByName(k).Interface()
  151. }
  152. return s.Stmt.Exec(args...)
  153. }
  154. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  155. rows, err := s.Stmt.Query(args...)
  156. if err != nil {
  157. return nil, err
  158. }
  159. return &Rows{rows, s.Mapper}, nil
  160. }
  161. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  162. vv := reflect.ValueOf(mp)
  163. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  164. return nil, errors.New("mp should be a map's pointer")
  165. }
  166. args := make([]interface{}, len(s.names))
  167. for k, i := range s.names {
  168. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  169. }
  170. return s.Query(args...)
  171. }
  172. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  173. vv := reflect.ValueOf(st)
  174. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  175. return nil, errors.New("mp should be a map's pointer")
  176. }
  177. args := make([]interface{}, len(s.names))
  178. for k, i := range s.names {
  179. args[i] = vv.Elem().FieldByName(k).Interface()
  180. }
  181. return s.Query(args...)
  182. }
  183. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  184. rows, err := s.Query(args...)
  185. return &Row{rows, err}
  186. }
  187. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  188. vv := reflect.ValueOf(mp)
  189. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  190. return &Row{nil, errors.New("mp should be a map's pointer")}
  191. }
  192. args := make([]interface{}, len(s.names))
  193. for k, i := range s.names {
  194. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  195. }
  196. return s.QueryRow(args...)
  197. }
  198. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  199. vv := reflect.ValueOf(st)
  200. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  201. return &Row{nil, errors.New("st should be a struct's pointer")}
  202. }
  203. args := make([]interface{}, len(s.names))
  204. for k, i := range s.names {
  205. args[i] = vv.Elem().FieldByName(k).Interface()
  206. }
  207. return s.QueryRow(args...)
  208. }
  209. var (
  210. re = regexp.MustCompile(`[?](\w+)`)
  211. )
  212. // insert into (name) values (?)
  213. // insert into (name) values (?name)
  214. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  215. query, args, err := MapToSlice(query, mp)
  216. if err != nil {
  217. return nil, err
  218. }
  219. return db.DB.Exec(query, args...)
  220. }
  221. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  222. query, args, err := StructToSlice(query, st)
  223. if err != nil {
  224. return nil, err
  225. }
  226. return db.DB.Exec(query, args...)
  227. }
  228. type EmptyScanner struct {
  229. }
  230. func (EmptyScanner) Scan(src interface{}) error {
  231. return nil
  232. }
  233. type Tx struct {
  234. *sql.Tx
  235. Mapper IMapper
  236. }
  237. func (db *DB) Begin() (*Tx, error) {
  238. tx, err := db.DB.Begin()
  239. if err != nil {
  240. return nil, err
  241. }
  242. return &Tx{tx, db.Mapper}, nil
  243. }
  244. func (tx *Tx) Prepare(query string) (*Stmt, error) {
  245. names := make(map[string]int)
  246. var i int
  247. query = re.ReplaceAllStringFunc(query, func(src string) string {
  248. names[src[1:]] = i
  249. i += 1
  250. return "?"
  251. })
  252. stmt, err := tx.Tx.Prepare(query)
  253. if err != nil {
  254. return nil, err
  255. }
  256. return &Stmt{stmt, tx.Mapper, names}, nil
  257. }
  258. func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  259. // TODO:
  260. return stmt
  261. }
  262. func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
  263. query, args, err := MapToSlice(query, mp)
  264. if err != nil {
  265. return nil, err
  266. }
  267. return tx.Tx.Exec(query, args...)
  268. }
  269. func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
  270. query, args, err := StructToSlice(query, st)
  271. if err != nil {
  272. return nil, err
  273. }
  274. return tx.Tx.Exec(query, args...)
  275. }
  276. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
  277. rows, err := tx.Tx.Query(query, args...)
  278. if err != nil {
  279. return nil, err
  280. }
  281. return &Rows{rows, tx.Mapper}, nil
  282. }
  283. func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
  284. query, args, err := MapToSlice(query, mp)
  285. if err != nil {
  286. return nil, err
  287. }
  288. return tx.Query(query, args...)
  289. }
  290. func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
  291. query, args, err := StructToSlice(query, st)
  292. if err != nil {
  293. return nil, err
  294. }
  295. return tx.Query(query, args...)
  296. }
  297. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
  298. rows, err := tx.Query(query, args...)
  299. return &Row{rows, err}
  300. }
  301. func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
  302. query, args, err := MapToSlice(query, mp)
  303. if err != nil {
  304. return &Row{nil, err}
  305. }
  306. return tx.QueryRow(query, args...)
  307. }
  308. func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
  309. query, args, err := StructToSlice(query, st)
  310. if err != nil {
  311. return &Row{nil, err}
  312. }
  313. return tx.QueryRow(query, args...)
  314. }