001package de.saly.kafka.crypto; 002 003import java.io.DataInputStream; 004import java.io.File; 005import java.io.FileInputStream; 006import java.io.IOException; 007import java.security.KeyFactory; 008import java.security.MessageDigest; 009import java.security.NoSuchAlgorithmException; 010import java.security.PrivateKey; 011import java.security.PublicKey; 012import java.security.SecureRandom; 013import java.security.spec.InvalidKeySpecException; 014import java.security.spec.PKCS8EncodedKeySpec; 015import java.security.spec.X509EncodedKeySpec; 016import java.util.Arrays; 017import java.util.HashMap; 018import java.util.Map; 019 020import javax.crypto.BadPaddingException; 021import javax.crypto.Cipher; 022import javax.crypto.IllegalBlockSizeException; 023import javax.crypto.SecretKey; 024import javax.crypto.spec.IvParameterSpec; 025import javax.crypto.spec.SecretKeySpec; 026import javax.xml.bind.DatatypeConverter; 027 028import org.apache.kafka.common.KafkaException; 029import org.apache.kafka.common.utils.Utils; 030 031public abstract class SerdeCryptoBase { 032 033 public static final String CRYPTO_RSA_PRIVATEKEY_FILEPATH = "crypto.rsa.privatekey.filepath"; //consumer 034 public static final String CRYPTO_RSA_PUBLICKEY_FILEPATH = "crypto.rsa.publickey.filepath"; //producer 035 public static final String CRYPTO_HASH_METHOD = "crypto.hash_method"; 036 public static final String CRYPTO_IGNORE_DECRYPT_FAILURES = "crypto.ignore_decrypt_failures"; 037 public static final String CRYPTO_AES_KEY_LEN = "crypto.aes.key_len"; 038 static final byte[] MAGIC_BYTES = new byte[] { (byte) 0xDF, (byte) 0xBB }; 039 protected static final String DEFAULT_TRANSFORMATION = "AES/CBC/PKCS5Padding"; //TODO allow other like GCM 040 private static final Map<String, byte[]> aesKeyCache = new HashMap<String, byte[]>(); 041 private static final int MAGIC_BYTES_LENGTH = MAGIC_BYTES.length; 042 private static final int HEADER_LENGTH = MAGIC_BYTES_LENGTH + 3; 043 private static final String AES = "AES"; 044 private static final String RSA = "RSA"; 045 private static final String RSA_TRANFORMATION = "RSA/ECB/OAEPWithSHA-256AndMGF1Padding"; 046 private static final int RSA_MULTIPLICATOR = 128; 047 private int opMode; 048 private String hashMethod = "SHA-256"; 049 private int aesKeyLen = 128; 050 private boolean ignoreDecryptFailures = false; 051 private ProducerCryptoBundle producerCryptoBundle = null; 052 private ConsumerCryptoBundle consumerCryptoBundle = null; 053 054 protected SerdeCryptoBase() { 055 056 } 057 058 //not thread safe 059 private class ConsumerCryptoBundle { 060 061 private Cipher rsaDecrypt; 062 final Cipher aesDecrypt = Cipher.getInstance(DEFAULT_TRANSFORMATION); 063 064 private ConsumerCryptoBundle(PrivateKey privateKey) throws Exception { 065 rsaDecrypt = Cipher.getInstance(RSA_TRANFORMATION); 066 rsaDecrypt.init(Cipher.DECRYPT_MODE, privateKey); 067 } 068 069 private byte[] aesDecrypt(byte[] encrypted) throws KafkaException { 070 try { 071 if (encrypted[0] == MAGIC_BYTES[0] && encrypted[1] == MAGIC_BYTES[1]) { 072 final byte hashLen = encrypted[2]; 073 final byte rsaFactor = encrypted[3]; 074 final byte ivLen = encrypted[4]; 075 final int offset = HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen; 076 final String aesHash = DatatypeConverter.printHexBinary(Arrays.copyOfRange(encrypted, HEADER_LENGTH, HEADER_LENGTH + hashLen)); 077 final byte[] iv = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR), 078 HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR) + ivLen); 079 080 byte[] aesKey; 081 082 if ((aesKey = aesKeyCache.get(aesHash)) != null) { 083 aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv)); 084 return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset); 085 } else { 086 byte[] rsaEncryptedAesKey = Arrays.copyOfRange(encrypted, HEADER_LENGTH + hashLen, 087 HEADER_LENGTH + hashLen + (rsaFactor * RSA_MULTIPLICATOR)); 088 aesKey = crypt(rsaDecrypt, rsaEncryptedAesKey); 089 aesDecrypt.init(Cipher.DECRYPT_MODE, createAESSecretKey(aesKey), new IvParameterSpec(iv)); 090 aesKeyCache.put(aesHash, aesKey); 091 return crypt(aesDecrypt, encrypted, offset, encrypted.length - offset); 092 } 093 } else { 094 return encrypted; //not encrypted, just bypass decryption 095 } 096 } catch (Exception e) { 097 if(ignoreDecryptFailures) { 098 return encrypted; //Probably not encrypted, just bypass decryption 099 } 100 101 throw new KafkaException("Decrypt failed",e); 102 } 103 } 104 } 105 106 private class ThreadAwareKeyInfo { 107 private final SecretKey aesKey; 108 private final byte[] aesHash; 109 private final byte[] rsaEncyptedAesKey; 110 private final Cipher rsaCipher; 111 private final Cipher aesCipher; 112 private final SecureRandom random = new SecureRandom(); 113 114 protected ThreadAwareKeyInfo(PublicKey publicKey) throws Exception { 115 byte[] aesKeyBytes = new byte[aesKeyLen/8]; 116 random.nextBytes(aesKeyBytes); 117 aesCipher = Cipher.getInstance(DEFAULT_TRANSFORMATION); 118 aesKey = createAESSecretKey(aesKeyBytes); 119 aesHash = hash(aesKeyBytes); 120 rsaCipher = Cipher.getInstance(RSA_TRANFORMATION); 121 rsaCipher.init(Cipher.ENCRYPT_MODE, publicKey); 122 rsaEncyptedAesKey = crypt(rsaCipher, aesKeyBytes); 123 } 124 } 125 126 //threads safe 127 private class ProducerCryptoBundle { 128 129 private ThreadLocal<ThreadAwareKeyInfo> keyInfo = new ThreadLocal<ThreadAwareKeyInfo>() { 130 @Override 131 protected ThreadAwareKeyInfo initialValue() { 132 try { 133 return new ThreadAwareKeyInfo(publicKey); 134 } catch (Exception e) { 135 throw new KafkaException(e); 136 } 137 } 138 }; 139 private final PublicKey publicKey; 140 141 private ProducerCryptoBundle(PublicKey publicKey) throws Exception { 142 this.publicKey = publicKey; 143 } 144 145 private void newKey() throws Exception { 146 keyInfo.remove(); 147 } 148 149 private byte[] aesEncrypt(byte[] plain) throws KafkaException { 150 final ThreadAwareKeyInfo ki = keyInfo.get(); 151 152 try { 153 final byte[] aesIv = new byte[16]; 154 ki.random.nextBytes(aesIv); 155 ki.aesCipher.init(Cipher.ENCRYPT_MODE, ki.aesKey, new IvParameterSpec(aesIv)); 156 return concatenate(MAGIC_BYTES, new byte[] { (byte) ki.aesHash.length, 157 (byte) (ki.rsaEncyptedAesKey.length / RSA_MULTIPLICATOR), (byte) aesIv.length }, ki.aesHash, ki.rsaEncyptedAesKey, 158 aesIv, crypt(ki.aesCipher, plain)); 159 } catch (Exception e) { 160 throw new KafkaException(e); 161 } 162 } 163 } 164 165 protected void init(int opMode, Map<String, ?> configs, boolean isKey) throws KafkaException { 166 this.opMode = opMode; 167 168 final String hashMethodProperty = (String) configs.get(CRYPTO_HASH_METHOD); 169 170 if(hashMethodProperty != null && hashMethodProperty.length() != 0) { 171 hashMethod = hashMethodProperty; 172 } 173 174 final String ignoreDecryptFailuresProperty = (String) configs.get(CRYPTO_IGNORE_DECRYPT_FAILURES); 175 176 if(ignoreDecryptFailuresProperty != null && ignoreDecryptFailuresProperty.length() != 0) { 177 ignoreDecryptFailures = Boolean.parseBoolean(ignoreDecryptFailuresProperty); 178 } 179 180 final String aesKeyLenProperty = (String) configs.get(CRYPTO_AES_KEY_LEN); 181 182 if(aesKeyLenProperty != null && aesKeyLenProperty.length() != 0) { 183 aesKeyLen = Integer.parseInt(aesKeyLenProperty); 184 if(aesKeyLen < 128 || aesKeyLen % 8 != 0) { 185 throw new KafkaException("Invalid aes key size, should be 128, 192 or 256"); 186 } 187 } 188 189 try { 190 if (opMode == Cipher.DECRYPT_MODE) { 191 //Consumer 192 String rsaPrivateKeyFile = (String) configs.get(CRYPTO_RSA_PRIVATEKEY_FILEPATH); 193 consumerCryptoBundle = new ConsumerCryptoBundle(createRSAPrivateKey(readBytesFromFile(rsaPrivateKeyFile))); 194 } else { 195 //Producer 196 String rsaPublicKeyFile = (String) configs.get(CRYPTO_RSA_PUBLICKEY_FILEPATH); 197 producerCryptoBundle = new ProducerCryptoBundle(createRSAPublicKey(readBytesFromFile(rsaPublicKeyFile))); 198 } 199 } catch (Exception e) { 200 throw new KafkaException(e); 201 } 202 } 203 204 protected byte[] crypt(byte[] array) throws KafkaException { 205 if (array == null || array.length == 0) { 206 return array; 207 } 208 209 if (opMode == Cipher.DECRYPT_MODE) { 210 //Consumer 211 return consumerCryptoBundle.aesDecrypt(array); 212 } else { 213 //Producer 214 return producerCryptoBundle.aesEncrypt(array); 215 } 216 } 217 218 /** 219 * Generate new AES key for the current thread 220 */ 221 protected void newKey() { 222 try { 223 producerCryptoBundle.newKey(); 224 } catch (Exception e) { 225 throw new KafkaException(e); 226 } 227 } 228 229 //Hereafter there are only helper methods 230 231 @SuppressWarnings("unchecked") 232 protected <T> T newInstance(Map<String, ?> map, String key, Class<T> klass) throws KafkaException { 233 Object val = map.get(key); 234 if (val == null) { 235 throw new KafkaException("No value for '" + key + "' found"); 236 } else if (val instanceof String) { 237 try { 238 return (T) Utils.newInstance(Class.forName((String) val)); 239 } catch (Exception e) { 240 throw new KafkaException(e); 241 } 242 } else if (val instanceof Class) { 243 return (T) Utils.newInstance((Class<T>) val); 244 } else { 245 throw new KafkaException("Unexpected type '" + val.getClass() + "' for '" + key + "'"); 246 } 247 } 248 249 private static PrivateKey createRSAPrivateKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException { 250 if (encodedKey == null || encodedKey.length == 0) { 251 throw new IllegalArgumentException("Key bytes must not be null or empty"); 252 } 253 254 PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(encodedKey); 255 KeyFactory kf = KeyFactory.getInstance(RSA); 256 return kf.generatePrivate(spec); 257 } 258 259 private static SecretKey createAESSecretKey(byte[] encodedKey) { 260 if (encodedKey == null || encodedKey.length == 0) { 261 throw new IllegalArgumentException("Key bytes must not be null or empty"); 262 } 263 264 return new SecretKeySpec(encodedKey, AES); 265 } 266 267 private static PublicKey createRSAPublicKey(byte[] encodedKey) throws NoSuchAlgorithmException, InvalidKeySpecException { 268 if (encodedKey == null || encodedKey.length == 0) { 269 throw new IllegalArgumentException("Key bytes must not be null or empty"); 270 } 271 272 X509EncodedKeySpec spec = new X509EncodedKeySpec(encodedKey); 273 KeyFactory kf = KeyFactory.getInstance(RSA); 274 return kf.generatePublic(spec); 275 } 276 277 private static byte[] readBytesFromFile(String filename) throws IOException { 278 if (filename == null) { 279 throw new IllegalArgumentException("Filename must not be null"); 280 } 281 282 File f = new File(filename); 283 DataInputStream dis = new DataInputStream(new FileInputStream(f)); 284 byte[] bytes = new byte[(int) f.length()]; 285 dis.readFully(bytes); 286 dis.close(); 287 return bytes; 288 } 289 290 private byte[] hash(byte[] toHash) { 291 try { 292 MessageDigest md = MessageDigest.getInstance(hashMethod); 293 md.update(toHash); 294 return md.digest(); 295 } catch (Exception e) { 296 throw new KafkaException(e); 297 } 298 } 299 300 private static byte[] crypt(Cipher c, byte[] plain) throws IllegalBlockSizeException, BadPaddingException { 301 return c.doFinal(plain); 302 } 303 304 private static byte[] crypt(Cipher c, byte[] plain, int offset, int len) throws IllegalBlockSizeException, BadPaddingException { 305 return c.doFinal(plain, offset, len); 306 } 307 308 public static byte[] concatenate(byte[] a, byte[] b, byte[] c, byte[] d, byte[] e, byte[] f) { 309 if (a != null && b != null && c != null && d != null && e != null && f != null) { 310 byte[] rv = new byte[a.length + b.length + c.length + d.length + e.length + f.length]; 311 System.arraycopy(a, 0, rv, 0, a.length); 312 System.arraycopy(b, 0, rv, a.length, b.length); 313 System.arraycopy(c, 0, rv, a.length + b.length, c.length); 314 System.arraycopy(d, 0, rv, a.length + b.length + c.length, d.length); 315 System.arraycopy(e, 0, rv, a.length + b.length + c.length + d.length, e.length); 316 System.arraycopy(f, 0, rv, a.length + b.length + c.length + d.length + e.length, f.length); 317 return rv; 318 } else { 319 throw new IllegalArgumentException("arrays must not be null"); 320 } 321 } 322}