test_ens.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. Test how well deposit address reuse is doing by computing recall
  3. of a fixed set of ENS clusters.
  4. """
  5. import os
  6. import pandas as pd
  7. from tqdm import tqdm
  8. from typing import Any, Set, List, Optional
  9. from app.models import Address
  10. class TestENSClusters:
  11. def __init__(self, csv_file: str):
  12. self._csv_file: str = csv_file
  13. self._df = pd.read_csv(csv_file)
  14. self._clusters: List[Set[str]] = self._get_clusters(self._df)
  15. def _get_clusters(self, df):
  16. clusters: List[Set[str]] = []
  17. for _, group in df.groupby('name'):
  18. cluster: Set[str] = set(group.address)
  19. if len(cluster) > 1: # solo clusters are not worthwhile
  20. clusters.append(cluster)
  21. return clusters
  22. def _get_prediction(self, address):
  23. addr: Optional[Address] = \
  24. Address.query.filter_by(address = address).first()
  25. if addr is not None:
  26. assert addr.entity == 1, "Address must be an EOA."
  27. cluster: List[Address] = []
  28. if (addr.user_cluster is not None) and (addr.user_cluster != -1):
  29. cluster: List[Address] = Address.query.filter_by(
  30. user_cluster = addr.user_cluster).limit(100000).all()
  31. if cluster is not None:
  32. cluster += cluster
  33. cluster = set([
  34. c.address for c in cluster if c.entity == 0]) # EOA only
  35. else: # if no address, then just return itself
  36. cluster: Set[str] = {address}
  37. return cluster
  38. def evaluate(self):
  39. avg_precision: float = 0
  40. avg_recall: float = 0
  41. pbar = tqdm(total=len(self._clusters))
  42. for cluster in self._clusters:
  43. address: str = list(cluster)[0] # representative address
  44. pred_cluster: Set[str] = self._get_prediction(address)
  45. tp: int = 0 # true positive
  46. fp: int = 0 # false positive
  47. for member in pred_cluster:
  48. member: str = member
  49. if member in cluster:
  50. tp += 1
  51. else:
  52. fp += 1
  53. fn: int = 0 # false negatives
  54. for member in cluster:
  55. member: str = member
  56. if member not in pred_cluster:
  57. fn += 1
  58. precision: float = get_precision(tp, fp)
  59. recall: float = get_recall(tp, fn)
  60. avg_precision += precision
  61. avg_recall += recall
  62. pbar.update()
  63. pbar.close()
  64. avg_precision /= float(len(self._clusters))
  65. avg_recall /= float(len(self._clusters))
  66. return {'precision': avg_precision, 'recall': avg_recall}
  67. def get_precision(tp, fp):
  68. return tp / float(tp + fp)
  69. def get_recall(tp, fn):
  70. return tp / float(tp + fn)
  71. if __name__ == "__main__":
  72. import argparse
  73. cur_dir = os.path.dirname(__file__)
  74. parser = argparse.ArgumentParser()
  75. parser.add_argument(
  76. '--csv-file',
  77. type=str,
  78. default=os.path.join(cur_dir, 'data/ens_pairs.csv'),
  79. help='path to csv file of ENS clusters',
  80. )
  81. parser.add_argument(
  82. '--max-cluster-size',
  83. type=int,
  84. default=100000,
  85. help='maximum amount of clusters to show',
  86. )
  87. args: Any = parser.parse_args()
  88. print(TestENSClusters(args.csv_file).evaluate())