RSACipher.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. #pragma once
  2. #include <openssl/opensslv.h>
  3. #include <openssl/err.h>
  4. #include <openssl/pem.h>
  5. #include <openssl/bio.h>
  6. #include <openssl/rsa.h>
  7. #include <string>
  8. #include "Exception.hpp"
  9. #include "ExceptionOpenssl.hpp"
  10. #include "ResourceWrapper.hpp"
  11. #include "ResourceTraitsOpenssl.hpp"
  12. namespace nkg {
  13. enum class RSAKeyType {
  14. PrivateKey,
  15. PublicKey
  16. };
  17. enum class RSAKeyFormat {
  18. PEM,
  19. PKCS1
  20. };
  21. class RSACipher final : private ARL::ResourceWrapper<ARL::ResourceTraits::OpensslRSA> {
  22. private:
  23. template<RSAKeyType __Type, RSAKeyFormat __Format>
  24. static void _WriteRSAToBIO(RSA* lpRSA, BIO* lpBIO) {
  25. if constexpr (__Type == RSAKeyType::PrivateKey) {
  26. if (PEM_write_bio_RSAPrivateKey(lpBIO, lpRSA, nullptr, nullptr, 0, nullptr, nullptr) == 0) {
  27. throw ARL::Exception(__BASE_FILE__, __LINE__, "PEM_write_bio_RSAPrivateKey failed.");
  28. }
  29. }
  30. if constexpr (__Type == RSAKeyType::PublicKey) {
  31. if constexpr (__Format == RSAKeyFormat::PEM) {
  32. if (PEM_write_bio_RSA_PUBKEY(lpBIO, lpRSA) == 0) {
  33. throw ARL::Exception(__BASE_FILE__, __LINE__, "PEM_write_bio_RSA_PUBKEY failed.");
  34. }
  35. }
  36. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  37. if (PEM_write_bio_RSAPublicKey(lpBIO, lpRSA) == 0) {
  38. throw ARL::Exception(__BASE_FILE__, __LINE__, "PEM_write_bio_RSAPublicKey failed.");
  39. }
  40. }
  41. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  42. }
  43. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  44. }
  45. template<RSAKeyType __Type, RSAKeyFormat __Format>
  46. [[nodiscard]]
  47. static RSA* _ReadRSAFromBIO(BIO* lpBIO) {
  48. RSA* lpRSA;
  49. if constexpr (__Type == RSAKeyType::PrivateKey) {
  50. lpRSA = PEM_read_bio_RSAPrivateKey(lpBIO, nullptr, nullptr, nullptr);
  51. if (lpRSA == nullptr) {
  52. throw ARL::Exception(__BASE_FILE__, __LINE__, "PEM_read_bio_RSAPrivateKey failed.")
  53. .PushHint("Are you sure that you DO provide a valid RSA private key file?");
  54. }
  55. }
  56. if constexpr (__Type == RSAKeyType::PublicKey) {
  57. if constexpr (__Format == RSAKeyFormat::PEM) {
  58. lpRSA = PEM_read_bio_RSA_PUBKEY(lpBIO, nullptr, nullptr, nullptr);
  59. if (lpRSA == nullptr) {
  60. throw ARL::Exception(__BASE_FILE__, __LINE__, "PEM_read_bio_RSA_PUBKEY failed.")
  61. .PushHint("Are you sure that you DO provide a valid RSA public key file with PEM format?");
  62. }
  63. }
  64. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  65. lpRSA = PEM_read_bio_RSAPublicKey(lpBIO, nullptr, nullptr, nullptr);
  66. if (lpRSA == nullptr) {
  67. throw ARL::Exception(__BASE_FILE__, __LINE__, "PEM_read_bio_RSAPublicKey failed.")
  68. .PushHint("Are you sure that you DO provide a valid RSA public key file with PKCS1 format?");
  69. }
  70. }
  71. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  72. }
  73. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  74. return lpRSA;
  75. }
  76. public:
  77. RSACipher() : ARL::ResourceWrapper<ARL::ResourceTraits::OpensslRSA>(RSA_new()) {
  78. if (IsValid() == false) {
  79. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "RSA_new failed.");
  80. }
  81. }
  82. [[nodiscard]]
  83. size_t Bits() const {
  84. #if (OPENSSL_VERSION_NUMBER & 0xffff0000) == 0x10000000 // openssl 1.0.x
  85. if (Get()->n == nullptr) {
  86. throw ARL::Exception(__BASE_FILE__, __LINE__, "RSA modulus has not been set.");
  87. } else {
  88. return BN_num_bits(Get()->n);
  89. }
  90. #elif (OPENSSL_VERSION_NUMBER & 0xffff0000) == 0x10100000 // openssl 1.1.x
  91. return RSA_bits(Get());
  92. #else
  93. #error "Unexpected openssl version!"
  94. #endif
  95. }
  96. void GenerateKey(int bits, unsigned int e = RSA_F4) {
  97. ARL::ResourceWrapper bn_e{ ARL::ResourceTraits::OpensslBIGNUM{} };
  98. bn_e.TakeOver(BN_new());
  99. if (bn_e.IsValid() == false) {
  100. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "BN_new failed.");
  101. }
  102. if (!BN_set_word(bn_e, e)) {
  103. throw ARL::Exception(__BASE_FILE__, __LINE__, "BN_set_word failed.");
  104. }
  105. if (!RSA_generate_key_ex(Get(), bits, bn_e, nullptr)) {
  106. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "RSA_generate_key_ex failed.");
  107. }
  108. }
  109. template<RSAKeyType __Type, RSAKeyFormat __Format>
  110. void ExportKeyToFile(std::string_view FileName) const {
  111. ARL::ResourceWrapper KeyFile{ ARL::ResourceTraits::OpensslBIO{} };
  112. KeyFile.TakeOver(BIO_new_file(FileName.data(), "w"));
  113. if (KeyFile.IsValid() == false) {
  114. throw ARL::Exception(__BASE_FILE__, __LINE__, "BIO_new_file failed.");
  115. }
  116. _WriteRSAToBIO<__Type, __Format>(Get(), KeyFile);
  117. }
  118. template<RSAKeyType __Type, RSAKeyFormat __Format>
  119. [[nodiscard]]
  120. std::string ExportKeyString() const {
  121. ARL::ResourceWrapper TempMemory{ ARL::ResourceTraits::OpensslBIO{} };
  122. const char* lpsz = nullptr;
  123. TempMemory.TakeOver(BIO_new(BIO_s_mem()));
  124. if (TempMemory.IsValid() == false) {
  125. throw ARL::Exception(__BASE_FILE__, __LINE__, "BIO_new failed.");
  126. }
  127. _WriteRSAToBIO<__Type, __Format>(Get(), TempMemory);
  128. auto l = BIO_get_mem_data(TempMemory.Get(), &lpsz);
  129. std::string KeyString(lpsz, l);
  130. while (KeyString.back() == '\n' || KeyString.back() == '\r') {
  131. KeyString.pop_back();
  132. }
  133. return KeyString;
  134. }
  135. template<RSAKeyType __Type, RSAKeyFormat __Format>
  136. void ImportKeyFromFile(std::string_view FileName) {
  137. ARL::ResourceWrapper KeyFile{ ARL::ResourceTraits::OpensslBIO{} };
  138. KeyFile.TakeOver(BIO_new_file(FileName.data(), "r"));
  139. if (KeyFile.IsValid() == false) {
  140. throw ARL::Exception(__BASE_FILE__, __LINE__, "BIO_new_file failed.");
  141. }
  142. ReleaseAndTakeOver(_ReadRSAFromBIO<__Type, __Format>(KeyFile));
  143. }
  144. template<RSAKeyType __Type, RSAKeyFormat __Format>
  145. void ImportKeyString(std::string_view KeyString) {
  146. ARL::ResourceWrapper TempMemory{ ARL::ResourceTraits::OpensslBIO{} };
  147. TempMemory.TakeOver(BIO_new(BIO_s_mem()));
  148. if (TempMemory.IsValid() == false) {
  149. throw ARL::Exception(__BASE_FILE__, __LINE__, "BIO_new failed.");
  150. }
  151. if (BIO_puts(TempMemory.Get(), KeyString.data()) <= 0) {
  152. throw ARL::Exception(__BASE_FILE__, __LINE__, "BIO_puts failed.");
  153. }
  154. TakeOver(_ReadRSAFromBIO<__Type, __Format>(TempMemory));
  155. }
  156. template<RSAKeyType __Type = RSAKeyType::PublicKey>
  157. size_t Encrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  158. int BytesWritten;
  159. if (cbFrom > static_cast<size_t>(INT_MAX)) {
  160. throw ARL::Exception(__BASE_FILE__, __LINE__, "Length overflowed.");
  161. }
  162. if constexpr (__Type == RSAKeyType::PrivateKey) {
  163. BytesWritten = RSA_private_encrypt(
  164. static_cast<int>(cbFrom),
  165. reinterpret_cast<const unsigned char*>(lpFrom),
  166. reinterpret_cast<unsigned char*>(lpTo),
  167. Get(),
  168. Padding
  169. );
  170. if (BytesWritten == -1) {
  171. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "RSA_private_encrypt failed.");
  172. }
  173. } else {
  174. BytesWritten = RSA_public_encrypt(
  175. static_cast<int>(cbFrom),
  176. reinterpret_cast<const unsigned char*>(lpFrom),
  177. reinterpret_cast<unsigned char*>(lpTo),
  178. Get(),
  179. Padding
  180. );
  181. if (BytesWritten == -1) {
  182. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "RSA_public_encrypt failed.");
  183. }
  184. }
  185. return BytesWritten;
  186. }
  187. template<RSAKeyType __Type = RSAKeyType::PrivateKey>
  188. size_t Decrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  189. int BytesWritten;
  190. if (cbFrom > static_cast<size_t>(INT_MAX)) {
  191. throw ARL::Exception(__BASE_FILE__, __LINE__, "Length overflowed.");
  192. }
  193. if constexpr (__Type == RSAKeyType::PrivateKey) {
  194. BytesWritten = RSA_private_decrypt(
  195. static_cast<int>(cbFrom),
  196. reinterpret_cast<const unsigned char*>(lpFrom),
  197. reinterpret_cast<unsigned char*>(lpTo),
  198. Get(),
  199. Padding
  200. );
  201. if (BytesWritten == -1) {
  202. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "RSA_private_decrypt failed.")
  203. .PushHint("Are your sure you DO provide a correct private key?");
  204. }
  205. } else {
  206. BytesWritten = RSA_public_decrypt(
  207. static_cast<int>(cbFrom),
  208. reinterpret_cast<const unsigned char*>(lpFrom),
  209. reinterpret_cast<unsigned char*>(lpTo),
  210. Get(),
  211. Padding
  212. );
  213. if (BytesWritten == -1) {
  214. throw ARL::OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(), "RSA_public_decrypt failed.")
  215. .PushHint("Are your sure you DO provide a correct public key?");
  216. }
  217. }
  218. return BytesWritten;
  219. }
  220. };
  221. }