001package de.saly.kafka.crypto;
002
003import java.io.DataInputStream;
004import java.io.File;
005import java.io.FileInputStream;
006import java.io.IOException;
007import java.security.KeyFactory;
008import java.security.MessageDigest;
009import java.security.NoSuchAlgorithmException;
010import java.security.PrivateKey;
011import java.security.PublicKey;
012import java.security.SecureRandom;
013import java.security.spec.InvalidKeySpecException;
014import java.security.spec.PKCS8EncodedKeySpec;
015import java.security.spec.X509EncodedKeySpec;
016import java.util.Arrays;
017import java.util.HashMap;
018import java.util.Map;
019
020import javax.crypto.BadPaddingException;
021import javax.crypto.Cipher;
022import javax.crypto.IllegalBlockSizeException;
023import javax.crypto.SecretKey;
024import javax.crypto.spec.IvParameterSpec;
025import javax.crypto.spec.SecretKeySpec;
026import javax.xml.bind.DatatypeConverter;
027
028import org.apache.kafka.common.KafkaException;
029import org.apache.kafka.common.utils.Utils;
030
031public abstract class SerdeCryptoBase {
032
033    public static final String CRYPTO_RSA_PRIVATEKEY_FILEPATH = "crypto.rsa.privatekey.filepath"; //consumer
034    public static final String CRYPTO_RSA_PUBLICKEY_FILEPATH = "crypto.rsa.publickey.filepath"; //producer
035    public static final String CRYPTO_HASH_METHOD = "crypto.hash_method";
036    public static final String CRYPTO_IGNORE_DECRYPT_FAILURES = "crypto.ignore_decrypt_failures";
037    public static final String CRYPTO_AES_KEY_LEN = "crypto.aes.key_len";
038    static final byte[] MAGIC_BYTES = new byte[] { (byte) 0xDF, (byte) 0xBB };
039    protected static final String DEFAULT_TRANSFORMATION = "AES/CBC/PKCS5Padding"; //TODO allow other like GCM
040    private static final Map<String, byte[]> aesKeyCache = new HashMap<String, byte[]>();
041    private static final int MAGIC_BYTES_LENGTH = MAGIC_BYTES.length;
042    private static final int HEADER_LENGTH = MAGIC_BYTES_LENGTH + 3;
043    private static final String AES = "AES";
044    private static final String RSA = "RSA";
045    private static final String RSA_TRANFORMATION = "RSA/ECB/OAEPWithSHA-256AndMGF1Padding";
046    private static final int RSA_MULTIPLICATOR = 128;
047    private int opMode;
048    private String hashMethod = "SHA-256";
049    private int aesKeyLen = 128;
050    private boolean ignoreDecryptFailures = false;
051    private ProducerCryptoBundle producerCryptoBundle = null;
052    private ConsumerCryptoBundle consumerCryptoBundle = null;
053
054    protected SerdeCryptoBase() {
055
056    }
057
058    //not thread safe
059    private class ConsumerCryptoBundle {
060
061        private Cipher rsaDecrypt;
062        final Cipher aesDecrypt = Cipher.getInstance(DEFAULT_TRANSFORMATION);
063
064        private ConsumerCryptoBundle(PrivateKey privateKey) throws Exception {
065            rsaDecrypt = Cipher.getInstance(RSA_TRANFORMATION);
066            rsaDecrypt.init(Cipher.DECRYPT_MODE, privateKey);
067        }
068
069        private byte[] aesDecrypt(byte[] encrypted) throws KafkaException {
070            try {
071                if (encrypted[0] == MAGIC_BYTES[0] && encrypted[1] == MAGIC_BYTES[1]) {
072                    final byte hashLen = encrypted[2];
073                    final byte rsaFactor = encrypted[3];
074                    final byte ivLen = encrypted[4];
075                    final int offset = HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen;
076                    final String aesHash = DatatypeConverter.printHexBinary(Arrays.copyOfRange(encrypted, HEADER_LENGTH, HEADER_LENGTH + hashLen));
077                    final byte[] iv = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR),
078                            HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen);
079
080                    byte[] aesKey;
081
082                    if ((aesKey = aesKeyCache.get(aesHash)) != null) {
083                        aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv));
084                        return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset);
085                    } else {
086                        byte[] rsaEncryptedAesKey = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen,
087                                HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR));
088                        aesKey = crypt(rsaDecrypt, rsaEncryptedAesKey);
089                        aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv));
090                        aesKeyCache.put(aesHash, aesKey);
091                        return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset);
092                    }
093                } else {
094                    return encrypted; //not encrypted, just bypass decryption
095                }
096            } catch (Exception e) {
097                if(ignoreDecryptFailures) {
098                    return encrypted; //Probably not encrypted, just bypass decryption
099                }
100                
101                throw new KafkaException("Decrypt failed",e);
102            }
103        }
104    }
105
106    private class ThreadAwareKeyInfo {
107        private final SecretKey aesKey;
108        private final byte[] aesHash;
109        private final byte[] rsaEncyptedAesKey;
110        private final Cipher rsaCipher;
111        private final Cipher aesCipher;
112        private final SecureRandom random = new SecureRandom();
113
114        protected ThreadAwareKeyInfo(PublicKey publicKey) throws Exception {
115            byte[] aesKeyBytes = new byte[aesKeyLen/8];
116            random.nextBytes(aesKeyBytes);
117            aesCipher = Cipher.getInstance(DEFAULT_TRANSFORMATION);
118            aesKey = createAESSecretKey(aesKeyBytes);
119            aesHash = hash(aesKeyBytes);
120            rsaCipher = Cipher.getInstance(RSA_TRANFORMATION);
121            rsaCipher.init(Cipher.ENCRYPT_MODE, publicKey);
122            rsaEncyptedAesKey = crypt(rsaCipher, aesKeyBytes);
123        }
124    }
125
126    //threads safe
127    private class ProducerCryptoBundle {
128
129        private ThreadLocal<ThreadAwareKeyInfo> keyInfo = new ThreadLocal<ThreadAwareKeyInfo>() {
130            @Override
131            protected ThreadAwareKeyInfo initialValue() {
132                try {
133                    return new ThreadAwareKeyInfo(publicKey);
134                } catch (Exception e) {
135                    throw new KafkaException(e);
136                }
137            }
138        };
139        private final PublicKey publicKey;
140
141        private ProducerCryptoBundle(PublicKey publicKey) throws Exception {
142            this.publicKey = publicKey;
143        }
144
145        private void newKey() throws Exception {
146            keyInfo.remove();
147        }
148
149        private byte[] aesEncrypt(byte[] plain) throws KafkaException {
150            final ThreadAwareKeyInfo ki = keyInfo.get();
151
152            try {
153                final byte[] aesIv = new byte[16];
154                ki.random.nextBytes(aesIv);
155                ki.aesCipher.init(Cipher.ENCRYPT_MODE, ki.aesKey, new IvParameterSpec(aesIv));
156                return concatenate(MAGIC_BYTES, new byte[] { (byte) ki.aesHash.length,
157                        (byte) (ki.rsaEncyptedAesKey.length / RSA_MULTIPLICATOR), (byte) aesIv.length }, ki.aesHash, ki.rsaEncyptedAesKey,
158                        aesIv, crypt(ki.aesCipher, plain));
159            } catch (Exception e) {
160                throw new KafkaException(e);
161            }
162        }
163    }
164
165    protected void init(int opMode, Map<String, ?> configs, boolean isKey) throws KafkaException {
166        this.opMode = opMode;
167
168        final String hashMethodProperty = (String) configs.get(CRYPTO_HASH_METHOD);
169        
170        if(hashMethodProperty != null && hashMethodProperty.length() != 0) {
171            hashMethod = hashMethodProperty;
172        }
173        
174        final String ignoreDecryptFailuresProperty = (String) configs.get(CRYPTO_IGNORE_DECRYPT_FAILURES);
175        
176        if(ignoreDecryptFailuresProperty != null && ignoreDecryptFailuresProperty.length() != 0) {
177            ignoreDecryptFailures = Boolean.parseBoolean(ignoreDecryptFailuresProperty);
178        }
179        
180        final String aesKeyLenProperty = (String) configs.get(CRYPTO_AES_KEY_LEN);
181        
182        if(aesKeyLenProperty != null && aesKeyLenProperty.length() != 0) {
183            aesKeyLen = Integer.parseInt(aesKeyLenProperty);
184            if(aesKeyLen < 128 || aesKeyLen % 8 != 0) {
185                throw new KafkaException("Invalid aes key size, should be 128, 192 or 256");
186            }
187        }
188        
189        try {
190            if (opMode == Cipher.DECRYPT_MODE) {
191                //Consumer
192                String rsaPrivateKeyFile = (String) configs.get(CRYPTO_RSA_PRIVATEKEY_FILEPATH);
193                consumerCryptoBundle = new ConsumerCryptoBundle(createRSAPrivateKey(readBytesFromFile(rsaPrivateKeyFile)));
194            } else {
195                //Producer
196                String rsaPublicKeyFile = (String) configs.get(CRYPTO_RSA_PUBLICKEY_FILEPATH);
197                producerCryptoBundle = new ProducerCryptoBundle(createRSAPublicKey(readBytesFromFile(rsaPublicKeyFile)));
198            }
199        } catch (Exception e) {
200            throw new KafkaException(e);
201        }
202    }
203
204    protected byte[] crypt(byte[] array) throws KafkaException {
205        if (array == null || array.length == 0) {
206            return array;
207        }
208
209        if (opMode == Cipher.DECRYPT_MODE) {
210            //Consumer
211            return consumerCryptoBundle.aesDecrypt(array);
212        } else {
213            //Producer
214            return producerCryptoBundle.aesEncrypt(array);
215        }
216    }
217
218    /**
219     * Generate new AES key for the current thread
220     */
221    protected void newKey() {
222        try {
223            producerCryptoBundle.newKey();
224        } catch (Exception e) {
225            throw new KafkaException(e);
226        }
227    }
228
229    //Hereafter there are only helper methods
230
231    @SuppressWarnings("unchecked")
232    protected <T> T newInstance(Map<String, ?> map, String key, Class<T> klass) throws KafkaException {
233        Object val = map.get(key);
234        if (val == null) {
235            throw new KafkaException("No value for '" + key + "' found");
236        } else if (val instanceof String) {
237            try {
238                return (T) Utils.newInstance(Class.forName((String) val));
239            } catch (Exception e) {
240                throw new KafkaException(e);
241            }
242        } else if (val instanceof Class) {
243            return (T) Utils.newInstance((Class<T>) val);
244        } else {
245            throw new KafkaException("Unexpected type '" + val.getClass() + "' for '" + key + "'");
246        }
247    }
248
249    private static PrivateKey createRSAPrivateKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
250        if (encodedKey == null || encodedKey.length == 0) {
251            throw new IllegalArgumentException("Key bytes must not be null or empty");
252        }
253
254        PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(encodedKey);
255        KeyFactory kf = KeyFactory.getInstance(RSA);
256        return kf.generatePrivate(spec);
257    }
258
259    private static SecretKey createAESSecretKey(byte[] encodedKey) {
260        if (encodedKey == null || encodedKey.length == 0) {
261            throw new IllegalArgumentException("Key bytes must not be null or empty");
262        }
263
264        return new SecretKeySpec(encodedKey, AES);
265    }
266
267    private static PublicKey createRSAPublicKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
268        if (encodedKey == null || encodedKey.length == 0) {
269            throw new IllegalArgumentException("Key bytes must not be null or empty");
270        }
271
272        X509EncodedKeySpec spec = new X509EncodedKeySpec(encodedKey);
273        KeyFactory kf = KeyFactory.getInstance(RSA);
274        return kf.generatePublic(spec);
275    }
276
277    private static byte[] readBytesFromFile(String filename) throws IOException {
278        if (filename == null) {
279            throw new IllegalArgumentException("Filename must not be null");
280        }
281
282        File f = new File(filename);
283        DataInputStream dis = new DataInputStream(new FileInputStream(f));
284        byte[] bytes = new byte[(int) f.length()];
285        dis.readFully(bytes);
286        dis.close();
287        return bytes;
288    }
289
290    private byte[] hash(byte[] toHash) {
291        try {
292            MessageDigest md = MessageDigest.getInstance(hashMethod);
293            md.update(toHash);
294            return md.digest();
295        } catch (Exception e) {
296            throw new KafkaException(e);
297        }
298    }
299
300    private static byte[] crypt(Cipher c, byte[] plain) throws IllegalBlockSizeException, BadPaddingException {
301        return c.doFinal(plain);
302    }
303
304    private static byte[] crypt(Cipher c, byte[] plain, int offset, int len) throws IllegalBlockSizeException, BadPaddingException {
305        return c.doFinal(plain, offset, len);
306    }
307    
308    public static byte[] concatenate(byte[] a, byte[] b, byte[] c, byte[] d, byte[] e, byte[] f) {
309        if (a != null && b != null && c != null && d != null && e != null && f != null) {
310            byte[] rv = new byte[a.length + b.length + c.length + d.length + e.length + f.length];
311            System.arraycopy(a, 0, rv, 0, a.length);
312            System.arraycopy(b, 0, rv, a.length, b.length);
313            System.arraycopy(c, 0, rv, a.length + b.length, c.length);
314            System.arraycopy(d, 0, rv, a.length + b.length + c.length, d.length);
315            System.arraycopy(e, 0, rv, a.length + b.length + c.length + d.length, e.length);
316            System.arraycopy(f, 0, rv, a.length + b.length + c.length + d.length + e.length, f.length);
317            return rv;
318        } else {
319            throw new IllegalArgumentException("arrays must not be null");
320        }
321    }
322}