test_ens.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. """
  2. Test how well deposit address reuse is doing by computing recall
  3. of a fixed set of ENS clusters.
  4. """
  5. import os, json
  6. import pandas as pd
  7. from tqdm import tqdm
  8. from typing import Any, Set, List, Optional
  9. from app.models import Address, Embedding
  10. class TestENSClusters:
  11. def __init__(self, csv_file: str, mode: str = 'dar'):
  12. assert mode in ['dar', 'node', 'both'], 'Unexpected mode.'
  13. self._mode: str = mode
  14. self._csv_file: str = csv_file
  15. self._df = pd.read_csv(csv_file)
  16. self._clusters: List[Set[str]] = self._get_clusters(self._df)
  17. def _get_clusters(self, df):
  18. clusters: List[Set[str]] = []
  19. for _, group in df.groupby('name'):
  20. cluster: Set[str] = set(group.address)
  21. if len(cluster) > 1: # solo clusters are not worthwhile
  22. clusters.append(cluster)
  23. return clusters
  24. def _get_prediction(self, address) -> Set[str]:
  25. if self._mode == 'dar':
  26. return self._get_dar_prediction(address)
  27. elif self._mode == 'node':
  28. return self._get_node_prediction(address)
  29. elif self._mode == 'both':
  30. dar_cluster: Set[str] = self._get_dar_prediction(address)
  31. node_cluster: Set[str] = self._get_node_prediction(address)
  32. cluster: Set[str] = set()
  33. cluster: Set[str] = cluster.union(dar_cluster)
  34. cluster: Set[str] = cluster.union(node_cluster)
  35. return cluster
  36. def _get_dar_prediction(self, address) -> Set[str]:
  37. addr: Optional[Address] = \
  38. Address.query.filter_by(address = address).first()
  39. if addr is not None:
  40. assert addr.entity == 0, "Address must be an EOA."
  41. cluster: List[Address] = []
  42. if (addr.user_cluster is not None) and (addr.user_cluster != -1):
  43. cluster: List[Address] = Address.query.filter_by(
  44. user_cluster = addr.user_cluster).limit(100000).all()
  45. if cluster is not None:
  46. cluster += cluster
  47. cluster = set([
  48. c.address for c in cluster if c.entity == 0]) # EOA only
  49. else: # if no address, then just return itself
  50. cluster: Set[str] = {address}
  51. return cluster
  52. def _get_node_prediction(self, address) -> Set[str]:
  53. node: Optional[Embedding] = \
  54. Embedding.query.filter_by(address = address).first()
  55. if node is not None:
  56. # I mapped neighbors -> distances so this is actually loading neighbors
  57. cluster: List[str] = json.loads(node.distances)
  58. cluster: Set[str] = set(cluster)
  59. else:
  60. cluster: Set[str] = {address}
  61. return cluster
  62. def evaluate(self):
  63. avg_precision: float = 0
  64. avg_recall: float = 0
  65. pbar = tqdm(total=len(self._clusters))
  66. for cluster in self._clusters:
  67. address: str = list(cluster)[0] # representative address
  68. pred_cluster: Set[str] = self._get_prediction(address)
  69. tp: int = 0 # true positive
  70. fp: int = 0 # false positive
  71. for member in pred_cluster:
  72. member: str = member
  73. if member in cluster:
  74. tp += 1
  75. else:
  76. fp += 1
  77. fn: int = 0 # false negatives
  78. for member in cluster:
  79. member: str = member
  80. if member not in pred_cluster:
  81. fn += 1
  82. precision: float = get_precision(tp, fp)
  83. recall: float = get_recall(tp, fn)
  84. avg_precision += precision
  85. avg_recall += recall
  86. pbar.update()
  87. pbar.close()
  88. avg_precision /= float(len(self._clusters))
  89. avg_recall /= float(len(self._clusters))
  90. return {'precision': avg_precision, 'recall': avg_recall}
  91. def get_precision(tp, fp):
  92. return tp / float(tp + fp)
  93. def get_recall(tp, fn):
  94. return tp / float(tp + fn)
  95. if __name__ == "__main__":
  96. import argparse
  97. cur_dir = os.path.dirname(__file__)
  98. parser = argparse.ArgumentParser()
  99. parser.add_argument(
  100. '--csv-file',
  101. type=str,
  102. default=os.path.join(cur_dir, 'data/ens_pairs.csv'),
  103. help='path to csv file of ENS clusters',
  104. )
  105. parser.add_argument(
  106. '--max-cluster-size',
  107. type=int,
  108. default=100000,
  109. help='maximum amount of clusters to show',
  110. )
  111. parser.add_argument(
  112. '--mode',
  113. type=str,
  114. default='dar',
  115. help='which model to use (dar|node|both)',
  116. choices=['dar', 'node', 'both'],
  117. )
  118. args: Any = parser.parse_args()
  119. print(TestENSClusters(args.csv_file).evaluate())