Skip to content

Commit dabd771

Browse files
committed
[SPARK-43038][SQL] Support the CBC mode by aes_encrypt()/aes_decrypt()
### What changes were proposed in this pull request? In the PR, I propose new AES mode for the `aes_encrypt()`/`aes_decrypt()` functions - `CBC` ([Cipher Block Chaining](https://www.ibm.com/docs/en/linux-on-systems?topic=operation-cipher-block-chaining-cbc-mode)) with the padding `PKCS7(5)`. The `aes_encrypt()` function returns a binary value which consists of the following fields: 1. The salt magic prefix `Salted__` with the length of 8 bytes. 2. A salt generated per every `aes_encrypt()` call using `java.security.SecureRandom`. Its length is 8 bytes. 3. The encrypted input. The encrypt function derives the secret key and initialization vector (16 bytes) from the salt and user's key using the same algorithm as OpenSSL's `EVP_BytesToKey()` (versions >= 1.1.0c). The `aes_decrypt()` functions assumes that its input has the fields as showed above. For example: ```sql spark-sql> SELECT base64(aes_encrypt('Apache Spark', '0000111122223333', 'CBC', 'PKCS')); U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk= spark-sql> SELECT aes_decrypt(unbase64('U2FsdGVkX1/ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='), '0000111122223333', 'CBC', 'PKCS'); Apache Spark ``` ### Why are the changes needed? To achieve feature parity with other systems/frameworks, and make the migration process from them to Spark SQL easier. For example, the `CBC` mode is supported by: - BigQuery: https://cloud.google.com/bigquery/docs/reference/standard-sql/aead-encryption-concepts#block_cipher_modes - Snowflake: https://docs.snowflake.com/en/sql-reference/functions/encrypt.html ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running new checks: ``` $ build/sbt "sql/testOnly *QueryExecutionErrorsSuite" $ build/sbt "sql/test:testOnly org.apache.spark.sql.expressions.ExpressionInfoSuite" $ build/sbt "test:testOnly org.apache.spark.sql.MiscFunctionsSuite" $ build/sbt "core/testOnly *SparkThrowableSuite" ``` and checked compatibility with LibreSSL/OpenSSL: ``` $ openssl version LibreSSL 3.3.6 $ echo -n 'Apache Spark' | openssl enc -e -aes-128-cbc -pass pass:0000111122223333 -a U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls= ``` ```sql spark-sql (default)> SELECT aes_decrypt(unbase64('U2FsdGVkX1+5GyAmmG7wDWWDBAuUuxjMy++cMFytpls='), '0000111122223333', 'CBC'); Apache Spark ``` decrypt Spark's output by OpenSSL: ```sql spark-sql (default)> SELECT base64(aes_encrypt('Apache Spark', 'abcdefghijklmnop12345678ABCDEFGH', 'CBC', 'PKCS')); U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA= ``` ``` $ echo 'U2FsdGVkX1+maU2vmxrulgxXuQSyZ3ODnlHKqnt2fDA=' | openssl aes-256-cbc -a -d -pass pass:abcdefghijklmnop12345678ABCDEFGH Apache Spark ``` Closes apache#40704 from MaxGekk/aes-cbc. Authored-by: Max Gekk <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 74d840c commit dabd771

File tree

6 files changed

+141
-25
lines changed

6 files changed

