run_nx.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """
  2. To scale the run_deposit.py script, we had to forgo creating
  3. NX graph in memory. This script does exactly that. The motivation
  4. is to isolate the high memory parts to a single file.
  5. Input: user_clusters.json, exchange_clusters.json,
  6. """
  7. import os
  8. import itertools
  9. import numpy as np
  10. import pandas as pd
  11. import networkx as nx
  12. from src.utils.utils import to_json, from_json
  13. from typing import Any, List, Set, Tuple
  14. def main(args: Any):
  15. data: pd.DataFrame = pd.read_csv(args.data_file)
  16. gas_price_sets: List[Set[str]] = from_json(args.gas_price_file)
  17. multi_denom_sets: List[Set[str]] = from_json(args.multi_denom_file)
  18. print('making user graph...', end = '', flush=True)
  19. user_graph: nx.DiGraph = make_graph(data.user, data.deposit)
  20. print('adding gas price nodes...', end = '', flush=True)
  21. user_graph: nx.DiGraph = add_to_user_graph(user_graph, gas_price_sets)
  22. print('adding multi denom nodes...', end = '', flush=True)
  23. user_graph: nx.DiGraph = add_to_user_graph(user_graph, multi_denom_sets)
  24. print('making exchange graph...', end = '', flush=True)
  25. exchange_graph: nx.DiGraph = make_graph(data.deposit, data.exchange)
  26. print('making user wcc...', end = '', flush=True)
  27. user_wccs: List[Set[str]] = get_wcc(user_graph)
  28. # algorithm 1 line 13
  29. # We actually want to keep this information!
  30. # user_wccs: List[Set[str]] = self._remove_deposits(
  31. # user_wccs,
  32. # set(store.deposit.to_numpy().tolist()),
  33. # )
  34. print('making exchange wcc...', end = '', flush=True)
  35. exchange_wccs: List[Set[str]] = get_wcc(exchange_graph)
  36. # prune trivial clusters
  37. user_wccs: List[Set[str]] = remove_singletons(user_wccs)
  38. exchange_wccs: List[Set[str]] = remove_singletons(exchange_wccs)
  39. if not os.path.isdir(args.save_dir):
  40. os.makedirs(args.save_dir)
  41. print('writing to disk...\n', end = '', flush=True)
  42. to_json(user_wccs, os.path.join(args.save_dir, 'user_clusters.json'))
  43. to_json(exchange_wccs, os.path.join(args.save_dir, 'exchange_clusters.json'))
  44. def add_to_user_graph(graph: nx.DiGraph, clusters: List[Set[str]]):
  45. for cluster in clusters:
  46. assert len(cluster) == 2, "Only supports edges with two nodes."
  47. node_a, node_b = cluster
  48. graph.add_node(node_a)
  49. graph.add_node(node_b)
  50. graph.add_edge(node_a, node_b)
  51. return graph
  52. def get_wcc(graph: nx.DiGraph) -> List[Set[str]]:
  53. comp_iter: Any = nx.weakly_connected_components(graph)
  54. comps: List[Set[str]] = [c for c in comp_iter]
  55. return comps
  56. def remove_deposits(components: List[Set[str]], deposit: Set[str]):
  57. # remove all deposit addresses from wcc list
  58. new_components: List[Set[str]] = []
  59. for component in components:
  60. new_component: Set[str] = component - deposit
  61. new_components.append(new_component)
  62. return new_components
  63. def remove_singletons(components: List[Set[str]]):
  64. # remove clusters with just one entity... these are not interesting.
  65. return [c for c in components if len(c) > 1]
  66. def make_graph(node_a: pd.Series, node_b: pd.Series) -> nx.DiGraph:
  67. """
  68. DEPRECATED: This assumes we can store all connections in memory.
  69. Make a directed graph connecting each row of node_a to the
  70. corresponding row of node_b.
  71. """
  72. assert node_a.size == node_b.size, "Dataframes are uneven sizes."
  73. graph: nx.DiGraph = nx.DiGraph()
  74. nodes: np.array = np.concatenate([node_a.unique(), node_b.unique()])
  75. edges: List[Tuple[str, str]] = list(
  76. zip(node_a.to_numpy(), node_b.to_numpy())
  77. )
  78. graph.add_nodes_from(nodes)
  79. graph.add_edges_from(edges)
  80. return graph
  81. if __name__ == "__main__":
  82. import argparse
  83. parser = argparse.ArgumentParser()
  84. parser.add_argument('data_file', type=str, help='path to cached out of deposit.py')
  85. parser.add_argument('gas_price_file', type=str, help='path to gas price address sets')
  86. parser.add_argument('multi_denom_file', type=str, help='path to gas price address sets')
  87. parser.add_argument('save_dir', type=str, help='where to save files.')
  88. args: Any = parser.parse_args()
  89. main(args)