RSACipher.hpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #pragma once
  2. #include <openssl/err.h>
  3. #include <openssl/pem.h>
  4. #include <openssl/bio.h>
  5. #include <openssl/rsa.h>
  6. #include "ExceptionOpenssl.hpp"
  7. #include "ResourceGuardOpenssl.hpp"
  8. #include <string>
  9. #include <memory.h>
  10. #ifdef _DEBUG
  11. #pragma comment(lib, "libcryptoMTd.lib")
  12. #else
  13. #pragma comment(lib, "libcryptoMT.lib")
  14. #endif
  15. #pragma comment(lib, "WS2_32.lib") // some symbol are used in OpenSSL static lib
  16. #pragma comment(lib, "Crypt32.lib") // some symbol are used in OpenSSL static lib
  17. #undef __BASE_FILE__
  18. #define __BASE_FILE__ "RSACipher.hpp"
  19. class RSACipher {
  20. public:
  21. enum class KeyType {
  22. PrivateKey,
  23. PublicKey
  24. };
  25. enum class KeyFormat {
  26. NotSpecified,
  27. PEM,
  28. PKCS1
  29. };
  30. private:
  31. ResourceGuard<OpensslRSATraits> _RsaObj;
  32. RSACipher(RSA* pRsa) : _RsaObj(pRsa) {}
  33. // Copy constructor is not allowed
  34. RSACipher(const RSACipher&) = delete;
  35. // Copy assignment is not allowed
  36. RSACipher& operator=(const RSACipher&) = delete;
  37. template<KeyType _Type, KeyFormat _Format = KeyFormat::NotSpecified>
  38. static void _RSAToBIO(RSA* pRsaObject, BIO* pBioObject) {
  39. if constexpr (_Type == KeyType::PrivateKey) {
  40. if (!PEM_write_bio_RSAPrivateKey(pBioObject, pRsaObject, nullptr, nullptr, 0, nullptr, nullptr))
  41. throw Exception(__BASE_FILE__, __LINE__,
  42. "PEM_write_bio_RSAPrivateKey fails.");
  43. } else {
  44. if constexpr (_Format == KeyFormat::PEM) {
  45. if (!PEM_write_bio_RSA_PUBKEY(pBioObject, pRsaObject))
  46. throw Exception(__BASE_FILE__, __LINE__,
  47. "PEM_write_bio_RSA_PUBKEY fails.");
  48. } else if constexpr (_Format == KeyFormat::PKCS1) {
  49. if (!PEM_write_bio_RSAPublicKey(pBioObject, pRsaObject))
  50. throw Exception(__BASE_FILE__, __LINE__,
  51. "PEM_write_bio_RSAPublicKey fails.");
  52. } else {
  53. static_assert(_Format == KeyFormat::PEM || _Format == KeyFormat::PKCS1);
  54. }
  55. }
  56. }
  57. template<KeyType _Type, KeyFormat _Format = KeyFormat::NotSpecified>
  58. static RSA* _BIOToRSA(BIO* pBioObject) {
  59. RSA* pNewRsaObject;
  60. if constexpr (_Type == KeyType::PrivateKey) {
  61. pNewRsaObject = PEM_read_bio_RSAPrivateKey(pBioObject, nullptr, nullptr, nullptr);
  62. if (pNewRsaObject == nullptr)
  63. throw Exception(__BASE_FILE__, __LINE__,
  64. "PEM_read_bio_RSAPrivateKey fails.");
  65. } else {
  66. if constexpr (_Format == KeyFormat::PEM) {
  67. pNewRsaObject = PEM_read_bio_RSA_PUBKEY(pBioObject, nullptr, nullptr, nullptr);
  68. if (pNewRsaObject == nullptr)
  69. throw Exception(__BASE_FILE__, __LINE__,
  70. "PEM_read_bio_RSA_PUBKEY fails.");
  71. } else if constexpr (_Format == KeyFormat::PKCS1) {
  72. pNewRsaObject = PEM_read_bio_RSAPublicKey(pBioObject, nullptr, nullptr, nullptr);
  73. if (pNewRsaObject == nullptr)
  74. throw Exception(__BASE_FILE__, __LINE__,
  75. "PEM_read_bio_RSAPublicKey fails.");
  76. } else {
  77. static_assert(_Format == KeyFormat::PEM || _Format == KeyFormat::PKCS1);
  78. }
  79. }
  80. return pNewRsaObject;
  81. }
  82. public:
  83. static RSACipher* Create() {
  84. RSACipher* aCipher = new RSACipher(RSA_new());
  85. if (aCipher->_RsaObj.IsValid() == false) {
  86. delete aCipher;
  87. aCipher = nullptr;
  88. }
  89. return aCipher;
  90. }
  91. RSACipher() : _RsaObj(RSA_new()) {
  92. if (_RsaObj.IsValid() == false)
  93. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  94. "RSA_new fails.");
  95. }
  96. void GenerateKey(int bits, unsigned int e = RSA_F4) {
  97. ResourceGuard<OpensslBNTraits> bn_e;
  98. bn_e.TakeHoldOf(BN_new());
  99. if (bn_e.IsValid() == false)
  100. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  101. "BN_new fails.");
  102. if (!BN_set_word(bn_e, e))
  103. throw Exception(__BASE_FILE__, __LINE__,
  104. "BN_set_word fails.");
  105. if (!RSA_generate_key_ex(_RsaObj, bits, bn_e, nullptr))
  106. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  107. "RSA_generate_key_ex fails.");
  108. }
  109. template<KeyType _Type, KeyFormat _Format = KeyFormat::NotSpecified>
  110. void ExportKeyToFile(const std::string& FileName) {
  111. ResourceGuard<OpensslBIOTraits> bio_file;
  112. bio_file.TakeHoldOf(BIO_new_file(FileName.c_str(), "w"));
  113. if (bio_file.IsValid() == false)
  114. throw Exception(__BASE_FILE__, __LINE__,
  115. "BIO_new_file fails.");
  116. _RSAToBIO<_Type, _Format>(_RsaObj, bio_file);
  117. }
  118. template<KeyType _Type, KeyFormat _Format = KeyFormat::NotSpecified>
  119. std::string ExportKeyString() {
  120. std::string result;
  121. ResourceGuard<OpensslBIOTraits> bio_mem;
  122. int DataSize;
  123. const char* pData = nullptr;
  124. bio_mem.TakeHoldOf(BIO_new(BIO_s_mem()));
  125. if (bio_mem.IsValid() == false)
  126. throw Exception(__BASE_FILE__, __LINE__,
  127. "BIO_new fails.");
  128. _RSAToBIO<_Type, _Format>(_RsaObj, bio_mem);
  129. DataSize = BIO_get_mem_data(bio_mem, &pData);
  130. result.resize(DataSize);
  131. memcpy(result.data(), pData, DataSize);
  132. return result;
  133. }
  134. template<KeyType _Type, KeyFormat _Format = KeyFormat::NotSpecified>
  135. void ImportKeyFromFile(const std::string& FileName) {
  136. bool bSuccess = false;
  137. ResourceGuard<OpensslBIOTraits> bio_file;
  138. RSA* NewRsaObj;
  139. bio_file.TakeHoldOf(BIO_new_file(FileName.c_str(), "r"));
  140. if (bio_file.IsValid() == false)
  141. throw Exception(__BASE_FILE__, __LINE__,
  142. "BIO_new_file fails.");
  143. NewRsaObj = _BIOToRSA<_Type, _Format>(bio_file);
  144. _RsaObj.Release();
  145. _RsaObj.TakeHoldOf(NewRsaObj);
  146. }
  147. template<KeyType _Type, KeyFormat _Format = KeyFormat::NotSpecified>
  148. void ImportKeyString(const std::string& KeyString) {
  149. ResourceGuard<OpensslBIOTraits> bio_mem;
  150. RSA* NewRsaObj;
  151. bio_mem = BIO_new(BIO_s_mem());
  152. if (bio_mem == nullptr)
  153. throw Exception(__BASE_FILE__, __LINE__,
  154. "BIO_new fails.");
  155. if (BIO_puts(bio_mem, KeyString.c_str()) <= 0)
  156. throw Exception(__BASE_FILE__, __LINE__,
  157. "BIO_puts fails.");
  158. NewRsaObj = _BIOToRSA<_Type, _Format>(bio_mem);
  159. _RsaObj.Release();
  160. _RsaObj.TakeHoldOf(NewRsaObj);
  161. }
  162. template<KeyType _Type = KeyType::PublicKey>
  163. int Encrypt(const void* from, int len, void* to, int padding) {
  164. int write_bytes;
  165. if constexpr (_Type == KeyType::PrivateKey) {
  166. write_bytes = RSA_private_encrypt(len,
  167. reinterpret_cast<const unsigned char*>(from),
  168. reinterpret_cast<unsigned char*>(to),
  169. _RsaObj,
  170. padding);
  171. if (write_bytes == -1)
  172. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  173. "RSA_private_encrypt fails.");
  174. } else {
  175. write_bytes = RSA_public_encrypt(len,
  176. reinterpret_cast<const unsigned char*>(from),
  177. reinterpret_cast<unsigned char*>(to),
  178. _RsaObj,
  179. padding);
  180. if (write_bytes == -1)
  181. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  182. "RSA_public_encrypt fails.");
  183. }
  184. return write_bytes;
  185. }
  186. template<KeyType _Type = KeyType::PrivateKey>
  187. int Decrypt(const void* from, int len, void* to, int padding) {
  188. int write_bytes;
  189. if constexpr (_Type == KeyType::PrivateKey) {
  190. write_bytes = RSA_private_decrypt(len,
  191. reinterpret_cast<const unsigned char*>(from),
  192. reinterpret_cast<unsigned char*>(to),
  193. _RsaObj,
  194. padding);
  195. if (write_bytes == -1)
  196. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  197. "RSA_private_decrypt fails.");
  198. } else {
  199. write_bytes = RSA_public_decrypt(len,
  200. reinterpret_cast<const unsigned char*>(from),
  201. reinterpret_cast<unsigned char*>(to),
  202. _RsaObj,
  203. padding);
  204. if (write_bytes == -1)
  205. throw OpensslError(__BASE_FILE__, __LINE__, ERR_get_error(),
  206. "RSA_public_decrypt fails.");
  207. }
  208. return write_bytes;
  209. }
  210. };