run_word2vec.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. from typing import Any
  3. from src.diff2vec.word2vec import Word2Vec
  4. from gensim.models.callbacks import CallbackAny2Vec
  5. class TrainCallback(CallbackAny2Vec):
  6. """
  7. Save model every epoch and log epoch completion.
  8. """
  9. def __init__(self, save_dir: str):
  10. self._save_dir: str = save_dir
  11. self._epoch: int = 0
  12. def on_train_begin(self, _):
  13. print('training start')
  14. def on_train_end(self, _):
  15. print('done')
  16. def on_epoch_begin(self, _):
  17. print(f'epoch {self._epoch} start')
  18. self._epoch += 1
  19. def on_epoch_end(self, model):
  20. out_path: str = os.path.join(self._save_dir, f'word2vec-epoch{self._epoch}.model')
  21. print(f'epoch {self._epoch} end')
  22. model.save(out_path)
  23. print(f'saved model to {out_path}')
  24. def main(args: Any):
  25. cache_dir: str = os.path.join(args.model_dir)
  26. Word2Vec(
  27. corpus_file = args.corpus_file,
  28. corpus_size = args.corpus_size,
  29. vector_size = args.dim,
  30. workers = args.workers,
  31. epochs = args.epochs,
  32. alpha = args.lr,
  33. seed = args.seed,
  34. cache_dir = cache_dir,
  35. callbacks = [TrainCallback(args.model_dir)],
  36. )
  37. if __name__ == "__main__":
  38. from argparse import ArgumentParser
  39. parser: ArgumentParser = ArgumentParser()
  40. parser.add_argument('corpus_file', type=str, help='path to cached sequences')
  41. parser.add_argument('model_dir', type=str, help='path to save model')
  42. parser.add_argument('--corpus-size', type=int, default=263644512, help='epochs (default: 263644512)')
  43. parser.add_argument('--epochs', type=int, default=5, help='epochs (default: 5)')
  44. parser.add_argument('--workers', type=int, default=4, help='workers (default: 4)')
  45. parser.add_argument('--min-count', type=int, default=5, help='min count (default: 5)')
  46. parser.add_argument('--lr', type=float, default=0.025, help='learning rate (default: 0.025)')
  47. parser.add_argument('--dim', type=float, default=128, help='dimensionality (default: 128)')
  48. parser.add_argument('--seed', type=float, default=42, help='random seed (default: 42)')
  49. args: Any = parser.parse_args()
  50. main(args)