pytorch_tensorboard_writer_sample.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 07/07/2020
  6. """
  7. import librosa
  8. import numpy
  9. import torch
  10. from librosa.display import specshow
  11. from matplotlib import pyplot
  12. from draugr import PROJECT_APP_PATH
  13. from draugr.torch_utilities import (
  14. TensorBoardPytorchWriter,
  15. constant_init,
  16. fan_in_init,
  17. normal_init,
  18. weight_bias_histograms,
  19. xavier_init,
  20. )
  21. if __name__ == "__main__":
  22. NFFT = 256
  23. STEP_SIZE = NFFT // 2
  24. DELTA = 0.001
  25. time_ = numpy.arange(0, 1, DELTA)
  26. SAMPLING_RATE = int(1 / DELTA)
  27. SIGNAL = numpy.sin(2 * numpy.pi * 50 * time_) + numpy.sin(
  28. 2 * numpy.pi * 120 * time_
  29. )
  30. def module_param_histograms() -> None:
  31. """
  32. :rtype: None
  33. """
  34. with TensorBoardPytorchWriter(
  35. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  36. ) as writer:
  37. input_f = 4
  38. n_classes = 10
  39. num_updates = 20
  40. model = torch.nn.Sequential(
  41. torch.nn.Linear(input_f, 20),
  42. torch.nn.ReLU(),
  43. torch.nn.Linear(20, n_classes),
  44. torch.nn.LogSoftmax(-1),
  45. )
  46. for i in range(num_updates):
  47. normal_init(
  48. model, 0.2 * float((i - num_updates * 0.5) ** 2), 1 / (i + 1)
  49. )
  50. weight_bias_histograms(writer, model, step=i, prefix="normal")
  51. xavier_init(model)
  52. weight_bias_histograms(writer, model, step=i, prefix="xavier")
  53. constant_init(model, i)
  54. weight_bias_histograms(writer, model, step=i, prefix="constant")
  55. fan_in_init(model)
  56. weight_bias_histograms(writer, model, step=i, prefix="fan_in")
  57. def signal_plot() -> None:
  58. """
  59. :rtype: None
  60. """
  61. with TensorBoardPytorchWriter(
  62. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  63. ) as writer:
  64. writer.line("Signal", SIGNAL, step=0)
  65. def fft_plot() -> None:
  66. """
  67. :rtype: None
  68. """
  69. with TensorBoardPytorchWriter(
  70. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  71. ) as writer:
  72. spectral = numpy.fft.fft(SIGNAL, NFFT)
  73. writer.line("FFT", spectral, title="Frequency", step=0)
  74. def spectral_plot() -> None:
  75. """
  76. :rtype: None
  77. """
  78. with TensorBoardPytorchWriter(
  79. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  80. ) as writer:
  81. writer.spectrogram(
  82. "STFT", SIGNAL, int(1 / DELTA), step=0, n_fft=NFFT, step_size=STEP_SIZE
  83. )
  84. def spectral_plot_scipy() -> None:
  85. """
  86. :rtype: None
  87. """
  88. with TensorBoardPytorchWriter(
  89. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  90. ) as writer:
  91. writer.spectrogram(
  92. "STFT_Scipy",
  93. SIGNAL,
  94. int(1 / DELTA),
  95. step=0,
  96. n_fft=NFFT,
  97. step_size=STEP_SIZE,
  98. )
  99. def cepstral_plot() -> None:
  100. """
  101. :rtype: None
  102. """
  103. with TensorBoardPytorchWriter(
  104. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  105. ) as writer:
  106. fig = pyplot.figure()
  107. stft = librosa.core.stft(SIGNAL, n_fft=NFFT, hop_length=STEP_SIZE)
  108. specshow(stft, sr=SAMPLING_RATE, x_axis="time")
  109. pyplot.colorbar()
  110. writer.figure("STFT_Rosa", fig, step=0)
  111. def mel_cepstral_plot() -> None:
  112. """
  113. :rtype: None
  114. """
  115. with TensorBoardPytorchWriter(
  116. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  117. ) as writer:
  118. fig = pyplot.figure()
  119. mfccs = librosa.feature.mfcc(
  120. SIGNAL, sr=SAMPLING_RATE, n_mfcc=20, n_fft=NFFT, hop_length=STEP_SIZE
  121. )
  122. specshow(mfccs, sr=SAMPLING_RATE, x_axis="time")
  123. pyplot.colorbar()
  124. writer.figure("MFCC_Rosa", fig, step=0)
  125. module_param_histograms()
  126. signal_plot()
  127. fft_plot()
  128. spectral_plot()
  129. spectral_plot_scipy()
  130. cepstral_plot()
  131. mel_cepstral_plot()