Skip to content

Commit

Permalink
GH-43946: [C++][Parquet] Guard against use of decryptor/encryptor aft…
Browse files Browse the repository at this point in the history
…er wipeout
  • Loading branch information
pitrou committed Sep 4, 2024
1 parent 170c599 commit b253529
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions cpp/src/parquet/encryption/encryption_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class AesEncryptor::AesEncryptorImpl {
}

private:
void CheckValid() {
if (ctx_ == nullptr) {
throw ParquetException("AesEncryptor was wiped out");
}
}

EVP_CIPHER_CTX* ctx_;
int32_t aes_mode_;
int32_t key_length_;
Expand Down Expand Up @@ -156,6 +162,8 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id,
int32_t AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(
span<const uint8_t> footer, span<const uint8_t> key, span<const uint8_t> aad,
span<const uint8_t> nonce, span<uint8_t> encrypted_footer) {
CheckValid();

if (static_cast<size_t>(key_length_) != key.size()) {
std::stringstream ss;
ss << "Wrong key length " << key.size() << ". Should be " << key_length_;
Expand All @@ -180,6 +188,8 @@ int32_t AesEncryptor::AesEncryptorImpl::Encrypt(span<const uint8_t> plaintext,
span<const uint8_t> key,
span<const uint8_t> aad,
span<uint8_t> ciphertext) {
CheckValid();

if (static_cast<size_t>(key_length_) != key.size()) {
std::stringstream ss;
ss << "Wrong key length " << key.size() << ". Should be " << key_length_;
Expand Down Expand Up @@ -413,6 +423,12 @@ class AesDecryptor::AesDecryptorImpl {
}

private:
void CheckValid() {
if (ctx_ == nullptr) {
throw ParquetException("AesDecryptor was wiped out");
}
}

EVP_CIPHER_CTX* ctx_;
int32_t aes_mode_;
int32_t key_length_;
Expand Down Expand Up @@ -714,6 +730,8 @@ int32_t AesDecryptor::AesDecryptorImpl::Decrypt(span<const uint8_t> ciphertext,
span<const uint8_t> key,
span<const uint8_t> aad,
span<uint8_t> plaintext) {
CheckValid();

if (static_cast<size_t>(key_length_) != key.size()) {
std::stringstream ss;
ss << "Wrong key length " << key.size() << ". Should be " << key_length_;
Expand Down Expand Up @@ -806,4 +824,7 @@ void RandBytes(unsigned char* buf, size_t num) {

void EnsureBackendInitialized() { openssl::EnsureInitialized(); }

#undef ENCRYPT_INIT
#undef DECRYPT_INIT

} // namespace parquet::encryption

0 comments on commit b253529

Please sign in to comment.