09-pooling.diff 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. diff --git a/src/llama.cpp b/src/llama.cpp
  2. index 721b8f4e..cfe7ac40 100644
  3. --- a/src/llama.cpp
  4. +++ b/src/llama.cpp
  5. @@ -8420,14 +8420,14 @@ struct llm_build_context {
  6. }
  7. struct ggml_tensor * build_inp_mean() {
  8. - lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
  9. + lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max);
  10. cb(lctx.inp_mean, "inp_mean", -1);
  11. ggml_set_input(lctx.inp_mean);
  12. return lctx.inp_mean;
  13. }
  14. struct ggml_tensor * build_inp_cls() {
  15. - lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
  16. + lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max);
  17. cb(lctx.inp_cls, "inp_cls", -1);
  18. ggml_set_input(lctx.inp_cls);
  19. return lctx.inp_cls;
  20. @@ -13847,19 +13847,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
  21. GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
  22. float * data = (float *) lctx.inp_mean->data;
  23. - memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
  24. + memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean));
  25. std::vector<uint64_t> sum(n_tokens, 0);
  26. for (int i = 0; i < n_tokens; ++i) {
  27. const llama_seq_id seq_id = batch.seq_id[i][0];
  28. -
  29. - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
  30. -
  31. sum[seq_id] += 1;
  32. }
  33. - std::vector<float> div(n_tokens, 0.0f);
  34. - for (int i = 0; i < n_tokens; ++i) {
  35. + std::vector<float> div(cparams.n_seq_max, 0.0f);
  36. + for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
  37. const uint64_t s = sum[i];
  38. if (s > 0) {
  39. div[i] = 1.0f/float(s);
  40. @@ -13879,14 +13876,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
  41. GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
  42. uint32_t * data = (uint32_t *) lctx.inp_cls->data;
  43. - memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
  44. + memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls));
  45. for (int i = 0; i < n_tokens; ++i) {
  46. const llama_seq_id seq_id = batch.seq_id[i][0];
  47. const llama_pos pos = batch.pos[i];
  48. -
  49. - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
  50. -
  51. if (pos == 0) {
  52. data[seq_id] = i;
  53. }