RSACipher.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #pragma once
  2. #include <openssl/err.h>
  3. #include <openssl/pem.h>
  4. #include <openssl/bio.h>
  5. #include <openssl/rsa.h>
  6. #include "Exception.hpp"
  7. #include "ExceptionOpenssl.hpp"
  8. #include "ResourceOwned.hpp"
  9. #include "ResourceTraitsOpenssl.hpp"
  10. #pragma comment(lib, "crypt32") // required by libcrypto.lib
  11. #pragma comment(lib, "ws2_32") // required by libcrypto.lib
  12. #undef NKG_CURRENT_SOURCE_FILE
  13. #undef NKG_CURRENT_SOURCE_LINE
  14. #define NKG_CURRENT_SOURCE_FILE() TEXT(".\\common\\RSACipher.hpp")
  15. #define NKG_CURRENT_SOURCE_LINE() __LINE__
  16. namespace nkg {
  17. enum class RSAKeyType {
  18. PrivateKey,
  19. PublicKey
  20. };
  21. enum class RSAKeyFormat {
  22. PEM,
  23. PKCS1
  24. };
  25. class RSACipher final : private ResourceOwned<OpensslRSATraits> {
  26. private:
  27. template<RSAKeyType __Type, RSAKeyFormat __Format>
  28. static void _WriteRSAToBIO(RSA* lpRSA, BIO* lpBIO) {
  29. if constexpr (__Type == RSAKeyType::PrivateKey) {
  30. if (PEM_write_bio_RSAPrivateKey(lpBIO, lpRSA, nullptr, nullptr, 0, nullptr, nullptr) == 0) {
  31. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSAPrivateKey failed."));
  32. }
  33. }
  34. if constexpr (__Type == RSAKeyType::PublicKey) {
  35. if constexpr (__Format == RSAKeyFormat::PEM) {
  36. if (PEM_write_bio_RSA_PUBKEY(lpBIO, lpRSA) == 0) {
  37. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSA_PUBKEY failed."));
  38. }
  39. }
  40. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  41. if (PEM_write_bio_RSAPublicKey(lpBIO, lpRSA) == 0) {
  42. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSAPublicKey failed."));
  43. }
  44. }
  45. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  46. }
  47. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  48. }
  49. template<RSAKeyType __Type, RSAKeyFormat __Format>
  50. [[nodiscard]]
  51. static RSA* _ReadRSAFromBIO(BIO* lpBIO) {
  52. RSA* lpRSA;
  53. if constexpr (__Type == RSAKeyType::PrivateKey) {
  54. lpRSA = PEM_read_bio_RSAPrivateKey(lpBIO, nullptr, nullptr, nullptr);
  55. if (lpRSA == nullptr) {
  56. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSAPrivateKey failed."))
  57. .AddHint(TEXT("Are you sure that you DO provide a valid RSA private key file?"));
  58. }
  59. }
  60. if constexpr (__Type == RSAKeyType::PublicKey) {
  61. if constexpr (__Format == RSAKeyFormat::PEM) {
  62. lpRSA = PEM_read_bio_RSA_PUBKEY(lpBIO, nullptr, nullptr, nullptr);
  63. if (lpRSA == nullptr) {
  64. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSA_PUBKEY failed."))
  65. .AddHint(TEXT("Are you sure that you DO provide a valid RSA public key file with PEM format?"));
  66. }
  67. }
  68. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  69. lpRSA = PEM_read_bio_RSAPublicKey(lpBIO, nullptr, nullptr, nullptr);
  70. if (lpRSA == nullptr) {
  71. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSAPublicKey failed."))
  72. .AddHint(TEXT("Are you sure that you DO provide a valid RSA public key file with PKCS1 format?"));
  73. }
  74. }
  75. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  76. }
  77. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  78. return lpRSA;
  79. }
  80. public:
  81. RSACipher() : ResourceOwned<OpensslRSATraits>(RSA_new()) {
  82. if (IsValid() == false) {
  83. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_new failed."));
  84. }
  85. }
  86. [[nodiscard]]
  87. size_t Bits() const {
  88. #if (OPENSSL_VERSION_NUMBER & 0xffff0000) == 0x10000000 // openssl 1.0.x
  89. if (Get()->n == nullptr) {
  90. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("RSA modulus has not been set."));
  91. } else {
  92. return BN_num_bits(Get()->n);
  93. }
  94. #elif (OPENSSL_VERSION_NUMBER & 0xffff0000) == 0x10100000 // openssl 1.1.x
  95. return RSA_bits(Get());
  96. #else
  97. #error "RSACipher.hpp: unexpected OpenSSL version."
  98. #endif
  99. }
  100. void GenerateKey(int bits, unsigned int e = RSA_F4) {
  101. ResourceOwned<OpensslBNTraits> bn_e(BN_new());
  102. if (bn_e.IsValid() == false) {
  103. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("BN_new failed."));
  104. }
  105. if (!BN_set_word(bn_e, e)) {
  106. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BN_set_word failed."));
  107. }
  108. if (!RSA_generate_key_ex(Get(), bits, bn_e, nullptr)) {
  109. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_generate_key_ex failed."));
  110. }
  111. }
  112. template<RSAKeyType __Type, RSAKeyFormat __Format>
  113. void ExportKeyToFile(const std::xstring& FileName) const {
  114. ResourceOwned<OpensslBIOTraits> BioFile(BIO_new_file(FileName.explicit_string(CP_UTF8).c_str(), "w"));
  115. if (BioFile.IsValid() == false) {
  116. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new_file failed."));
  117. }
  118. _WriteRSAToBIO<__Type, __Format>(Get(), BioFile);
  119. }
  120. template<RSAKeyType __Type, RSAKeyFormat __Format>
  121. [[nodiscard]]
  122. std::string ExportKeyString() const {
  123. ResourceOwned<OpensslBIOTraits> BioMemory(BIO_new(BIO_s_mem()));
  124. long StringLength;
  125. const char* StringChars = nullptr;
  126. if (BioMemory.IsValid() == false) {
  127. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new failed."));
  128. }
  129. _WriteRSAToBIO<__Type, __Format>(Get(), BioMemory);
  130. StringLength = BIO_get_mem_data(BioMemory.Get(), &StringChars);
  131. return std::string(StringChars, StringLength);
  132. }
  133. template<RSAKeyType __Type, RSAKeyFormat __Format>
  134. void ImportKeyFromFile(const std::xstring& FileName) {
  135. ResourceOwned<OpensslBIOTraits> BioFile(BIO_new_file(FileName.explicit_string(CP_UTF8).c_str(), "r"));
  136. if (BioFile.IsValid() == false) {
  137. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new_file failed."));
  138. }
  139. TakeOver(_ReadRSAFromBIO<__Type, __Format>(BioFile));
  140. }
  141. template<RSAKeyType __Type, RSAKeyFormat __Format>
  142. void ImportKeyString(const std::string& KeyString) {
  143. ResourceOwned<OpensslBIOTraits> BioMemory(BIO_new(BIO_s_mem()));
  144. if (BioMemory.IsValid() == false) {
  145. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new failed."));
  146. }
  147. if (BIO_puts(BioMemory.Get(), KeyString.c_str()) <= 0) {
  148. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_puts failed."));
  149. }
  150. TakeOver(_ReadRSAFromBIO<__Type, __Format>(BioMemory));
  151. }
  152. template<RSAKeyType __Type = RSAKeyType::PublicKey>
  153. size_t Encrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  154. int BytesWritten;
  155. if (cbFrom > INT_MAX) {
  156. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("Length overflowed."));
  157. }
  158. if constexpr (__Type == RSAKeyType::PrivateKey) {
  159. BytesWritten = RSA_private_encrypt(
  160. static_cast<int>(cbFrom),
  161. reinterpret_cast<const unsigned char*>(lpFrom),
  162. reinterpret_cast<unsigned char*>(lpTo),
  163. Get(),
  164. Padding
  165. );
  166. if (BytesWritten == -1) {
  167. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_private_encrypt failed."));
  168. }
  169. } else {
  170. BytesWritten = RSA_public_encrypt(
  171. static_cast<int>(cbFrom),
  172. reinterpret_cast<const unsigned char*>(lpFrom),
  173. reinterpret_cast<unsigned char*>(lpTo),
  174. Get(),
  175. Padding
  176. );
  177. if (BytesWritten == -1) {
  178. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_public_encrypt failed."));
  179. }
  180. }
  181. return BytesWritten;
  182. }
  183. template<RSAKeyType __Type = RSAKeyType::PrivateKey>
  184. size_t Decrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  185. int BytesWritten;
  186. if (cbFrom > INT_MAX) {
  187. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("Length overflowed."));
  188. }
  189. if constexpr (__Type == RSAKeyType::PrivateKey) {
  190. BytesWritten = RSA_private_decrypt(
  191. static_cast<int>(cbFrom),
  192. reinterpret_cast<const unsigned char*>(lpFrom),
  193. reinterpret_cast<unsigned char*>(lpTo),
  194. Get(),
  195. Padding
  196. );
  197. if (BytesWritten == -1) {
  198. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_private_decrypt failed."))
  199. .AddHint(TEXT("Are your sure you DO provide a correct private key?"));
  200. }
  201. } else {
  202. BytesWritten = RSA_public_decrypt(
  203. static_cast<int>(cbFrom),
  204. reinterpret_cast<const unsigned char*>(lpFrom),
  205. reinterpret_cast<unsigned char*>(lpTo),
  206. Get(),
  207. Padding
  208. );
  209. if (BytesWritten == -1) {
  210. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_public_decrypt failed."))
  211. .AddHint(TEXT("Are your sure you DO provide a correct public key?"));
  212. }
  213. }
  214. return BytesWritten;
  215. }
  216. };
  217. }