demix.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. # Attempt very simple 2-signal separation (demixing) of a stereo .wav file.
  3. # https://en.wikipedia.org/wiki/Signal_separation
  4. import scipy
  5. from os.path import dirname, join as pjoin
  6. from scipy.io import wavfile
  7. import scipy.io
  8. import scipy
  9. import scipy.optimize
  10. import math
  11. import numpy as np
  12. class Optipro:
  13. def __init__(self, wave):
  14. self.wave = wave
  15. def combi1(self, p1, p2):
  16. return math.cos(p1) * self.wave[:,0] + math.sin(p1) * self.wave[:,1]
  17. def combi2(self, p1, p2):
  18. return math.sin(p2) * self.wave[:,0] + math.cos(p2) * self.wave[:,1]
  19. def score(self, p1, p2):
  20. combination1 = self.combi1(p1, p2)
  21. combination2 = self.combi2(p1, p2)
  22. p12 = np.dot(combination1, combination2)
  23. return abs(p12)
  24. def main(wav_fname):
  25. samplerate, data = wavfile.read(wav_fname, mmap=True)
  26. nchan = data.shape[1]
  27. length = data.shape[0] / samplerate
  28. print(f"channels = {nchan} length = {length}s")
  29. data = data.astype(np.float32)
  30. datax = data[:]
  31. op = Optipro(data)
  32. import matplotlib.pyplot as plt
  33. xs = np.linspace(0, math.pi, 20)
  34. ys = np.linspace(0, math.pi, 20)
  35. # https://stackoverflow.com/questions/22774726/numpy-evaluate-function-on-a-grid-of-points
  36. def f(x, y):
  37. return op.score(x, y)
  38. X, Y = np.meshgrid(xs, ys)
  39. # print([X, Y].reshape)
  40. Z = np.fromiter(map(f, X.ravel(), Y.ravel()), X.dtype).reshape(X.shape)
  41. plt.contourf(X, Y, Z, 64, alpha=.75, cmap='jet')
  42. contours = plt.contour(X, Y, Z, 4, colors='black')
  43. plt.clabel(contours, inline=True, fontsize=8)
  44. plt.show()
  45. z = scipy.optimize.minimize(lambda x: op.score(x[0], x[1]), [0.1, 0.2],
  46. method="BFGS",
  47. options={'finite_diff_rel_step':1e-4, 'norm':2},
  48. jac='3-point')
  49. print(z)
  50. p1 = z.x[0]
  51. p2 = z.x[1]
  52. print([p1, p2])
  53. print("at 0,0: ", op.score(0.0, 0.0))
  54. print(op.score(p1, p2))
  55. putsol0 = op.combi1(p1, p2)
  56. putsol1 = op.combi2(p1, p2)
  57. morigl = math.sqrt(np.dot(data[:,0], data[:,0]))
  58. morigr = math.sqrt(np.dot(data[:,1], data[:,1]))
  59. morig = (morigl + morigr)/2.0
  60. n0 = math.sqrt(np.dot(putsol0, putsol0))
  61. n1 = math.sqrt(np.dot(putsol1, putsol1))
  62. print(morigl, morigr, morig, n0, n1, morig/n0, morig/n1)
  63. putsol0 *= morig/n0
  64. putsol1 *= morig/n1
  65. m0 = max(abs(putsol0))
  66. m1 = max(abs(putsol1))
  67. putsol0 /= m0
  68. putsol1 /= m1
  69. scipy.io.wavfile.write("source0.wav", samplerate, putsol0)
  70. scipy.io.wavfile.write("source1.wav", samplerate, putsol1)
  71. if __name__=="__main__":
  72. import sys
  73. filename = "input.wav"
  74. if len(sys.argv) > 1:
  75. filename = sys.argv[1]
  76. main(filename)