torch_tensorboard_tsne.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 04-11-2020
  6. """
  7. import csv
  8. import os
  9. import re
  10. from pathlib import Path
  11. import numpy
  12. import pandas
  13. import torch
  14. from PIL import Image
  15. from torchvision import models, transforms
  16. if __name__ == "__main__":
  17. def main(im_path="images"):
  18. """description"""
  19. def get_vector(input_image):
  20. """description"""
  21. image = input_image.convert(
  22. "RGB"
  23. ) # in case input image is not in RGB format
  24. img_t = transform(image)
  25. batch_t = torch.unsqueeze(img_t, 0)
  26. my_embedding = torch.zeros([1, 512, 1, 1])
  27. def copy_data(m, i, o):
  28. """description"""
  29. my_embedding.copy_(o.data)
  30. h = layer.register_forward_hook(copy_data)
  31. model(batch_t)
  32. h.remove()
  33. return my_embedding.squeeze().cpu().numpy()
  34. model = models.resnet18(pretrained=True)
  35. layer = model._modules.get("avgpool")
  36. model.eval()
  37. transform = transforms.Compose(
  38. [
  39. transforms.Resize(256),
  40. transforms.CenterCrop(224),
  41. transforms.ToTensor(),
  42. transforms.Normalize(
  43. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  44. ),
  45. ]
  46. )
  47. im_names = [
  48. str(Path(root) / name)
  49. for root, dirs, files in os.walk(im_path)
  50. for name in files
  51. if name.endswith(".jpg")
  52. ]
  53. existing_images_df = pandas.DataFrame(
  54. [re.findall(r"[\w']+", im_name)[1:3] for im_name in im_names],
  55. columns=["cat_id", "pid"],
  56. )
  57. existing_images_df["impath"] = im_names
  58. vecs = [
  59. list(get_vector(Image.open(impath)))
  60. for _, pid, impath in existing_images_df.values
  61. ]
  62. with open("vis/feature_vecs.tsv", "w") as fw:
  63. csv_writer = csv.writer(fw, delimiter="\t")
  64. csv_writer.writerows(vecs)
  65. images = [
  66. Image.open(filename).resize((300, 300))
  67. for filename in existing_images_df["impath"]
  68. ]
  69. image_width, image_height = images[0].size
  70. one_square_size = int(numpy.ceil(numpy.sqrt(len(images))))
  71. master_width = image_width * one_square_size
  72. master_height = image_height * one_square_size
  73. spriteimage = Image.new(
  74. mode="RGBA", size=(master_width, master_height), color=(0, 0, 0, 0)
  75. ) # fully transparent
  76. for count, image in enumerate(images):
  77. div, mod = divmod(count, one_square_size)
  78. h_loc = image_width * div
  79. w_loc = image_width * mod
  80. spriteimage.paste(image, (w_loc, h_loc))
  81. spriteimage.convert("RGB").save("sprite.jpg", transparency=0)
  82. metadata = existing_images_df[["cat_id", "pid"]]
  83. metadata.to_csv("vis/metadata.tsv", sep="\t", index=False)
  84. main()