spv.coopmat.comp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #version 450 core
  2. #extension GL_KHR_memory_scope_semantics : enable
  3. #extension GL_NV_cooperative_matrix : enable
  4. #extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable
  5. #extension GL_EXT_buffer_reference : enable
  6. layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
  7. const int X = 8;
  8. layout(constant_id = 0) const int Y = 2;
  9. const int Z = X*Y;
  10. fcoopmatNV<16, gl_ScopeSubgroup, Z, 8> mC;
  11. fcoopmatNV<16, gl_ScopeSubgroup, Z, 8> mC2[3];
  12. int arr[mC.length()];
  13. int arr2[mC2[1].length()];
  14. layout(constant_id = 1) const float F = 3.0;
  15. const fcoopmatNV<32, gl_ScopeSubgroup, Z, 8> mD = fcoopmatNV<32, gl_ScopeSubgroup, Z, 8>(0.0);
  16. const fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> mD2 = fcoopmatNV<16, gl_ScopeSubgroup, 8, 8>(1);
  17. struct S { int a; int b; int c; };
  18. const S s = S(12, 23, 34);
  19. layout(set = 0, binding = 0, buffer_reference) coherent buffer Block {
  20. float y[1024*1024];
  21. float x[];
  22. } block;
  23. layout(set = 0, binding = 0) coherent buffer Block16 {
  24. float16_t y[1024*1024];
  25. float16_t x[];
  26. Block b;
  27. } block16;
  28. fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> f16(fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> m) { return -m; }
  29. fcoopmatNV<32, gl_ScopeSubgroup, 8, 8> f32(fcoopmatNV<32, gl_ScopeSubgroup, 8, 8> m) { return -m; }
  30. layout(constant_id = 2) const int SC = 1;
  31. fcoopmatNV<16, gl_ScopeSubgroup, SC, SC> scm[SC][SC];
  32. // sized for fcoopmatNV<16, gl_ScopeSubgroup, 16, 16>
  33. shared uvec4 shmatrix[16*16*2/16];
  34. void main()
  35. {
  36. fcoopmatNV<32, gl_ScopeSubgroup, 16, (2>1?8:4)> m = fcoopmatNV<32, gl_ScopeSubgroup, 16, (2>1?8:4)>(0.0);
  37. m = m + m;
  38. m = m - m;
  39. m = -m;
  40. m = 2.0*m;
  41. m = m*2.0;
  42. fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> m2 = fcoopmatNV<16, gl_ScopeSubgroup, 16, 8>(m);
  43. float x = m[1];
  44. m[0] = x;
  45. coopMatLoadNV(m, block.x, 16, 128, false);
  46. coopMatStoreNV(m, block.x, 16, 128, false);
  47. coopMatLoadNV(m2, block16.x, 16, 128, false);
  48. coopMatStoreNV(m2, block16.x, 16, 128, false);
  49. coopMatLoadNV(m, block16.b.x, 16, 128, false);
  50. coopMatStoreNV(m, block16.b.x, 16, 128, false);
  51. fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> A;
  52. fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> B;
  53. fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> C;
  54. fcoopmatNV<32, gl_ScopeSubgroup, 16, 8> D;
  55. D = coopMatMulAddNV(A, B, C);
  56. int l = D.length();
  57. fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> E;
  58. fcoopmatNV<16, gl_ScopeSubgroup, Z, Z> F = fcoopmatNV<16, gl_ScopeSubgroup, Z, Z>(0.0);
  59. fcoopmatNV<32, gl_ScopeSubgroup, 16, (2>1?8:4)> a[5];
  60. a[3][0] = 1.0;
  61. float md1 = mD[1];
  62. md1 += (m += m)[1234];
  63. mC2[1] = mC2[2];
  64. coopMatLoadNV(m, block.y, 16, 128, false);
  65. coopMatStoreNV(m, block.y, 16, 128, false);
  66. coopMatLoadNV(m2, block16.y, 16, 128, false);
  67. coopMatStoreNV(m2, block16.y, 16, 128, false);
  68. fcoopmatNV<16, gl_ScopeSubgroup, 8, 8> p1;
  69. fcoopmatNV<32, gl_ScopeSubgroup, 8, 8> p2;
  70. p1 = f16(p1);
  71. p2 = f32(p2);
  72. p1 = fcoopmatNV<16, gl_ScopeSubgroup, 8, 8>(0.0);
  73. p2 = fcoopmatNV<32, gl_ScopeSubgroup, 8, 8>(0.0);
  74. p1 /= p1;
  75. p1 *= float16_t(2.0);
  76. p2 *= 4.0;
  77. fcoopmatNV<16, gl_ScopeSubgroup, 16, 8> ms;
  78. coopMatLoadNV(ms, shmatrix, 1, 2, false);
  79. coopMatStoreNV(ms, shmatrix, 1, 2, false);
  80. }