RSACipher.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. #pragma once
  2. #include <openssl/err.h>
  3. #include <openssl/pem.h>
  4. #include <openssl/bio.h>
  5. #include <openssl/rsa.h>
  6. #include "Exception.hpp"
  7. #include "ExceptionOpenssl.hpp"
  8. #include "ResourceOwned.hpp"
  9. #include "ResourceTraitsOpenssl.hpp"
  10. #undef NKG_CURRENT_SOURCE_FILE
  11. #undef NKG_CURRENT_SOURCE_LINE
  12. #define NKG_CURRENT_SOURCE_FILE() TEXT(".\\common\\RSACipher.hpp")
  13. #define NKG_CURRENT_SOURCE_LINE() __LINE__
  14. namespace nkg {
  15. enum class RSAKeyType {
  16. PrivateKey,
  17. PublicKey
  18. };
  19. enum class RSAKeyFormat {
  20. PEM,
  21. PKCS1
  22. };
  23. class RSACipher final : private ResourceOwned<OpensslRSATraits> {
  24. private:
  25. template<RSAKeyType __Type, RSAKeyFormat __Format>
  26. static void _WriteRSAToBIO(RSA* lpRSA, BIO* lpBIO) {
  27. if constexpr (__Type == RSAKeyType::PrivateKey) {
  28. if (PEM_write_bio_RSAPrivateKey(lpBIO, lpRSA, nullptr, nullptr, 0, nullptr, nullptr) == 0) {
  29. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSAPrivateKey failed."));
  30. }
  31. }
  32. if constexpr (__Type == RSAKeyType::PublicKey) {
  33. if constexpr (__Format == RSAKeyFormat::PEM) {
  34. if (PEM_write_bio_RSA_PUBKEY(lpBIO, lpRSA) == 0) {
  35. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSA_PUBKEY failed."));
  36. }
  37. }
  38. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  39. if (PEM_write_bio_RSAPublicKey(lpBIO, lpRSA) == 0) {
  40. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSAPublicKey failed."));
  41. }
  42. }
  43. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  44. }
  45. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  46. }
  47. template<RSAKeyType __Type, RSAKeyFormat __Format>
  48. [[nodiscard]]
  49. static RSA* _ReadRSAFromBIO(BIO* lpBIO) {
  50. RSA* lpRSA;
  51. if constexpr (__Type == RSAKeyType::PrivateKey) {
  52. lpRSA = PEM_read_bio_RSAPrivateKey(lpBIO, nullptr, nullptr, nullptr);
  53. if (lpRSA == nullptr) {
  54. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSAPrivateKey failed."))
  55. .AddHint(TEXT("Are you sure that you DO provide a valid RSA private key file?"));
  56. }
  57. }
  58. if constexpr (__Type == RSAKeyType::PublicKey) {
  59. if constexpr (__Format == RSAKeyFormat::PEM) {
  60. lpRSA = PEM_read_bio_RSA_PUBKEY(lpBIO, nullptr, nullptr, nullptr);
  61. if (lpRSA == nullptr) {
  62. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSA_PUBKEY failed."))
  63. .AddHint(TEXT("Are you sure that you DO provide a valid RSA public key file with PEM format?"));
  64. }
  65. }
  66. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  67. lpRSA = PEM_read_bio_RSAPublicKey(lpBIO, nullptr, nullptr, nullptr);
  68. if (lpRSA == nullptr) {
  69. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSAPublicKey failed."))
  70. .AddHint(TEXT("Are you sure that you DO provide a valid RSA public key file with PKCS1 format?"));
  71. }
  72. }
  73. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  74. }
  75. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  76. return lpRSA;
  77. }
  78. public:
  79. RSACipher() : ResourceOwned<OpensslRSATraits>(RSA_new()) {
  80. if (IsValid() == false) {
  81. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_new failed."));
  82. }
  83. }
  84. [[nodiscard]]
  85. size_t Bits() const {
  86. RSA* rsa = Get();
  87. if (RSA_get0_n(rsa) == nullptr) {
  88. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("RSA modulus has not been set."));
  89. } else {
  90. return BN_num_bits(RSA_get0_n(rsa));
  91. }
  92. }
  93. void GenerateKey(int bits, unsigned int e = RSA_F4) {
  94. ResourceOwned<OpensslBNTraits> bn_e(BN_new());
  95. if (bn_e.IsValid() == false) {
  96. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("BN_new failed."));
  97. }
  98. if (!BN_set_word(bn_e, e)) {
  99. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BN_set_word failed."));
  100. }
  101. if (!RSA_generate_key_ex(Get(), bits, bn_e, nullptr)) {
  102. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_generate_key_ex failed."));
  103. }
  104. }
  105. template<RSAKeyType __Type, RSAKeyFormat __Format>
  106. void ExportKeyToFile(const std::xstring& FileName) const {
  107. ResourceOwned<OpensslBIOTraits> BioFile(BIO_new_file(FileName.explicit_string(CP_UTF8).c_str(), "w"));
  108. if (BioFile.IsValid() == false) {
  109. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new_file failed."));
  110. }
  111. _WriteRSAToBIO<__Type, __Format>(Get(), BioFile);
  112. }
  113. template<RSAKeyType __Type, RSAKeyFormat __Format>
  114. [[nodiscard]]
  115. std::string ExportKeyString() const {
  116. ResourceOwned<OpensslBIOTraits> BioMemory(BIO_new(BIO_s_mem()));
  117. long StringLength;
  118. const char* StringChars = nullptr;
  119. if (BioMemory.IsValid() == false) {
  120. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new failed."));
  121. }
  122. _WriteRSAToBIO<__Type, __Format>(Get(), BioMemory);
  123. StringLength = BIO_get_mem_data(BioMemory.Get(), &StringChars);
  124. return std::string(StringChars, StringLength);
  125. }
  126. template<RSAKeyType __Type, RSAKeyFormat __Format>
  127. void ImportKeyFromFile(const std::xstring& FileName) {
  128. ResourceOwned<OpensslBIOTraits> BioFile(BIO_new_file(FileName.explicit_string(CP_UTF8).c_str(), "r"));
  129. if (BioFile.IsValid() == false) {
  130. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new_file failed."));
  131. }
  132. TakeOver(_ReadRSAFromBIO<__Type, __Format>(BioFile));
  133. }
  134. template<RSAKeyType __Type, RSAKeyFormat __Format>
  135. void ImportKeyString(const std::string& KeyString) {
  136. ResourceOwned<OpensslBIOTraits> BioMemory(BIO_new(BIO_s_mem()));
  137. if (BioMemory.IsValid() == false) {
  138. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new failed."));
  139. }
  140. if (BIO_puts(BioMemory.Get(), KeyString.c_str()) <= 0) {
  141. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_puts failed."));
  142. }
  143. TakeOver(_ReadRSAFromBIO<__Type, __Format>(BioMemory));
  144. }
  145. template<RSAKeyType __Type = RSAKeyType::PublicKey>
  146. size_t Encrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  147. int BytesWritten;
  148. if (cbFrom > INT_MAX) {
  149. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("Length overflowed."));
  150. }
  151. if constexpr (__Type == RSAKeyType::PrivateKey) {
  152. BytesWritten = RSA_private_encrypt(
  153. static_cast<int>(cbFrom),
  154. reinterpret_cast<const unsigned char*>(lpFrom),
  155. reinterpret_cast<unsigned char*>(lpTo),
  156. Get(),
  157. Padding
  158. );
  159. if (BytesWritten == -1) {
  160. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_private_encrypt failed."));
  161. }
  162. } else {
  163. BytesWritten = RSA_public_encrypt(
  164. static_cast<int>(cbFrom),
  165. reinterpret_cast<const unsigned char*>(lpFrom),
  166. reinterpret_cast<unsigned char*>(lpTo),
  167. Get(),
  168. Padding
  169. );
  170. if (BytesWritten == -1) {
  171. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_public_encrypt failed."));
  172. }
  173. }
  174. return BytesWritten;
  175. }
  176. template<RSAKeyType __Type = RSAKeyType::PrivateKey>
  177. size_t Decrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  178. int BytesWritten;
  179. if (cbFrom > INT_MAX) {
  180. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("Length overflowed."));
  181. }
  182. if constexpr (__Type == RSAKeyType::PrivateKey) {
  183. BytesWritten = RSA_private_decrypt(
  184. static_cast<int>(cbFrom),
  185. reinterpret_cast<const unsigned char*>(lpFrom),
  186. reinterpret_cast<unsigned char*>(lpTo),
  187. Get(),
  188. Padding
  189. );
  190. if (BytesWritten == -1) {
  191. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_private_decrypt failed."))
  192. .AddHint(TEXT("Are your sure you DO provide a correct private key?"));
  193. }
  194. } else {
  195. BytesWritten = RSA_public_decrypt(
  196. static_cast<int>(cbFrom),
  197. reinterpret_cast<const unsigned char*>(lpFrom),
  198. reinterpret_cast<unsigned char*>(lpTo),
  199. Get(),
  200. Padding
  201. );
  202. if (BytesWritten == -1) {
  203. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_public_decrypt failed."))
  204. .AddHint(TEXT("Are your sure you DO provide a correct public key?"));
  205. }
  206. }
  207. return BytesWritten;
  208. }
  209. };
  210. }