RSACipher.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. #pragma once
  2. #include <openssl/err.h>
  3. #include <openssl/pem.h>
  4. #include <openssl/bio.h>
  5. #include <openssl/rsa.h>
  6. #include <memory.h>
  7. #include "Exception.hpp"
  8. #include "ResourceObject.hpp"
  9. #ifdef _DEBUG
  10. #pragma comment(lib, "libcryptoMTd.lib")
  11. #else
  12. #pragma comment(lib, "libcryptoMT.lib")
  13. #endif
  14. #pragma comment(lib, "WS2_32.lib") // some symbol are used in OpenSSL static lib
  15. #pragma comment(lib, "Crypt32.lib") // some symbol are used in OpenSSL static lib
  16. #undef __BASE_FILE__
  17. #define __BASE_FILE__ TEXT("RSACipher.hpp")
  18. class OpensslException : public Exception {
  19. private:
  20. unsigned long _ErrorCode;
  21. TString _ErrorString;
  22. public:
  23. OpensslException(const TString& FileName,
  24. size_t LineNumber,
  25. unsigned long OpensslErrorCode,
  26. const TString& CustomMsg) noexcept :
  27. Exception(FileName, LineNumber, CustomMsg),
  28. _ErrorCode(OpensslErrorCode),
  29. _ErrorString(TStringBuilder(ERR_error_string(OpensslErrorCode, nullptr), CP_UTF8)) {}
  30. virtual bool HasErrorCode() const noexcept override {
  31. return true;
  32. }
  33. virtual uintptr_t ErrorCode() const noexcept override {
  34. return _ErrorCode;
  35. }
  36. virtual const TString& ErrorString() const noexcept override {
  37. return _ErrorString;
  38. }
  39. };
  40. struct OpensslBIOTraits {
  41. using HandleType = BIO*;
  42. static inline const HandleType InvalidValue = nullptr;
  43. static constexpr auto& Releasor = BIO_free;
  44. };
  45. struct OpensslBIOChainTraits {
  46. using HandleType = BIO*;
  47. static inline const HandleType InvalidValue = nullptr;
  48. static constexpr auto& Releasor = BIO_free_all;
  49. };
  50. struct OpensslBNTraits {
  51. using HandleType = BIGNUM*;
  52. static inline const HandleType InvalidValue = nullptr;
  53. static constexpr auto& Releasor = BN_free;
  54. };
  55. struct OpensslRSATraits {
  56. using HandleType = RSA*;
  57. static inline const HandleType InvalidValue = nullptr;
  58. static constexpr auto& Releasor = RSA_free;
  59. };
  60. enum class RSAKeyType {
  61. PrivateKey,
  62. PublicKey
  63. };
  64. enum class RSAKeyFormat {
  65. NotSpecified,
  66. PEM,
  67. PKCS1
  68. };
  69. class RSACipher {
  70. private:
  71. ResourceObject<OpensslRSATraits> _RSAObj;
  72. RSACipher(RSA* pRsa) noexcept : _RSAObj(pRsa) {}
  73. //
  74. // Copy constructor is not allowed
  75. //
  76. RSACipher(const RSACipher&) = delete;
  77. //
  78. // Copy assignment is not allowed
  79. //
  80. RSACipher& operator=(const RSACipher&) = delete;
  81. template<RSAKeyType __Type, RSAKeyFormat __Format = RSAKeyFormat::NotSpecified>
  82. static void _WriteRSAToBIO(RSA* PtrToRSA, BIO* PtrToBIO) {
  83. if constexpr (__Type == RSAKeyType::PrivateKey) {
  84. if (!PEM_write_bio_RSAPrivateKey(PtrToBIO, PtrToRSA, nullptr, nullptr, 0, nullptr, nullptr))
  85. throw Exception(__BASE_FILE__, __LINE__,
  86. TEXT("PEM_write_bio_RSAPrivateKey failed."));
  87. } else {
  88. if constexpr (__Format == RSAKeyFormat::PEM) {
  89. if (!PEM_write_bio_RSA_PUBKEY(PtrToBIO, PtrToRSA))
  90. throw Exception(__BASE_FILE__, __LINE__,
  91. TEXT("PEM_write_bio_RSA_PUBKEY failed."));
  92. } else if constexpr (__Format == RSAKeyFormat::PKCS1) {
  93. if (!PEM_write_bio_RSAPublicKey(PtrToBIO, PtrToRSA))
  94. throw Exception(__BASE_FILE__, __LINE__,
  95. TEXT("PEM_write_bio_RSAPublicKey failed."));
  96. } else {
  97. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  98. }
  99. }
  100. }
  101. template<RSAKeyType __Type, RSAKeyFormat __Format = RSAKeyFormat::NotSpecified>
  102. static RSA* _ReadRSAFromBIO(BIO* PtrToBIO) {
  103. RSA* PtrToRSA;
  104. if constexpr (__Type == RSAKeyType::PrivateKey) {
  105. PtrToRSA = PEM_read_bio_RSAPrivateKey(PtrToBIO, nullptr, nullptr, nullptr);
  106. if (PtrToRSA == nullptr)
  107. throw Exception(__BASE_FILE__, __LINE__,
  108. TEXT("PEM_read_bio_RSAPrivateKey failed."));
  109. } else {
  110. if constexpr (__Format == RSAKeyFormat::PEM) {
  111. PtrToRSA = PEM_read_bio_RSA_PUBKEY(PtrToBIO, nullptr, nullptr, nullptr);
  112. if (PtrToRSA == nullptr)
  113. throw Exception(__BASE_FILE__, __LINE__,
  114. TEXT("PEM_read_bio_RSA_PUBKEY failed."));
  115. } else if constexpr (__Format == RSAKeyFormat::PKCS1) {
  116. PtrToRSA = PEM_read_bio_RSAPublicKey(PtrToBIO, nullptr, nullptr, nullptr);
  117. if (PtrToRSA == nullptr)
  118. throw Exception(__BASE_FILE__, __LINE__,
  119. TEXT("PEM_read_bio_RSAPublicKey failed."));
  120. } else {
  121. static_assert(__Format == RSAKeyFormat::PEM || __Format == RSAKeyFormat::PKCS1);
  122. }
  123. }
  124. return PtrToRSA;
  125. }
  126. public:
  127. RSACipher() : _RSAObj(RSA_new()) {
  128. if (_RSAObj.IsValid() == false)
  129. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  130. TEXT("RSA_new failed."));
  131. }
  132. void GenerateKey(int bits, unsigned int e = RSA_F4) {
  133. ResourceObject<OpensslBNTraits> bn_e;
  134. bn_e.TakeOver(BN_new());
  135. if (bn_e.IsValid() == false)
  136. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  137. TEXT("BN_new failed."));
  138. if (!BN_set_word(bn_e, e))
  139. throw Exception(__BASE_FILE__, __LINE__,
  140. TEXT("BN_set_word failed."));
  141. if (!RSA_generate_key_ex(_RSAObj, bits, bn_e, nullptr))
  142. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  143. TEXT("RSA_generate_key_ex failed."));
  144. }
  145. template<RSAKeyType __Type, RSAKeyFormat __Format = RSAKeyFormat::NotSpecified>
  146. void ExportKeyToFile(const std::string& FileName) {
  147. ResourceObject<OpensslBIOTraits> BIOKeyFile;
  148. BIOKeyFile.TakeOver(BIO_new_file(FileName.c_str(), "w"));
  149. if (BIOKeyFile.IsValid() == false)
  150. throw Exception(__BASE_FILE__, __LINE__,
  151. TEXT("BIO_new_file failed."));
  152. _WriteRSAToBIO<__Type, __Format>(_RSAObj, BIOKeyFile);
  153. }
  154. template<RSAKeyType __Type, RSAKeyFormat __Format = RSAKeyFormat::NotSpecified>
  155. std::string ExportKeyString() {
  156. std::string KeyString;
  157. ResourceObject<OpensslBIOTraits> BIOKeyMemory;
  158. long s;
  159. const char* p = nullptr;
  160. BIOKeyMemory.TakeOver(BIO_new(BIO_s_mem()));
  161. if (BIOKeyMemory.IsValid() == false)
  162. throw Exception(__BASE_FILE__, __LINE__,
  163. TEXT("BIO_new failed."));
  164. _WriteRSAToBIO<__Type, __Format>(_RSAObj, BIOKeyMemory);
  165. s = BIO_get_mem_data(BIOKeyMemory, &p);
  166. KeyString.resize(s);
  167. memcpy(KeyString.data(), p, s);
  168. return KeyString;
  169. }
  170. template<RSAKeyType __Type, RSAKeyFormat __Format = RSAKeyFormat::NotSpecified>
  171. void ImportKeyFromFile(const std::string& FileName) {
  172. ResourceObject<OpensslBIOTraits> BIOKeyFile;
  173. RSA* NewRSAObj;
  174. BIOKeyFile.TakeOver(BIO_new_file(FileName.c_str(), "r"));
  175. if (BIOKeyFile.IsValid() == false)
  176. throw Exception(__BASE_FILE__, __LINE__,
  177. TEXT("BIO_new_file failed."));
  178. NewRSAObj = _ReadRSAFromBIO<__Type, __Format>(BIOKeyFile);
  179. _RSAObj.Release();
  180. _RSAObj.TakeOver(NewRSAObj);
  181. }
  182. template<RSAKeyType __Type, RSAKeyFormat __Format = RSAKeyFormat::NotSpecified>
  183. void ImportKeyString(const std::string& KeyString) {
  184. ResourceObject<OpensslBIOTraits> BIOKeyMemory;
  185. RSA* NewRSAObj;
  186. BIOKeyMemory = BIO_new(BIO_s_mem());
  187. if (BIOKeyMemory == nullptr)
  188. throw Exception(__BASE_FILE__, __LINE__,
  189. TEXT("BIO_new failed."));
  190. if (BIO_puts(BIOKeyMemory, KeyString.c_str()) <= 0)
  191. throw Exception(__BASE_FILE__, __LINE__,
  192. TEXT("BIO_puts failed."));
  193. NewRSAObj = _ReadRSAFromBIO<__Type, __Format>(BIOKeyMemory);
  194. _RSAObj.Release();
  195. _RSAObj.TakeOver(NewRSAObj);
  196. }
  197. template<RSAKeyType __Type = RSAKeyType::PublicKey>
  198. int Encrypt(const void* from, int len, void* to, int padding) {
  199. int WriteBytes;
  200. if constexpr (__Type == RSAKeyType::PrivateKey) {
  201. WriteBytes = RSA_private_encrypt(len,
  202. reinterpret_cast<const unsigned char*>(from),
  203. reinterpret_cast<unsigned char*>(to),
  204. _RSAObj,
  205. padding);
  206. if (WriteBytes == -1)
  207. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  208. TEXT("RSA_private_encrypt failed."));
  209. } else {
  210. WriteBytes = RSA_public_encrypt(len,
  211. reinterpret_cast<const unsigned char*>(from),
  212. reinterpret_cast<unsigned char*>(to),
  213. _RSAObj,
  214. padding);
  215. if (WriteBytes == -1)
  216. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  217. TEXT("RSA_public_encrypt failed."));
  218. }
  219. return WriteBytes;
  220. }
  221. template<RSAKeyType __Type = RSAKeyType::PrivateKey>
  222. int Decrypt(const void* from, int len, void* to, int padding) {
  223. int WriteBytes;
  224. if constexpr (__Type == RSAKeyType::PrivateKey) {
  225. WriteBytes = RSA_private_decrypt(len,
  226. reinterpret_cast<const unsigned char*>(from),
  227. reinterpret_cast<unsigned char*>(to),
  228. _RSAObj,
  229. padding);
  230. if (WriteBytes == -1)
  231. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  232. TEXT("RSA_private_decrypt failed."));
  233. } else {
  234. WriteBytes = RSA_public_decrypt(len,
  235. reinterpret_cast<const unsigned char*>(from),
  236. reinterpret_cast<unsigned char*>(to),
  237. _RSAObj,
  238. padding);
  239. if (WriteBytes == -1)
  240. throw OpensslException(__BASE_FILE__, __LINE__, ERR_get_error(),
  241. TEXT("RSA_public_decrypt failed."));
  242. }
  243. return WriteBytes;
  244. }
  245. };