dump_neighbors.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. """
  2. After running `get_neighbors.py` we will dump the information
  3. into a CSV, converting numbers back to address string.
  4. """
  5. import os
  6. import json
  7. import numpy as np
  8. import pandas as pd
  9. from tqdm import tqdm
  10. from typing import Any, List, Dict
  11. from src.utils.utils import from_json
  12. def main(args: Any):
  13. distances: np.array = np.load(args.distance_file)
  14. neighbors: np.array = np.load(args.neighbor_file)
  15. index2addr: Dict[str, int] = from_json(args.address_file)
  16. size: int = len(distances)
  17. address_df: List[str] = []
  18. distance_df: List[str] = []
  19. neighbor_df: List[str] = []
  20. for index in tqdm(range(size)):
  21. distance: List[float] = distances[index].tolist()
  22. distance: str = json.dumps(distance)
  23. neighbor: List[int] = neighbors[index].tolist()
  24. neighbor: List[str] = [index2addr[str(nei)] for nei in neighbor if nei >= 0]
  25. neighbor: str = json.dumps(neighbor)
  26. distance_df.append(distance)
  27. neighbor_df.append(neighbor)
  28. address_df.append(index2addr[str(index)])
  29. df_dict: Dict[str, List[str]] = dict(
  30. address=address_df,
  31. distance=distance_df,
  32. neighbor=neighbor_df,
  33. )
  34. df: pd.DataFrame = pd.DataFrame.from_dict(df_dict)
  35. df.to_csv(os.path.join(args.save_dir, 'diff2vec-processed.csv'), index=False)
  36. if __name__ == "__main__":
  37. import argparse
  38. parser = argparse.ArgumentParser()
  39. parser.add_argument('distance_file', type=str, help='path to distance numpy file.')
  40. parser.add_argument('neighbor_file', type=str, help='path to neighbor numpy file.')
  41. parser.add_argument('address_file', type=str, help='path to address lookup file.')
  42. parser.add_argument('save_dir', type=str, help='where to save outpouts.')
  43. args: Any = parser.parse_args()
  44. main(args)