torch_tensorboard_tsne.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 04-11-2020
  6. """
  7. import csv
  8. import os
  9. import re
  10. import numpy
  11. import pandas as pd
  12. import torch
  13. import torchvision.models as models
  14. import torchvision.transforms as transforms
  15. from PIL import Image
  16. if __name__ == "__main__":
  17. def main(im_path="images"):
  18. """ """
  19. def get_vector(input_image):
  20. """ """
  21. image = input_image.convert(
  22. "RGB"
  23. ) # in case input image is not in RGB format
  24. img_t = transform(image)
  25. batch_t = torch.unsqueeze(img_t, 0)
  26. my_embedding = torch.zeros([1, 512, 1, 1])
  27. def copy_data(m, i, o):
  28. """ """
  29. my_embedding.copy_(o.data)
  30. h = layer.register_forward_hook(copy_data)
  31. model(batch_t)
  32. h.remove()
  33. return my_embedding.squeeze().cpu().numpy()
  34. model = models.resnet18(pretrained=True)
  35. layer = model._modules.get("avgpool")
  36. model.eval()
  37. transform = transforms.Compose(
  38. [
  39. transforms.Resize(256),
  40. transforms.CenterCrop(224),
  41. transforms.ToTensor(),
  42. transforms.Normalize(
  43. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  44. ),
  45. ]
  46. )
  47. im_names = [
  48. os.path.join(root, name)
  49. for root, dirs, files in os.walk(im_path)
  50. for name in files
  51. if name.endswith(".jpg")
  52. ]
  53. existing_images_df = pd.DataFrame(
  54. [re.findall(r"[\w']+", im_name)[1:3] for im_name in im_names],
  55. columns=["cat_id", "pid"],
  56. )
  57. existing_images_df["impath"] = im_names
  58. vecs = [
  59. list(get_vector(Image.open(impath)))
  60. for _, pid, impath in existing_images_df.values
  61. ]
  62. with open("vis/feature_vecs.tsv", "w") as fw:
  63. csv_writer = csv.writer(fw, delimiter="\t")
  64. csv_writer.writerows(vecs)
  65. images = [
  66. Image.open(filename).resize((300, 300))
  67. for filename in existing_images_df["impath"]
  68. ]
  69. image_width, image_height = images[0].size
  70. one_square_size = int(numpy.ceil(numpy.sqrt(len(images))))
  71. master_width = image_width * one_square_size
  72. master_height = image_height * one_square_size
  73. spriteimage = Image.new(
  74. mode="RGBA", size=(master_width, master_height), color=(0, 0, 0, 0)
  75. ) # fully transparent
  76. for count, image in enumerate(images):
  77. div, mod = divmod(count, one_square_size)
  78. h_loc = image_width * div
  79. w_loc = image_width * mod
  80. spriteimage.paste(image, (w_loc, h_loc))
  81. spriteimage.convert("RGB").save("sprite.jpg", transparency=0)
  82. metadata = existing_images_df[["cat_id", "pid"]]
  83. metadata.to_csv("vis/metadata.tsv", sep="\t", index=False)
  84. main()