get_neighbors.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """
  2. We can't store the Word2Vec in RAM on the server so we should
  3. instead store a map from address index -> clusters of indices.
  4. """
  5. import os
  6. import faiss
  7. import numpy as np
  8. from typing import Any, List
  9. from tqdm import tqdm
  10. def main(args: Any):
  11. print('loading vectors...', end=' ')
  12. vectors: np.array = np.load(args.vectors_npy)
  13. print('done')
  14. size = vectors.shape[0]
  15. # https://github.com/facebookresearch/faiss/issues/112
  16. nlist = int(4 * np.sqrt(size))
  17. if args.index_file is None:
  18. quantizer: faiss.IndexFlatL2 = faiss.IndexFlatL2(128)
  19. index: faiss.IndexIVFFlat = \
  20. faiss.IndexIVFFlat(quantizer, 128, nlist, faiss.METRIC_L2)
  21. assert not index.is_trained
  22. print('training FAISS index...', end=' ')
  23. index.train(vectors)
  24. print('done')
  25. assert index.is_trained
  26. print('adding vectors to index...', end=' ')
  27. index.add(vectors)
  28. print('done')
  29. print('saving to disk...', end=' ')
  30. faiss.write_index(index, os.path.join(args.save_dir, 'faiss.index'))
  31. print('done')
  32. else:
  33. print('reading to disk...', end=' ')
  34. index = faiss.read_index(os.path.join(args.save_dir, 'faiss.index'))
  35. print('done')
  36. print('computing neighbors')
  37. distances: List[np.array] = []
  38. neighbors: List[np.array] = []
  39. batch_size: int = 100
  40. num_batches: int = (size // batch_size) + int(size % batch_size)
  41. for i in tqdm(range(num_batches)):
  42. query: np.array = vectors[batch_size*i:batch_size*(i+1)]
  43. D, I = index.search(query, args.k)
  44. distances.append(D)
  45. neighbors.append(I)
  46. distances = np.concatenate(distances, axis=0)
  47. neighbors = np.concatenate(neighbors, axis=0)
  48. np.save(os.path.join(args.save_dir, f'distances-k{args.k}.npy'), distances)
  49. np.save(os.path.join(args.save_dir, f'neighbors-k{args.k}.npy'), neighbors)
  50. if __name__ == "__main__":
  51. import argparse
  52. parser = argparse.ArgumentParser()
  53. parser.add_argument('vectors_npy', type=str, help='path to trained word2vec vectors.')
  54. parser.add_argument('save_dir', type=str, help='where to save outputs.')
  55. parser.add_argument('--index-file', type=str, default=None,
  56. help='optional path to cached index file')
  57. parser.add_argument('--k', type=int, default=10, help='number of neighbors to find.')
  58. args: Any = parser.parse_args()
  59. main(args)