1 package de.saly.kafka.crypto;
2
3 import java.io.DataInputStream;
4 import java.io.File;
5 import java.io.FileInputStream;
6 import java.io.IOException;
7 import java.security.KeyFactory;
8 import java.security.MessageDigest;
9 import java.security.NoSuchAlgorithmException;
10 import java.security.PrivateKey;
11 import java.security.PublicKey;
12 import java.security.SecureRandom;
13 import java.security.spec.InvalidKeySpecException;
14 import java.security.spec.PKCS8EncodedKeySpec;
15 import java.security.spec.X509EncodedKeySpec;
16 import java.util.Arrays;
17 import java.util.HashMap;
18 import java.util.Map;
19
20 import javax.crypto.BadPaddingException;
21 import javax.crypto.Cipher;
22 import javax.crypto.IllegalBlockSizeException;
23 import javax.crypto.SecretKey;
24 import javax.crypto.spec.IvParameterSpec;
25 import javax.crypto.spec.SecretKeySpec;
26 import javax.xml.bind.DatatypeConverter;
27
28 import org.apache.kafka.common.KafkaException;
29 import org.apache.kafka.common.utils.Utils;
30
31 public abstract class SerdeCryptoBase {
32
33 public static final String CRYPTO_RSA_PRIVATEKEY_FILEPATH = "crypto.rsa.privatekey.filepath";
34 public static final String CRYPTO_RSA_PUBLICKEY_FILEPATH = "crypto.rsa.publickey.filepath";
35 public static final String CRYPTO_HASH_METHOD = "crypto.hash_method";
36 public static final String CRYPTO_IGNORE_DECRYPT_FAILURES = "crypto.ignore_decrypt_failures";
37 public static final String CRYPTO_AES_KEY_LEN = "crypto.aes.key_len";
38 static final byte[] MAGIC_BYTES = new byte[] { (byte) 0xDF, (byte) 0xBB };
39 protected static final String DEFAULT_TRANSFORMATION = "AES/CBC/PKCS5Padding";
40 private static final Map<String, byte[]> aesKeyCache = new HashMap<String, byte[]>();
41 private static final int MAGIC_BYTES_LENGTH = MAGIC_BYTES.length;
42 private static final int HEADER_LENGTH = MAGIC_BYTES_LENGTH + 3;
43 private static final String AES = "AES";
44 private static final String RSA = "RSA";
45 private static final String RSA_TRANFORMATION = "RSA/ECB/OAEPWithSHA-256AndMGF1Padding";
46 private static final int RSA_MULTIPLICATOR = 128;
47 private int opMode;
48 private String hashMethod = "SHA-256";
49 private int aesKeyLen = 128;
50 private boolean ignoreDecryptFailures = false;
51 private ProducerCryptoBundle producerCryptoBundle = null;
52 private ConsumerCryptoBundle consumerCryptoBundle = null;
53
54 protected SerdeCryptoBase() {
55
56 }
57
58
59 private class ConsumerCryptoBundle {
60
61 private Cipher rsaDecrypt;
62 final Cipher aesDecrypt = Cipher.getInstance(DEFAULT_TRANSFORMATION);
63
64 private ConsumerCryptoBundle(PrivateKey privateKey) throws Exception {
65 rsaDecrypt = Cipher.getInstance(RSA_TRANFORMATION);
66 rsaDecrypt.init(Cipher.DECRYPT_MODE, privateKey);
67 }
68
69 private byte[] aesDecrypt(byte[] encrypted) throws KafkaException {
70 try {
71 if (encrypted[0] == MAGIC_BYTES[0] && encrypted[1] == MAGIC_BYTES[1]) {
72 final byte hashLen = encrypted[2];
73 final byte rsaFactor = encrypted[3];
74 final byte ivLen = encrypted[4];
75 final int offset = HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen;
76 final String aesHash = DatatypeConverter.printHexBinary(Arrays.copyOfRange(encrypted, HEADER_LENGTH, HEADER_LENGTH + hashLen));
77 final byte[] iv = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR),
78 HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen);
79
80 byte[] aesKey;
81
82 if ((aesKey = aesKeyCache.get(aesHash)) != null) {
83 aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv));
84 return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset);
85 } else {
86 byte[] rsaEncryptedAesKey = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen,
87 HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR));
88 aesKey = crypt(rsaDecrypt, rsaEncryptedAesKey);
89 aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv));
90 aesKeyCache.put(aesHash, aesKey);
91 return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset);
92 }
93 } else {
94 return encrypted;
95 }
96 } catch (Exception e) {
97 if(ignoreDecryptFailures) {
98 return encrypted;
99 }
100
101 throw new KafkaException("Decrypt failed",e);
102 }
103 }
104 }
105
106 private class ThreadAwareKeyInfo {
107 private final SecretKey aesKey;
108 private final byte[] aesHash;
109 private final byte[] rsaEncyptedAesKey;
110 private final Cipher rsaCipher;
111 private final Cipher aesCipher;
112 private final SecureRandom random = new SecureRandom();
113
114 protected ThreadAwareKeyInfo(PublicKey publicKey) throws Exception {
115 byte[] aesKeyBytes = new byte[aesKeyLen/8];
116 random.nextBytes(aesKeyBytes);
117 aesCipher = Cipher.getInstance(DEFAULT_TRANSFORMATION);
118 aesKey = createAESSecretKey(aesKeyBytes);
119 aesHash = hash(aesKeyBytes);
120 rsaCipher = Cipher.getInstance(RSA_TRANFORMATION);
121 rsaCipher.init(Cipher.ENCRYPT_MODE, publicKey);
122 rsaEncyptedAesKey = crypt(rsaCipher, aesKeyBytes);
123 }
124 }
125
126
127 private class ProducerCryptoBundle {
128
129 private ThreadLocal<ThreadAwareKeyInfo> keyInfo = new ThreadLocal<ThreadAwareKeyInfo>() {
130 @Override
131 protected ThreadAwareKeyInfo initialValue() {
132 try {
133 return new ThreadAwareKeyInfo(publicKey);
134 } catch (Exception e) {
135 throw new KafkaException(e);
136 }
137 }
138 };
139 private final PublicKey publicKey;
140
141 private ProducerCryptoBundle(PublicKey publicKey) throws Exception {
142 this.publicKey = publicKey;
143 }
144
145 private void newKey() throws Exception {
146 keyInfo.remove();
147 }
148
149 private byte[] aesEncrypt(byte[] plain) throws KafkaException {
150 final ThreadAwareKeyInfo ki = keyInfo.get();
151
152 try {
153 final byte[] aesIv = new byte[16];
154 ki.random.nextBytes(aesIv);
155 ki.aesCipher.init(Cipher.ENCRYPT_MODE, ki.aesKey, new IvParameterSpec(aesIv));
156 return concatenate(MAGIC_BYTES, new byte[] { (byte) ki.aesHash.length,
157 (byte) (ki.rsaEncyptedAesKey.length / RSA_MULTIPLICATOR), (byte) aesIv.length }, ki.aesHash, ki.rsaEncyptedAesKey,
158 aesIv, crypt(ki.aesCipher, plain));
159 } catch (Exception e) {
160 throw new KafkaException(e);
161 }
162 }
163 }
164
165 protected void init(int opMode, Map<String, ?> configs, boolean isKey) throws KafkaException {
166 this.opMode = opMode;
167
168 final String hashMethodProperty = (String) configs.get(CRYPTO_HASH_METHOD);
169
170 if(hashMethodProperty != null && hashMethodProperty.length() != 0) {
171 hashMethod = hashMethodProperty;
172 }
173
174 final String ignoreDecryptFailuresProperty = (String) configs.get(CRYPTO_IGNORE_DECRYPT_FAILURES);
175
176 if(ignoreDecryptFailuresProperty != null && ignoreDecryptFailuresProperty.length() != 0) {
177 ignoreDecryptFailures = Boolean.parseBoolean(ignoreDecryptFailuresProperty);
178 }
179
180 final String aesKeyLenProperty = (String) configs.get(CRYPTO_AES_KEY_LEN);
181
182 if(aesKeyLenProperty != null && aesKeyLenProperty.length() != 0) {
183 aesKeyLen = Integer.parseInt(aesKeyLenProperty);
184 if(aesKeyLen < 128 || aesKeyLen % 8 != 0) {
185 throw new KafkaException("Invalid aes key size, should be 128, 192 or 256");
186 }
187 }
188
189 try {
190 if (opMode == Cipher.DECRYPT_MODE) {
191
192 String rsaPrivateKeyFile = (String) configs.get(CRYPTO_RSA_PRIVATEKEY_FILEPATH);
193 consumerCryptoBundle = new ConsumerCryptoBundle(createRSAPrivateKey(readBytesFromFile(rsaPrivateKeyFile)));
194 } else {
195
196 String rsaPublicKeyFile = (String) configs.get(CRYPTO_RSA_PUBLICKEY_FILEPATH);
197 producerCryptoBundle = new ProducerCryptoBundle(createRSAPublicKey(readBytesFromFile(rsaPublicKeyFile)));
198 }
199 } catch (Exception e) {
200 throw new KafkaException(e);
201 }
202 }
203
204 protected byte[] crypt(byte[] array) throws KafkaException {
205 if (array == null || array.length == 0) {
206 return array;
207 }
208
209 if (opMode == Cipher.DECRYPT_MODE) {
210
211 return consumerCryptoBundle.aesDecrypt(array);
212 } else {
213
214 return producerCryptoBundle.aesEncrypt(array);
215 }
216 }
217
218
219
220
221 protected void newKey() {
222 try {
223 producerCryptoBundle.newKey();
224 } catch (Exception e) {
225 throw new KafkaException(e);
226 }
227 }
228
229
230
231 @SuppressWarnings("unchecked")
232 protected <T> T newInstance(Map<String, ?> map, String key, Class<T> klass) throws KafkaException {
233 Object val = map.get(key);
234 if (val == null) {
235 throw new KafkaException("No value for '" + key + "' found");
236 } else if (val instanceof String) {
237 try {
238 return (T) Utils.newInstance(Class.forName((String) val));
239 } catch (Exception e) {
240 throw new KafkaException(e);
241 }
242 } else if (val instanceof Class) {
243 return (T) Utils.newInstance((Class<T>) val);
244 } else {
245 throw new KafkaException("Unexpected type '" + val.getClass() + "' for '" + key + "'");
246 }
247 }
248
249 private static PrivateKey createRSAPrivateKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
250 if (encodedKey == null || encodedKey.length == 0) {
251 throw new IllegalArgumentException("Key bytes must not be null or empty");
252 }
253
254 PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(encodedKey);
255 KeyFactory kf = KeyFactory.getInstance(RSA);
256 return kf.generatePrivate(spec);
257 }
258
259 private static SecretKey createAESSecretKey(byte[] encodedKey) {
260 if (encodedKey == null || encodedKey.length == 0) {
261 throw new IllegalArgumentException("Key bytes must not be null or empty");
262 }
263
264 return new SecretKeySpec(encodedKey, AES);
265 }
266
267 private static PublicKey createRSAPublicKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
268 if (encodedKey == null || encodedKey.length == 0) {
269 throw new IllegalArgumentException("Key bytes must not be null or empty");
270 }
271
272 X509EncodedKeySpec spec = new X509EncodedKeySpec(encodedKey);
273 KeyFactory kf = KeyFactory.getInstance(RSA);
274 return kf.generatePublic(spec);
275 }
276
277 private static byte[] readBytesFromFile(String filename) throws IOException {
278 if (filename == null) {
279 throw new IllegalArgumentException("Filename must not be null");
280 }
281
282 File f = new File(filename);
283 DataInputStream dis = new DataInputStream(new FileInputStream(f));
284 byte[] bytes = new byte[(int) f.length()];
285 dis.readFully(bytes);
286 dis.close();
287 return bytes;
288 }
289
290 private byte[] hash(byte[] toHash) {
291 try {
292 MessageDigest md = MessageDigest.getInstance(hashMethod);
293 md.update(toHash);
294 return md.digest();
295 } catch (Exception e) {
296 throw new KafkaException(e);
297 }
298 }
299
300 private static byte[] crypt(Cipher c, byte[] plain) throws IllegalBlockSizeException, BadPaddingException {
301 return c.doFinal(plain);
302 }
303
304 private static byte[] crypt(Cipher c, byte[] plain, int offset, int len) throws IllegalBlockSizeException, BadPaddingException {
305 return c.doFinal(plain, offset, len);
306 }
307
308 public static byte[] concatenate(byte[] a, byte[] b, byte[] c, byte[] d, byte[] e, byte[] f) {
309 if (a != null && b != null && c != null && d != null && e != null && f != null) {
310 byte[] rv = new byte[a.length + b.length + c.length + d.length + e.length + f.length];
311 System.arraycopy(a, 0, rv, 0, a.length);
312 System.arraycopy(b, 0, rv, a.length, b.length);
313 System.arraycopy(c, 0, rv, a.length + b.length, c.length);
314 System.arraycopy(d, 0, rv, a.length + b.length + c.length, d.length);
315 System.arraycopy(e, 0, rv, a.length + b.length + c.length + d.length, e.length);
316 System.arraycopy(f, 0, rv, a.length + b.length + c.length + d.length + e.length, f.length);
317 return rv;
318 } else {
319 throw new IllegalArgumentException("arrays must not be null");
320 }
321 }
322 }