combine_metadata.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """
  2. Combine clusters into metadata.csv. To be run after `prune_metadata.csv`.
  3. """
  4. import numpy as np
  5. import pandas as pd
  6. from tqdm import tqdm
  7. from typing import Any, Dict
  8. from src.utils.utils import from_json
  9. def main(args: Any):
  10. df: pd.DataFrame = pd.read_csv(args.metadata_joined_csv)
  11. user_clusters = from_json(args.user_clusters_json)
  12. exchange_clusters = from_json(args.exchange_clusters_json)
  13. print('adding user clusters...')
  14. user_map: Dict[str, int] = {}
  15. pbar = tqdm(total=len(user_clusters))
  16. for i, cluster in enumerate(user_clusters):
  17. for address in cluster:
  18. user_map[address] = i
  19. pbar.update()
  20. pbar.close()
  21. print('adding exchange clusters...')
  22. exchange_map: Dict[str, int] = {}
  23. pbar = tqdm(total=len(exchange_clusters))
  24. for i, cluster in enumerate(exchange_clusters):
  25. for address in cluster:
  26. exchange_map[address] = i
  27. pbar.update()
  28. pbar.close()
  29. df['user_cluster'] = df.address.apply(
  30. lambda address: user_map.get(address, np.nan),
  31. )
  32. df['exchange_cluster'] = df.address.apply(
  33. lambda address: exchange_map.get(address, np.nan),
  34. )
  35. # cast to the right type
  36. df['user_cluster'] = df['user_cluster'].astype(pd.Int64Dtype())
  37. df['exchange_cluster'] = df['exchange_cluster'].astype(pd.Int64Dtype())
  38. print('saving to disk...')
  39. df.to_csv(args.out_csv, index=False)
  40. if __name__ == "__main__":
  41. import argparse
  42. parser = argparse.ArgumentParser()
  43. parser.add_argument('metadata_joined_csv', type=str)
  44. parser.add_argument('user_clusters_json', type=str)
  45. parser.add_argument('exchange_clusters_json', type=str)
  46. parser.add_argument('out_csv', type=str)
  47. args = parser.parse_args()
  48. main(args)