RSACipher.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #pragma once
  2. #include <openssl/err.h>
  3. #include <openssl/pem.h>
  4. #include <openssl/bio.h>
  5. #include <openssl/rsa.h>
  6. #include "Exception.hpp"
  7. #include "ExceptionOpenssl.hpp"
  8. #include "ResourceOwned.hpp"
  9. #include "ResourceTraitsOpenssl.hpp"
  10. #undef NKG_CURRENT_SOURCE_FILE
  11. #undef NKG_CURRENT_SOURCE_LINE
  12. #define NKG_CURRENT_SOURCE_FILE() TEXT(".\\common\\RSACipher.hpp")
  13. #define NKG_CURRENT_SOURCE_LINE() __LINE__
  14. namespace nkg {
  15. enum class RSAKeyType {
  16. PrivateKey,
  17. PublicKey
  18. };
  19. enum class RSAKeyFormat {
  20. PEM,
  21. PKCS1
  22. };
  23. class RSACipher final : private ResourceOwned<OpensslRSATraits> {
  24. private:
  25. template<RSAKeyType __Type, RSAKeyFormat __Format>
  26. static void _WriteRSAToBIO(RSA* lpRSA, BIO* lpBIO) {
  27. if constexpr (__Type == RSAKeyType::PrivateKey) {
  28. if (PEM_write_bio_RSAPrivateKey(lpBIO, lpRSA, nullptr, nullptr, 0, nullptr, nullptr) == 0) {
  29. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSAPrivateKey failed."));
  30. }
  31. }
  32. if constexpr (__Type == RSAKeyType::PublicKey) {
  33. if constexpr (__Format == RSAKeyFormat::PEM) {
  34. if (PEM_write_bio_RSA_PUBKEY(lpBIO, lpRSA) == 0) {
  35. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSA_PUBKEY failed."));
  36. }
  37. }
  38. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  39. if (PEM_write_bio_RSAPublicKey(lpBIO, lpRSA) == 0) {
  40. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_write_bio_RSAPublicKey failed."));
  41. }
  42. }
  43. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  44. }
  45. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  46. }
  47. template<RSAKeyType __Type, RSAKeyFormat __Format>
  48. [[nodiscard]]
  49. static RSA* _ReadRSAFromBIO(BIO* lpBIO) {
  50. RSA* lpRSA;
  51. if constexpr (__Type == RSAKeyType::PrivateKey) {
  52. lpRSA = PEM_read_bio_RSAPrivateKey(lpBIO, nullptr, nullptr, nullptr);
  53. if (lpRSA == nullptr) {
  54. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSAPrivateKey failed."))
  55. .AddHint(TEXT("Are you sure that you DO provide a valid RSA private key file?"));
  56. }
  57. }
  58. if constexpr (__Type == RSAKeyType::PublicKey) {
  59. if constexpr (__Format == RSAKeyFormat::PEM) {
  60. lpRSA = PEM_read_bio_RSA_PUBKEY(lpBIO, nullptr, nullptr, nullptr);
  61. if (lpRSA == nullptr) {
  62. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSA_PUBKEY failed."))
  63. .AddHint(TEXT("Are you sure that you DO provide a valid RSA public key file with PEM format?"));
  64. }
  65. }
  66. if constexpr (__Format == RSAKeyFormat::PKCS1) {
  67. lpRSA = PEM_read_bio_RSAPublicKey(lpBIO, nullptr, nullptr, nullptr);
  68. if (lpRSA == nullptr) {
  69. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("PEM_read_bio_RSAPublicKey failed."))
  70. .AddHint(TEXT("Are you sure that you DO provide a valid RSA public key file with PKCS1 format?"));
  71. }
  72. }
  73. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  74. }
  75. static_assert(__Type == RSAKeyType::PrivateKey || __Type == RSAKeyType::PublicKey);
  76. return lpRSA;
  77. }
  78. public:
  79. RSACipher() : ResourceOwned<OpensslRSATraits>(RSA_new()) {
  80. if (IsValid() == false) {
  81. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_new failed."));
  82. }
  83. }
  84. [[nodiscard]]
  85. size_t Bits() const {
  86. if (Get()->n == nullptr) {
  87. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("RSA modulus has not been set."));
  88. } else {
  89. return BN_num_bits(Get()->n);
  90. }
  91. }
  92. void GenerateKey(int bits, unsigned int e = RSA_F4) {
  93. ResourceOwned<OpensslBNTraits> bn_e(BN_new());
  94. if (bn_e.IsValid() == false) {
  95. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("BN_new failed."));
  96. }
  97. if (!BN_set_word(bn_e, e)) {
  98. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BN_set_word failed."));
  99. }
  100. if (!RSA_generate_key_ex(Get(), bits, bn_e, nullptr)) {
  101. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_generate_key_ex failed."));
  102. }
  103. }
  104. template<RSAKeyType __Type, RSAKeyFormat __Format>
  105. void ExportKeyToFile(const std::xstring& FileName) const {
  106. ResourceOwned<OpensslBIOTraits> BioFile(BIO_new_file(FileName.explicit_string(CP_UTF8).c_str(), "w"));
  107. if (BioFile.IsValid() == false) {
  108. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new_file failed."));
  109. }
  110. _WriteRSAToBIO<__Type, __Format>(Get(), BioFile);
  111. }
  112. template<RSAKeyType __Type, RSAKeyFormat __Format>
  113. [[nodiscard]]
  114. std::string ExportKeyString() const {
  115. ResourceOwned<OpensslBIOTraits> BioMemory(BIO_new(BIO_s_mem()));
  116. long StringLength;
  117. const char* StringChars = nullptr;
  118. if (BioMemory.IsValid() == false) {
  119. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new failed."));
  120. }
  121. _WriteRSAToBIO<__Type, __Format>(Get(), BioMemory);
  122. StringLength = BIO_get_mem_data(BioMemory.Get(), &StringChars);
  123. return std::string(StringChars, StringLength);
  124. }
  125. template<RSAKeyType __Type, RSAKeyFormat __Format>
  126. void ImportKeyFromFile(const std::xstring& FileName) {
  127. ResourceOwned<OpensslBIOTraits> BioFile(BIO_new_file(FileName.explicit_string(CP_UTF8).c_str(), "r"));
  128. if (BioFile.IsValid() == false) {
  129. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new_file failed."));
  130. }
  131. TakeOver(_ReadRSAFromBIO<__Type, __Format>(BioFile));
  132. }
  133. template<RSAKeyType __Type, RSAKeyFormat __Format>
  134. void ImportKeyString(const std::string& KeyString) {
  135. ResourceOwned<OpensslBIOTraits> BioMemory(BIO_new(BIO_s_mem()));
  136. if (BioMemory.IsValid() == false) {
  137. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_new failed."));
  138. }
  139. if (BIO_puts(BioMemory.Get(), KeyString.c_str()) <= 0) {
  140. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("BIO_puts failed."));
  141. }
  142. TakeOver(_ReadRSAFromBIO<__Type, __Format>(BioMemory));
  143. }
  144. template<RSAKeyType __Type = RSAKeyType::PublicKey>
  145. size_t Encrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  146. int BytesWritten;
  147. if (cbFrom > INT_MAX) {
  148. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("Length overflowed."));
  149. }
  150. if constexpr (__Type == RSAKeyType::PrivateKey) {
  151. BytesWritten = RSA_private_encrypt(
  152. static_cast<int>(cbFrom),
  153. reinterpret_cast<const unsigned char*>(lpFrom),
  154. reinterpret_cast<unsigned char*>(lpTo),
  155. Get(),
  156. Padding
  157. );
  158. if (BytesWritten == -1) {
  159. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_private_encrypt failed."));
  160. }
  161. } else {
  162. BytesWritten = RSA_public_encrypt(
  163. static_cast<int>(cbFrom),
  164. reinterpret_cast<const unsigned char*>(lpFrom),
  165. reinterpret_cast<unsigned char*>(lpTo),
  166. Get(),
  167. Padding
  168. );
  169. if (BytesWritten == -1) {
  170. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_public_encrypt failed."));
  171. }
  172. }
  173. return BytesWritten;
  174. }
  175. template<RSAKeyType __Type = RSAKeyType::PrivateKey>
  176. size_t Decrypt(const void* lpFrom, size_t cbFrom, void* lpTo, int Padding) const {
  177. int BytesWritten;
  178. if (cbFrom > INT_MAX) {
  179. throw Exception(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), TEXT("Length overflowed."));
  180. }
  181. if constexpr (__Type == RSAKeyType::PrivateKey) {
  182. BytesWritten = RSA_private_decrypt(
  183. static_cast<int>(cbFrom),
  184. reinterpret_cast<const unsigned char*>(lpFrom),
  185. reinterpret_cast<unsigned char*>(lpTo),
  186. Get(),
  187. Padding
  188. );
  189. if (BytesWritten == -1) {
  190. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_private_decrypt failed."))
  191. .AddHint(TEXT("Are your sure you DO provide a correct private key?"));
  192. }
  193. } else {
  194. BytesWritten = RSA_public_decrypt(
  195. static_cast<int>(cbFrom),
  196. reinterpret_cast<const unsigned char*>(lpFrom),
  197. reinterpret_cast<unsigned char*>(lpTo),
  198. Get(),
  199. Padding
  200. );
  201. if (BytesWritten == -1) {
  202. throw OpensslError(NKG_CURRENT_SOURCE_FILE(), NKG_CURRENT_SOURCE_LINE(), ERR_get_error(), TEXT("RSA_public_decrypt failed."))
  203. .AddHint(TEXT("Are your sure you DO provide a correct public key?"));
  204. }
  205. }
  206. return BytesWritten;
  207. }
  208. };
  209. }