View Javadoc
1   package de.saly.kafka.crypto;
2   
3   import java.io.DataInputStream;
4   import java.io.File;
5   import java.io.FileInputStream;
6   import java.io.IOException;
7   import java.security.KeyFactory;
8   import java.security.MessageDigest;
9   import java.security.NoSuchAlgorithmException;
10  import java.security.PrivateKey;
11  import java.security.PublicKey;
12  import java.security.SecureRandom;
13  import java.security.spec.InvalidKeySpecException;
14  import java.security.spec.PKCS8EncodedKeySpec;
15  import java.security.spec.X509EncodedKeySpec;
16  import java.util.Arrays;
17  import java.util.HashMap;
18  import java.util.Map;
19  
20  import javax.crypto.BadPaddingException;
21  import javax.crypto.Cipher;
22  import javax.crypto.IllegalBlockSizeException;
23  import javax.crypto.SecretKey;
24  import javax.crypto.spec.IvParameterSpec;
25  import javax.crypto.spec.SecretKeySpec;
26  import javax.xml.bind.DatatypeConverter;
27  
28  import org.apache.kafka.common.KafkaException;
29  import org.apache.kafka.common.utils.Utils;
30  
31  public abstract class SerdeCryptoBase {
32  
33      public static final String CRYPTO_RSA_PRIVATEKEY_FILEPATH = "crypto.rsa.privatekey.filepath"; //consumer
34      public static final String CRYPTO_RSA_PUBLICKEY_FILEPATH = "crypto.rsa.publickey.filepath"; //producer
35      public static final String CRYPTO_HASH_METHOD = "crypto.hash_method";
36      public static final String CRYPTO_IGNORE_DECRYPT_FAILURES = "crypto.ignore_decrypt_failures";
37      public static final String CRYPTO_AES_KEY_LEN = "crypto.aes.key_len";
38      static final byte[] MAGIC_BYTES = new byte[] { (byte) 0xDF, (byte) 0xBB };
39      protected static final String DEFAULT_TRANSFORMATION = "AES/CBC/PKCS5Padding"; //TODO allow other like GCM
40      private static final Map<String, byte[]> aesKeyCache = new HashMap<String, byte[]>();
41      private static final int MAGIC_BYTES_LENGTH = MAGIC_BYTES.length;
42      private static final int HEADER_LENGTH = MAGIC_BYTES_LENGTH + 3;
43      private static final String AES = "AES";
44      private static final String RSA = "RSA";
45      private static final String RSA_TRANFORMATION = "RSA/ECB/OAEPWithSHA-256AndMGF1Padding";
46      private static final int RSA_MULTIPLICATOR = 128;
47      private int opMode;
48      private String hashMethod = "SHA-256";
49      private int aesKeyLen = 128;
50      private boolean ignoreDecryptFailures = false;
51      private ProducerCryptoBundle producerCryptoBundle = null;
52      private ConsumerCryptoBundle consumerCryptoBundle = null;
53  
54      protected SerdeCryptoBase() {
55  
56      }
57  
58      //not thread safe
59      private class ConsumerCryptoBundle {
60  
61          private Cipher rsaDecrypt;
62          final Cipher aesDecrypt = Cipher.getInstance(DEFAULT_TRANSFORMATION);
63  
64          private ConsumerCryptoBundle(PrivateKey privateKey) throws Exception {
65              rsaDecrypt = Cipher.getInstance(RSA_TRANFORMATION);
66              rsaDecrypt.init(Cipher.DECRYPT_MODE, privateKey);
67          }
68  
69          private byte[] aesDecrypt(byte[] encrypted) throws KafkaException {
70              try {
71                  if (encrypted[0] == MAGIC_BYTES[0] && encrypted[1] == MAGIC_BYTES[1]) {
72                      final byte hashLen = encrypted[2];
73                      final byte rsaFactor = encrypted[3];
74                      final byte ivLen = encrypted[4];
75                      final int offset = HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen;
76                      final String aesHash = DatatypeConverter.printHexBinary(Arrays.copyOfRange(encrypted, HEADER_LENGTH, HEADER_LENGTH + hashLen));
77                      final byte[] iv = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR),
78                              HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen);
79  
80                      byte[] aesKey;
81  
82                      if ((aesKey = aesKeyCache.get(aesHash)) != null) {
83                          aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv));
84                          return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset);
85                      } else {
86                          byte[] rsaEncryptedAesKey = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen,
87                                  HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR));
88                          aesKey = crypt(rsaDecrypt, rsaEncryptedAesKey);
89                          aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv));
90                          aesKeyCache.put(aesHash, aesKey);
91                          return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset);
92                      }
93                  } else {
94                      return encrypted; //not encrypted, just bypass decryption
95                  }
96              } catch (Exception e) {
97                  if(ignoreDecryptFailures) {
98                      return encrypted; //Probably not encrypted, just bypass decryption
99                  }
100                 
101                 throw new KafkaException("Decrypt failed",e);
102             }
103         }
104     }
105 
106     private class ThreadAwareKeyInfo {
107         private final SecretKey aesKey;
108         private final byte[] aesHash;
109         private final byte[] rsaEncyptedAesKey;
110         private final Cipher rsaCipher;
111         private final Cipher aesCipher;
112         private final SecureRandom random = new SecureRandom();
113 
114         protected ThreadAwareKeyInfo(PublicKey publicKey) throws Exception {
115             byte[] aesKeyBytes = new byte[aesKeyLen/8];
116             random.nextBytes(aesKeyBytes);
117             aesCipher = Cipher.getInstance(DEFAULT_TRANSFORMATION);
118             aesKey = createAESSecretKey(aesKeyBytes);
119             aesHash = hash(aesKeyBytes);
120             rsaCipher = Cipher.getInstance(RSA_TRANFORMATION);
121             rsaCipher.init(Cipher.ENCRYPT_MODE, publicKey);
122             rsaEncyptedAesKey = crypt(rsaCipher, aesKeyBytes);
123         }
124     }
125 
126     //threads safe
127     private class ProducerCryptoBundle {
128 
129         private ThreadLocal<ThreadAwareKeyInfo> keyInfo = new ThreadLocal<ThreadAwareKeyInfo>() {
130             @Override
131             protected ThreadAwareKeyInfo initialValue() {
132                 try {
133                     return new ThreadAwareKeyInfo(publicKey);
134                 } catch (Exception e) {
135                     throw new KafkaException(e);
136                 }
137             }
138         };
139         private final PublicKey publicKey;
140 
141         private ProducerCryptoBundle(PublicKey publicKey) throws Exception {
142             this.publicKey = publicKey;
143         }
144 
145         private void newKey() throws Exception {
146             keyInfo.remove();
147         }
148 
149         private byte[] aesEncrypt(byte[] plain) throws KafkaException {
150             final ThreadAwareKeyInfo ki = keyInfo.get();
151 
152             try {
153                 final byte[] aesIv = new byte[16];
154                 ki.random.nextBytes(aesIv);
155                 ki.aesCipher.init(Cipher.ENCRYPT_MODE, ki.aesKey, new IvParameterSpec(aesIv));
156                 return concatenate(MAGIC_BYTES, new byte[] { (byte) ki.aesHash.length,
157                         (byte) (ki.rsaEncyptedAesKey.length / RSA_MULTIPLICATOR), (byte) aesIv.length }, ki.aesHash, ki.rsaEncyptedAesKey,
158                         aesIv, crypt(ki.aesCipher, plain));
159             } catch (Exception e) {
160                 throw new KafkaException(e);
161             }
162         }
163     }
164 
165     protected void init(int opMode, Map<String, ?> configs, boolean isKey) throws KafkaException {
166         this.opMode = opMode;
167 
168         final String hashMethodProperty = (String) configs.get(CRYPTO_HASH_METHOD);
169         
170         if(hashMethodProperty != null && hashMethodProperty.length() != 0) {
171             hashMethod = hashMethodProperty;
172         }
173         
174         final String ignoreDecryptFailuresProperty = (String) configs.get(CRYPTO_IGNORE_DECRYPT_FAILURES);
175         
176         if(ignoreDecryptFailuresProperty != null && ignoreDecryptFailuresProperty.length() != 0) {
177             ignoreDecryptFailures = Boolean.parseBoolean(ignoreDecryptFailuresProperty);
178         }
179         
180         final String aesKeyLenProperty = (String) configs.get(CRYPTO_AES_KEY_LEN);
181         
182         if(aesKeyLenProperty != null && aesKeyLenProperty.length() != 0) {
183             aesKeyLen = Integer.parseInt(aesKeyLenProperty);
184             if(aesKeyLen < 128 || aesKeyLen % 8 != 0) {
185                 throw new KafkaException("Invalid aes key size, should be 128, 192 or 256");
186             }
187         }
188         
189         try {
190             if (opMode == Cipher.DECRYPT_MODE) {
191                 //Consumer
192                 String rsaPrivateKeyFile = (String) configs.get(CRYPTO_RSA_PRIVATEKEY_FILEPATH);
193                 consumerCryptoBundle = new ConsumerCryptoBundle(createRSAPrivateKey(readBytesFromFile(rsaPrivateKeyFile)));
194             } else {
195                 //Producer
196                 String rsaPublicKeyFile = (String) configs.get(CRYPTO_RSA_PUBLICKEY_FILEPATH);
197                 producerCryptoBundle = new ProducerCryptoBundle(createRSAPublicKey(readBytesFromFile(rsaPublicKeyFile)));
198             }
199         } catch (Exception e) {
200             throw new KafkaException(e);
201         }
202     }
203 
204     protected byte[] crypt(byte[] array) throws KafkaException {
205         if (array == null || array.length == 0) {
206             return array;
207         }
208 
209         if (opMode == Cipher.DECRYPT_MODE) {
210             //Consumer
211             return consumerCryptoBundle.aesDecrypt(array);
212         } else {
213             //Producer
214             return producerCryptoBundle.aesEncrypt(array);
215         }
216     }
217 
218     /**
219      * Generate new AES key for the current thread
220      */
221     protected void newKey() {
222         try {
223             producerCryptoBundle.newKey();
224         } catch (Exception e) {
225             throw new KafkaException(e);
226         }
227     }
228 
229     //Hereafter there are only helper methods
230 
231     @SuppressWarnings("unchecked")
232     protected <T> T newInstance(Map<String, ?> map, String key, Class<T> klass) throws KafkaException {
233         Object val = map.get(key);
234         if (val == null) {
235             throw new KafkaException("No value for '" + key + "' found");
236         } else if (val instanceof String) {
237             try {
238                 return (T) Utils.newInstance(Class.forName((String) val));
239             } catch (Exception e) {
240                 throw new KafkaException(e);
241             }
242         } else if (val instanceof Class) {
243             return (T) Utils.newInstance((Class<T>) val);
244         } else {
245             throw new KafkaException("Unexpected type '" + val.getClass() + "' for '" + key + "'");
246         }
247     }
248 
249     private static PrivateKey createRSAPrivateKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
250         if (encodedKey == null || encodedKey.length == 0) {
251             throw new IllegalArgumentException("Key bytes must not be null or empty");
252         }
253 
254         PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(encodedKey);
255         KeyFactory kf = KeyFactory.getInstance(RSA);
256         return kf.generatePrivate(spec);
257     }
258 
259     private static SecretKey createAESSecretKey(byte[] encodedKey) {
260         if (encodedKey == null || encodedKey.length == 0) {
261             throw new IllegalArgumentException("Key bytes must not be null or empty");
262         }
263 
264         return new SecretKeySpec(encodedKey, AES);
265     }
266 
267     private static PublicKey createRSAPublicKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
268         if (encodedKey == null || encodedKey.length == 0) {
269             throw new IllegalArgumentException("Key bytes must not be null or empty");
270         }
271 
272         X509EncodedKeySpec spec = new X509EncodedKeySpec(encodedKey);
273         KeyFactory kf = KeyFactory.getInstance(RSA);
274         return kf.generatePublic(spec);
275     }
276 
277     private static byte[] readBytesFromFile(String filename) throws IOException {
278         if (filename == null) {
279             throw new IllegalArgumentException("Filename must not be null");
280         }
281 
282         File f = new File(filename);
283         DataInputStream dis = new DataInputStream(new FileInputStream(f));
284         byte[] bytes = new byte[(int) f.length()];
285         dis.readFully(bytes);
286         dis.close();
287         return bytes;
288     }
289 
290     private byte[] hash(byte[] toHash) {
291         try {
292             MessageDigest md = MessageDigest.getInstance(hashMethod);
293             md.update(toHash);
294             return md.digest();
295         } catch (Exception e) {
296             throw new KafkaException(e);
297         }
298     }
299 
300     private static byte[] crypt(Cipher c, byte[] plain) throws IllegalBlockSizeException, BadPaddingException {
301         return c.doFinal(plain);
302     }
303 
304     private static byte[] crypt(Cipher c, byte[] plain, int offset, int len) throws IllegalBlockSizeException, BadPaddingException {
305         return c.doFinal(plain, offset, len);
306     }
307     
308     public static byte[] concatenate(byte[] a, byte[] b, byte[] c, byte[] d, byte[] e, byte[] f) {
309         if (a != null && b != null && c != null && d != null && e != null && f != null) {
310             byte[] rv = new byte[a.length + b.length + c.length + d.length + e.length + f.length];
311             System.arraycopy(a, 0, rv, 0, a.length);
312             System.arraycopy(b, 0, rv, a.length, b.length);
313             System.arraycopy(c, 0, rv, a.length + b.length, c.length);
314             System.arraycopy(d, 0, rv, a.length + b.length + c.length, d.length);
315             System.arraycopy(e, 0, rv, a.length + b.length + c.length + d.length, e.length);
316             System.arraycopy(f, 0, rv, a.length + b.length + c.length + d.length + e.length, f.length);
317             return rv;
318         } else {
319             throw new IllegalArgumentException("arrays must not be null");
320         }
321     }
322 }