+141
-25
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,11 @@
978978
"expects a binary value with 16, 24 or 32 bytes, but got <actualLength> bytes."
979979
]
980980
},
981+
"AES_SALTED_MAGIC" : {
982+
"message" : [
983+
"Initial bytes from input <saltedMagic> do not match 'Salted__' (0x53616C7465645F5F)."
984+
]
985+
},
981986
"PATTERN" : {
982987
"message" : [
983988
"<value>."

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@
2222

2323
import javax.crypto.Cipher;
2424
import javax.crypto.spec.GCMParameterSpec;
25+
import javax.crypto.spec.IvParameterSpec;
2526
import javax.crypto.spec.SecretKeySpec;
2627
import java.nio.ByteBuffer;
2728
import java.security.GeneralSecurityException;
29+
import java.security.MessageDigest;
30+
import java.security.NoSuchAlgorithmException;
2831
import java.security.SecureRandom;
32+
import java.util.Arrays;
33+
34+
import static java.nio.charset.StandardCharsets.US_ASCII;
2935

3036
/**
3137
* An utility class for constructing expressions.
@@ -35,6 +41,13 @@ public class ExpressionImplUtils {
3541
private static final int GCM_IV_LEN = 12;
3642
private static final int GCM_TAG_LEN = 128;
3743

44+
private static final int CBC_IV_LEN = 16;
45+
private static final int CBC_SALT_LEN = 8;
46+
/** OpenSSL's magic initial bytes. */
47+
private static final String SALTED_STR = "Salted__";
48+
private static final byte[] SALTED_MAGIC = SALTED_STR.getBytes(US_ASCII);
49+
50+
3851
/**
3952
* Function to check if a given number string is a valid Luhn number
4053
* @param numberString
@@ -115,11 +128,70 @@ private static byte[] aesInternal(
115128
cipher.init(Cipher.DECRYPT_MODE, secretKey, parameterSpec);
116129
return cipher.doFinal(input, GCM_IV_LEN, input.length - GCM_IV_LEN);
117130
}
131+
} else if (mode.equalsIgnoreCase("CBC") &&
132+
(padding.equalsIgnoreCase("PKCS") || padding.equalsIgnoreCase("DEFAULT"))) {
133+
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
134+
if (opmode == Cipher.ENCRYPT_MODE) {
135+
byte[] salt = new byte[CBC_SALT_LEN];
136+
secureRandom.nextBytes(salt);
137+
final byte[] keyAndIv = getKeyAndIv(key, salt);
138+
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
139+
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
140+
cipher.init(
141+
Cipher.ENCRYPT_MODE,
142+
new SecretKeySpec(keyValue, "AES"),
143+
new IvParameterSpec(iv));
144+
byte[] encrypted = cipher.doFinal(input, 0, input.length);
145+
ByteBuffer byteBuffer = ByteBuffer.allocate(
146+
SALTED_MAGIC.length + CBC_SALT_LEN + encrypted.length);
147+
byteBuffer.put(SALTED_MAGIC);
148+
byteBuffer.put(salt);
149+
byteBuffer.put(encrypted);
150+
return byteBuffer.array();
151+
} else {
152+
assert(opmode == Cipher.DECRYPT_MODE);
153+
final byte[] shouldBeMagic = Arrays.copyOfRange(input, 0, SALTED_MAGIC.length);
154+
if (!Arrays.equals(shouldBeMagic, SALTED_MAGIC)) {
155+
throw QueryExecutionErrors.aesInvalidSalt(shouldBeMagic);
156+
}
157+
final byte[] salt = Arrays.copyOfRange(
158+
input, SALTED_MAGIC.length, SALTED_MAGIC.length + CBC_SALT_LEN);
159+
final byte[] keyAndIv = getKeyAndIv(key, salt);
160+
final byte[] keyValue = Arrays.copyOfRange(keyAndIv, 0, key.length);
161+
final byte[] iv = Arrays.copyOfRange(keyAndIv, key.length, key.length + CBC_IV_LEN);
162+
cipher.init(
163+
Cipher.DECRYPT_MODE,
164+
new SecretKeySpec(keyValue, "AES"),
165+
new IvParameterSpec(iv, 0, CBC_IV_LEN));
166+
return cipher.doFinal(input, CBC_IV_LEN, input.length - CBC_IV_LEN);
167+
}
118168
} else {
119169
throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
120170
}
121171
} catch (GeneralSecurityException e) {
122172
throw QueryExecutionErrors.aesCryptoError(e.getMessage());
123173
}
124174
}
175+
176+
// Derive the key and init vector in the same way as OpenSSL's EVP_BytesToKey
177+
// since the version 1.1.0c which switched to SHA-256 as the hash.
178+
private static byte[] getKeyAndIv(byte[] key, byte[] salt) throws NoSuchAlgorithmException {
179+
final byte[] keyAndSalt = arrConcat(key, salt);
180+
byte[] hash = new byte[0];
181+
byte[] keyAndIv = new byte[0];
182+
for (int i = 0; i < 3 && keyAndIv.length < key.length + CBC_IV_LEN; i++) {
183+
final byte[] hashData = arrConcat(hash, keyAndSalt);
184+
final MessageDigest md = MessageDigest.getInstance("SHA-256");
185+
hash = md.digest(hashData);
186+
keyAndIv = arrConcat(keyAndIv, hash);
187+
}
188+
return keyAndIv;
189+
}
190+
191+
private static byte[] arrConcat(final byte[] arr1, final byte[] arr2) {
192+
final byte[] res = new byte[arr1.length + arr2.length];
193+
System.arraycopy(arr1, 0, res, 0, arr1.length);
194+
System.arraycopy(arr2, 0, res, arr1.length, arr2.length);
195+
return res;
196+
}
125197
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,17 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
313313
@ExpressionDescription(
314314
usage = """
315315
_FUNC_(expr, key[, mode[, padding]]) - Returns an encrypted value of `expr` using AES in given `mode` with the specified `padding`.
316-
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE').
316+
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
317317
The default mode is GCM.
318318
""",
319319
arguments = """
320320
Arguments:
321321
* expr - The binary value to encrypt.
322322
* key - The passphrase to use to encrypt the data.
323323
* mode - Specifies which block cipher mode should be used to encrypt messages.
324-
Valid modes: ECB, GCM.
324+
Valid modes: ECB, GCM, CBC.
325325
* padding - Specifies how to pad messages whose length is not a multiple of the block size.
326-
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB and NONE for GCM.
326+
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC.
327327
""",
328328
examples = """
329329
Examples:
@@ -333,6 +333,8 @@ case class CurrentUser() extends LeafExpression with Unevaluable {
333333
6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210
334334
> SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
335335
3lmwu+Mw0H3fi5NDvcu9lg==
336+
> SELECT base64(_FUNC_('Apache Spark', '1234567890abcdef', 'CBC', 'DEFAULT'));
337+
U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM=
336338
""",
337339
since = "3.3.0",
338340
group = "misc_funcs")
@@ -377,17 +379,17 @@ case class AesEncrypt(
377379
@ExpressionDescription(
378380
usage = """
379381
_FUNC_(expr, key[, mode[, padding]]) - Returns a decrypted value of `expr` using AES in `mode` with `padding`.
380-
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS') and ('GCM', 'NONE').
382+
Key lengths of 16, 24 and 32 bits are supported. Supported combinations of (`mode`, `padding`) are ('ECB', 'PKCS'), ('GCM', 'NONE') and ('CBC', 'PKCS').
381383
The default mode is GCM.
382384
""",
383385
arguments = """
384386
Arguments:
385387
* expr - The binary value to decrypt.
386388
* key - The passphrase to use to decrypt the data.
387389
* mode - Specifies which block cipher mode should be used to decrypt messages.
388-
Valid modes: ECB, GCM.
390+
Valid modes: ECB, GCM, CBC.
389391
* padding - Specifies how to pad messages whose length is not a multiple of the block size.
390-
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB and NONE for GCM.
392+
Valid values: PKCS, NONE, DEFAULT. The DEFAULT padding means PKCS for ECB, NONE for GCM and PKCS for CBC.
391393
""",
392394
examples = """
393395
Examples:
@@ -397,6 +399,8 @@ case class AesEncrypt(
397399
Spark SQL
398400
> SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), '1234567890abcdef', 'ECB', 'PKCS');
399401
Spark SQL
402+
> SELECT _FUNC_(unbase64('U2FsdGVkX18JQ84pfRUwonUrFzpWQ46vKu4+MkJVFGM='), '1234567890abcdef', 'CBC');
403+
Apache Spark
400404
""",
401405
since = "3.3.0",
402406
group = "misc_funcs")

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2651,6 +2651,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
26512651
"detailMessage" -> detailMessage))
26522652
}
26532653

2654+
def aesInvalidSalt(saltedMagic: Array[Byte]): RuntimeException = {
2655+
new SparkRuntimeException(
2656+
errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
2657+
messageParameters = Map(
2658+
"parameter" -> toSQLId("expr"),
2659+
"functionName" -> toSQLId("aes_decrypt"),
2660+
"saltedMagic" -> saltedMagic.map("%02X" format _).mkString("0x", "", "")))
2661+
}
2662+
26542663
def hiveTableWithAnsiIntervalsError(tableName: String): SparkUnsupportedOperationException = {
26552664
new SparkUnsupportedOperationException(
26562665
errorClass = "_LEGACY_ERROR_TEMP_2276",

sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,26 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
6262
}
6363
}
6464

65-
test("SPARK-37591: AES functions - GCM mode") {
65+
test("SPARK-37591, SPARK-43038: AES functions - GCM/CBC mode") {
6666
Seq(
67-
("abcdefghijklmnop", ""),
68-
("abcdefghijklmnop", "abcdefghijklmnop"),
69-
("abcdefghijklmnop12345678", "Spark"),
70-
("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
71-
).foreach { case (key, input) =>
72-
val df = Seq((key, input)).toDF("key", "input")
73-
val encrypted = df.selectExpr("aes_encrypt(input, key, 'GCM', 'NONE') AS enc", "input", "key")
74-
assert(encrypted.schema("enc").dataType === BinaryType)
75-
assert(encrypted.filter($"enc" === $"input").isEmpty)
76-
val result = encrypted.selectExpr(
77-
"CAST(aes_decrypt(enc, key, 'GCM', 'NONE') AS STRING) AS res", "input")
78-
assert(!result.filter($"res" === $"input").isEmpty &&
79-
result.filter($"res" =!= $"input").isEmpty)
67+
"GCM" -> "NONE",
68+
"CBC" -> "PKCS").foreach { case (mode, padding) =>
69+
Seq(
70+
("abcdefghijklmnop", ""),
71+
("abcdefghijklmnop", "abcdefghijklmnop"),
72+
("abcdefghijklmnop12345678", "Spark"),
73+
("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
74+
).foreach { case (key, input) =>
75+
val df = Seq((key, input)).toDF("key", "input")
76+
val encrypted = df.selectExpr(
77+
s"aes_encrypt(input, key, '$mode', '$padding') AS enc", "input", "key")
78+
assert(encrypted.schema("enc").dataType === BinaryType)
79+
assert(encrypted.filter($"enc" === $"input").isEmpty)
80+
val result = encrypted.selectExpr(
81+
s"CAST(aes_decrypt(enc, key, '$mode', '$padding') AS STRING) AS res", "input")
82+
assert(!result.filter($"res" === $"input").isEmpty &&
83+
result.filter($"res" =!= $"input").isEmpty)
84+
}
8085
}
8186
}
8287
}

sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,25 @@ class QueryExecutionErrorsSuite
140140
}
141141
}
142142

143+
test("INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC: AES decrypt failure - invalid salt") {
144+
checkError(
145+
exception = intercept[SparkRuntimeException] {
146+
sql(
147+
"""
148+
|SELECT aes_decrypt(
149+
| unbase64('INVALID_SALT_ERGxwEOTDpDD4bQvDtQaNe+gXGudCcUk='),
150+
| '0000111122223333',
151+
| 'CBC', 'PKCS')
152+
|""".stripMargin).collect()
153+
},
154+
errorClass = "INVALID_PARAMETER_VALUE.AES_SALTED_MAGIC",
155+
parameters = Map(
156+
"parameter" -> "`expr`",
157+
"functionName" -> "`aes_decrypt`",
158+
"saltedMagic" -> "0x20D5402C80D200B4"),
159+
sqlState = "22023")
160+
}
161+
143162
test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and padding") {
144163
val key16 = "abcdefghijklmnop"
145164
val key32 = "abcdefghijklmnop12345678ABCDEFGH"
@@ -157,18 +176,20 @@ class QueryExecutionErrorsSuite
157176
}
158177

159178
// Unsupported AES mode and padding in encrypt
160-
checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC')"),
161-
"CBC", "DEFAULT")
179+
checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC', 'None')"),
180+
"CBC", "None")
162181
checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 'NoPadding')"),
163182
"ECB", "NoPadding")
164183

165184
// Unsupported AES mode and padding in decrypt
166185
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GSM')"),
167-
"GSM", "DEFAULT")
186+
"GSM", "DEFAULT")
168187
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GCM', 'PKCS')"),
169-
"GCM", "PKCS")
188+
"GCM", "PKCS")
170189
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'ECB', 'None')"),
171-
"ECB", "None")
190+
"ECB", "None")
191+
checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'CBC', 'NoPadding')"),
192+
"CBC", "NoPadding")
172193
}
173194

174195
test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") {

0 commit comments

Comments
 (0)