Ratchet: Acks and callbacks

- Store callbacks and ES acks in OutboundSession
- Calls from engine to SKM for callbacks and acks
- Pass key ID and remote key back in SessionKeyAndNonce
- Implmenent multiple acks in ACK block
This commit is contained in:
zzz
2020-03-28 13:22:32 +00:00
parent 23634afbc9
commit eeb7ea4cae
6 changed files with 263 additions and 35 deletions

View File

@ -178,7 +178,7 @@ public final class ECIESAEADEngine {
if (state == null) {
if (shouldDebug)
_log.debug("Decrypting ES with tag: " + st.toBase64() + ": key: " + key.toBase64() + ": " + data.length + " bytes");
decrypted = decryptExistingSession(tag, data, key, targetPrivateKey);
decrypted = decryptExistingSession(tag, data, key, targetPrivateKey, keyManager);
} else if (data.length >= MIN_NSR_SIZE) {
if (shouldDebug)
_log.debug("Decrypting NSR with tag: " + st.toBase64() + ": key: " + key.toBase64() + ": " + data.length + " bytes");
@ -457,14 +457,16 @@ public final class ECIESAEADEngine {
*
* @param tag 8 bytes for ad, same as first 8 bytes of data
* @param data 24 bytes minimum, first 8 bytes will be skipped
*
* @param keyManager for ack callbacks
* @return decrypted data or null on failure
*
*/
private CloveSet decryptExistingSession(byte[] tag, byte[] data, SessionKeyAndNonce key, PrivateKey targetPrivateKey)
private CloveSet decryptExistingSession(byte[] tag, byte[] data, SessionKeyAndNonce key,
PrivateKey targetPrivateKey, RatchetSKM keyManager)
throws DataFormatException {
// TODO decrypt in place?
byte decrypted[] = decryptAEADBlock(tag, data, TAGLEN, data.length - TAGLEN, key, key.getNonce());
int nonce = key.getNonce();
byte decrypted[] = decryptAEADBlock(tag, data, TAGLEN, data.length - TAGLEN, key, nonce);
if (decrypted == null) {
if (_log.shouldWarn())
_log.warn("Decrypt of ES failed");
@ -475,7 +477,8 @@ public final class ECIESAEADEngine {
_log.warn("Zero length payload in ES");
return null;
}
PLCallback pc = new PLCallback();
PublicKey remote = key.getRemoteKey();
PLCallback pc = new PLCallback(keyManager, remote);
try {
int blocks = RatchetPayload.processPayload(_context, pc, decrypted, 0, decrypted.length, false);
if (_log.shouldDebug())
@ -489,6 +492,9 @@ public final class ECIESAEADEngine {
if (_log.shouldWarn())
_log.warn("No garlic block in ES payload");
}
if (pc.ackRequested) {
keyManager.ackRequested(remote, key.getID(), nonce);
}
int num = pc.cloveSet.size();
// return non-null even if zero cloves
GarlicClove[] arr = new GarlicClove[num];
@ -608,7 +614,7 @@ public final class ECIESAEADEngine {
}
if (_log.shouldDebug())
_log.debug("Encrypting as ES to " + target + " with key " + re.key + " and tag " + re.tag.toBase64());
byte rv[] = encryptExistingSession(cloves, target, re, replyDI, callback);
byte rv[] = encryptExistingSession(cloves, target, re, replyDI, callback, keyManager);
return rv;
}
@ -647,7 +653,7 @@ public final class ECIESAEADEngine {
if (_log.shouldDebug())
_log.debug("State before encrypt new session: " + state);
byte[] payload = createPayload(cloves, cloves.getExpiration(), replyDI, null);
byte[] payload = createPayload(cloves, cloves.getExpiration(), replyDI, null, null);
byte[] enc = new byte[KEYLEN + KEYLEN + MACLEN + payload.length + MACLEN];
try {
@ -707,7 +713,7 @@ public final class ECIESAEADEngine {
if (_log.shouldDebug())
_log.debug("State after mixhash tag before encrypt new session reply: " + state);
byte[] payload = createPayload(cloves, 0, replyDI, null);
byte[] payload = createPayload(cloves, 0, replyDI, null, null);
// part 1 - tag and empty payload
byte[] enc = new byte[TAGLEN + KEYLEN + MACLEN + payload.length + MACLEN];
@ -771,17 +777,19 @@ public final class ECIESAEADEngine {
* @return encrypted data or null on failure
*/
private byte[] encryptExistingSession(CloveSet cloves, PublicKey target, RatchetEntry re,
DeliveryInstructions replyDI, ReplyCallback callback) {
DeliveryInstructions replyDI, ReplyCallback callback,
RatchetSKM keyManager) {
//
if (ACKREQ_IN_ES && replyDI == null)
replyDI = new DeliveryInstructions();
byte rawTag[] = re.tag.getData();
byte[] payload = createPayload(cloves, 0, replyDI, re.nextKey);
byte[] payload = createPayload(cloves, 0, replyDI, re.nextKey, re.acksToSend);
SessionKeyAndNonce key = re.key;
byte encr[] = encryptAEADBlock(rawTag, payload, key, key.getNonce());
int nonce = key.getNonce();
byte encr[] = encryptAEADBlock(rawTag, payload, key, nonce);
System.arraycopy(rawTag, 0, encr, 0, TAGLEN);
if (callback != null) {
// TODO
keyManager.registerCallback(target, re.keyID, nonce, callback);
}
return encr;
}
@ -826,10 +834,30 @@ public final class ECIESAEADEngine {
private class PLCallback implements RatchetPayload.PayloadCallback {
public final List<GarlicClove> cloveSet = new ArrayList<GarlicClove>(3);
private final RatchetSKM skm;
private final PublicKey remote;
public long datetime;
public NextSessionKey nextKey;
public boolean ackRequested;
/**
* NS/NSR
*/
public PLCallback() {
this(null, null);
}
/**
* ES
* @param keyManager only for ES, otherwise null
* @param remoteKey only for ES, otherwise null
* @since 0.9.46
*/
public PLCallback(RatchetSKM keyManager, PublicKey remoteKey) {
skm = keyManager;
remote = remoteKey;
}
public void gotDateTime(long time) {
if (_log.shouldDebug())
_log.debug("Got DATE block: " + DataHelper.formatTime(time));
@ -858,6 +886,10 @@ public final class ECIESAEADEngine {
public void gotAck(int id, int n) {
if (_log.shouldDebug())
_log.debug("Got ACK block: " + id + " / " + n);
if (skm != null)
skm.receivedACK(remote, id, n);
else if (_log.shouldWarn())
_log.warn("ACK in NS/NSR?");
}
public void gotAckRequest(int id, DeliveryInstructions di) {
@ -885,9 +917,11 @@ public final class ECIESAEADEngine {
/**
* @param expiration if greater than zero, add a DateTime block
* @param replyDI non-null to request an ack, or null
* @param acksTOSend may be null
*/
private byte[] createPayload(CloveSet cloves, long expiration,
DeliveryInstructions replyDI, NextSessionKey nextKey) {
DeliveryInstructions replyDI, NextSessionKey nextKey,
List<Integer> acksToSend) {
int count = cloves.getCloveCount();
int numblocks = count + 1;
if (expiration > 0)
@ -896,6 +930,8 @@ public final class ECIESAEADEngine {
numblocks++;
if (nextKey != null)
numblocks++;
if (acksToSend != null)
numblocks++;
int len = 0;
List<Block> blocks = new ArrayList<Block>(numblocks);
if (expiration > 0) {
@ -921,6 +957,11 @@ public final class ECIESAEADEngine {
blocks.add(block);
len += block.getTotalLength();
}
if (acksToSend != null) {
Block block = new AckBlock(acksToSend);
blocks.add(block);
len += block.getTotalLength();
}
int padlen = 1 + _context.random().nextInt(MAXPAD);
// random data
//Block block = new PaddingBlock(_context, padlen);

View File

@ -1,5 +1,7 @@
package net.i2p.router.crypto.ratchet;
import java.util.List;
import net.i2p.data.SessionKey;
/**
@ -15,19 +17,21 @@ class RatchetEntry {
public final int keyID;
public final int pn;
public final NextSessionKey nextKey;
public final List<Integer> acksToSend;
/** outbound - calculated key */
public RatchetEntry(RatchetSessionTag tag, SessionKeyAndNonce key, int keyID, int pn) {
this(tag, key, keyID, pn, null);
this(tag, key, keyID, pn, null, null);
}
public RatchetEntry(RatchetSessionTag tag, SessionKeyAndNonce key, int keyID, int pn,
NextSessionKey nextKey) {
NextSessionKey nextKey, List<Integer> acksToSend) {
this.tag = tag;
this.key = key;
this.keyID = keyID;
this.pn = pn;
this.nextKey = nextKey;
this.acksToSend = acksToSend;
}
@Override

View File

@ -379,8 +379,21 @@ class RatchetPayload {
DataHelper.toLong(data, 2, 2, n);
}
/**
* @param acks each is id &lt;&lt; 16 | n
*/
public AckBlock(List<Integer> acks) {
super(BLOCK_ACKKEY);
data = new byte[4 * acks.size()];
int i = 0;
for (Integer a : acks) {
toInt4(data, i, a.intValue());
i += 4;
}
}
public int getDataLength() {
return 4;
return data.length;
}
public int writeData(byte[] tgt, int off) {
@ -472,4 +485,18 @@ class RatchetPayload {
value >>= 8;
}
}
/**
* Big endian.
* Same as DataHelper.toLong(target, offset, 4, value) but allows negative value
*
* @throws ArrayIndexOutOfBoundsException
* @since 0.9.46
*/
private static void toInt4(byte target[], int offset, int value) {
for (int i = offset + 3; i >= offset; i--) {
target[i] = (byte) value;
value >>= 8;
}
}
}

View File

@ -15,6 +15,7 @@ import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import com.southernstorm.noise.protocol.HandshakeState;
@ -621,23 +622,22 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
removed++;
}
}
if (removed > 0 && _log.shouldInfo())
_log.info("Expired inbound: " + removed);
// outbound
int oremoved = 0;
int cremoved = 0;
exp = now - (SESSION_LIFETIME_MAX_MS / 2);
for (Iterator<OutboundSession> iter = _outboundSessions.values().iterator(); iter.hasNext();) {
OutboundSession sess = iter.next();
oremoved += sess.expireTags();
oremoved += sess.expireTags(now);
cremoved += sess.expireCallbacks(now);
if (sess.getLastUsedDate() < exp) {
iter.remove();
oremoved++;
}
}
if (oremoved > 0 && _log.shouldInfo())
_log.info("Expired outbound: " + oremoved);
// pending outbound
int premoved = 0;
exp = now - SESSION_PENDING_DURATION_MS;
synchronized (_pendingOutboundSessions) {
@ -645,6 +645,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
List<OutboundSession> pending = iter.next();
for (Iterator<OutboundSession> liter = pending.iterator(); liter.hasNext();) {
OutboundSession sess = liter.next();
cremoved += sess.expireCallbacks(now);
if (sess.getEstablishedDate() < exp) {
liter.remove();
premoved++;
@ -654,8 +655,9 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
iter.remove();
}
}
if (premoved > 0 && _log.shouldInfo())
_log.info("Expired pending: " + premoved);
if ((removed > 0 || oremoved > 0 || premoved > 0 || cremoved > 0) && _log.shouldInfo())
_log.info("Expired inbound: " + removed + ", outbound: " + oremoved +
", pending: " + premoved + ", callbacks: " + cremoved);
return removed + oremoved + premoved;
}
@ -680,6 +682,48 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
/// end SessionTagListener ///
/// ACKS ///
/**
* @since 0.9.46
*/
void registerCallback(PublicKey target, int id, int n, ReplyCallback callback) {
if (_log.shouldInfo())
_log.info("Register callback tgt " + target + " id=" + id + " n=" + n + " callback " + callback);
OutboundSession sess = getSession(target);
if (sess != null)
sess.registerCallback(id, n, callback);
else if (_log.shouldWarn())
_log.warn("no session found for register callback");
}
/**
* @since 0.9.46
*/
void receivedACK(PublicKey target, int id, int n) {
OutboundSession sess = getSession(target);
if (sess != null)
sess.receivedACK(id, n);
else if (_log.shouldWarn())
_log.warn("no session found for received ack");
}
/**
* @since 0.9.46
*/
void ackRequested(PublicKey target, int id, int n) {
if (_log.shouldInfo())
_log.info("rcvd ACK REQUEST id=" + id + " n=" + n);
OutboundSession sess = getSession(target);
if (sess != null)
sess.ackRequested(id, n);
else if (_log.shouldWarn())
_log.warn("no session found for ack req");
}
/// end ACKS ///
/**
* Return a map of session key to a set of inbound RatchetTagSets for that SessionKey
*/
@ -827,6 +871,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* In order, earliest first.
*/
private final List<RatchetTagSet> _tagSets;
private final ConcurrentHashMap<Integer, ReplyCallback> _callbacks;
private final LinkedBlockingQueue<Integer> _acksToSend;
/**
* Set to true after first tagset is acked.
* Upon repeated failures, we may revert back to false.
@ -841,6 +887,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
private int _consecutiveFailures;
private static final int MAX_FAILS = 2;
private static final int MAX_SEND_ACKS = 8;
private static final int DEBUG_OB_NSR = 0x10001;
private static final int DEBUG_IB_NSR = 0x10002;
@ -856,6 +903,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_lastUsed = _established;
_unackedTagSets = new HashSet<RatchetTagSet>(4);
_tagSets = new ArrayList<RatchetTagSet>(6);
_callbacks = new ConcurrentHashMap<Integer, ReplyCallback>();
_acksToSend = new LinkedBlockingQueue<Integer>();
// generate expected tagset
byte[] ck = state.getChainingKey();
byte[] tagsetkey = new byte[32];
@ -1076,8 +1125,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
/**
* Expire old tags, returning the number of tag sets removed
*/
public int expireTags() {
long now = _context.clock().now();
public int expireTags(long now) {
int removed = 0;
synchronized (_tagSets) {
for (Iterator<RatchetTagSet> iter = _tagSets.iterator(); iter.hasNext(); ) {
@ -1113,9 +1161,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (tag != null) {
set.setDate(now);
SessionKeyAndNonce skn = set.consumeNextKey();
// TODO key ID and PN
// TODO next key
return new RatchetEntry(tag, skn, 0, 0);
// TODO PN
return new RatchetEntry(tag, skn, set.getID(), 0, set.getNextKey(), getAcksToSend());
} else if (_log.shouldInfo()) {
_log.info("Removing empty " + set);
}
@ -1180,5 +1227,77 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
public boolean getAckReceived() {
return _acked;
}
/**
* @since 0.9.46
*/
public void registerCallback(int id, int n, ReplyCallback callback) {
Integer key = Integer.valueOf((id << 16) | n);
ReplyCallback old = _callbacks.putIfAbsent(key, callback);
if (old != null) {
if (old.getExpiration() < _context.clock().now())
_callbacks.put(key, callback);
else if (_log.shouldWarn())
_log.warn("Not replacing callback: " + old);
}
}
/**
* @since 0.9.46
*/
public void receivedACK(int id, int n) {
Integer key = Integer.valueOf((id << 16) | n);
ReplyCallback callback = _callbacks.remove(key);
if (callback != null) {
if (_log.shouldInfo())
_log.info("ACK rcvd ID " + id + " n=" + n + " callback " + callback);
callback.onReply();
} else {
if (_log.shouldInfo())
_log.info("ACK rcvd ID " + id + " n=" + n + ", no callback");
}
}
/**
* @since 0.9.46
*/
public void ackRequested(int id, int n) {
Integer key = Integer.valueOf((id << 16) | n);
_acksToSend.offer(key);
}
/**
* @return the acks to send, non empty, or null
* @since 0.9.46
*/
private List<Integer> getAcksToSend() {
if (_acksToSend == null)
return null;
int sz = _acksToSend.size();
if (sz == 0)
return null;
List<Integer> rv = new ArrayList<Integer>(Math.min(sz, MAX_SEND_ACKS));
_acksToSend.drainTo(rv, MAX_SEND_ACKS);
if (rv.isEmpty())
return null;
return rv;
}
/**
* @since 0.9.46
*/
public int expireCallbacks(long now) {
if (_callbacks.isEmpty())
return 0;
int rv = 0;
for (Iterator<ReplyCallback> iter = _callbacks.values().iterator(); iter.hasNext();) {
ReplyCallback cb = iter.next();
if (cb.getExpiration() < now) {
iter.remove();
rv++;
}
}
return rv;
}
}
}

View File

@ -40,8 +40,10 @@ class RatchetTagSet implements TagSetHandle {
private final PublicKey _remoteKey;
private final SessionKey _key;
private final HandshakeState _state;
// inbound only, else null
// We use object for tags because we must do indexOfValueByValue()
private final SparseArray<RatchetSessionTag> _sessionTags;
// inbound ES only, else null
// We use byte[] for key to save space, because we don't need indexOfValueByValue()
private final SparseArray<byte[]> _sessionKeys;
private final HKDF hkdf;
@ -335,7 +337,7 @@ class RatchetTagSet implements TagSetHandle {
byte[] rv = _sessionKeys.valueAt(kidx);
_sessionKeys.removeAt(kidx);
addTags(tagnum);
return new SessionKeyAndNonce(rv, tagnum);
return new SessionKeyAndNonce(rv, _id, tagnum, _remoteKey);
} else if (tagnum > _lastKey) {
// if there's any gaps, catch up and store
for (int i = _lastKey + 1; i < tagnum; i++) {
@ -409,8 +411,8 @@ class RatchetTagSet implements TagSetHandle {
}
/**
* For outbound only.
* Call after consumeNextTag();
* For outbound, call after consumeNextTag().
* Also called by consume() to catch up for inbound.
*
* @return a key and nonce, non-null
*/
@ -422,7 +424,8 @@ class RatchetTagSet implements TagSetHandle {
byte[] key = new byte[32];
hkdf.calculate(_symmkey_ck, _symmkey_constant, INFO_5, _symmkey_ck, key, 0);
_lastKey++;
return new SessionKeyAndNonce(key, _lastKey);
// fill in ID and remoteKey as this may be for inbound
return new SessionKeyAndNonce(key, _id, _lastKey, _remoteKey);
}
/**
@ -442,7 +445,7 @@ class RatchetTagSet implements TagSetHandle {
}
@Override
public String toString() {
public synchronized String toString() {
StringBuilder buf = new StringBuilder(256);
if (_sessionTags != null)
buf.append("Inbound ");

View File

@ -2,24 +2,39 @@ package net.i2p.router.crypto.ratchet;
import com.southernstorm.noise.protocol.HandshakeState;
import net.i2p.data.PublicKey;
import net.i2p.data.SessionKey;
/**
* A session key is 32 bytes of data.
* Nonce should be 65535 or less.
*
* This is what is returned from RatchetTagSet.consume().
* RatchetSKM puts it in a RatchetEntry and returns it to ECIESAEADEngine.
*
* @since 0.9.44
*/
class SessionKeyAndNonce extends SessionKey {
private final int _nonce;
private final int _id, _nonce;
private final HandshakeState _state;
private final PublicKey _remoteKey;
/**
* For Existing Session
* For outbound Existing Session
*/
public SessionKeyAndNonce(byte data[], int nonce) {
this(data, 0, nonce, null);
}
/**
* For inbound Existing Session
* @since 0.9.46
*/
public SessionKeyAndNonce(byte data[], int id, int nonce, PublicKey remoteKey) {
super(data);
_id = id;
_nonce = nonce;
_remoteKey = remoteKey;
_state = null;
}
@ -28,7 +43,9 @@ class SessionKeyAndNonce extends SessionKey {
*/
public SessionKeyAndNonce(HandshakeState state) {
super();
_id = 0;
_nonce = 0;
_remoteKey = null;
_state = state;
}
@ -39,6 +56,23 @@ class SessionKeyAndNonce extends SessionKey {
return _nonce;
}
/**
* For inbound ES, else 0
* @since 0.9.46
*/
public int getID() {
return _id;
}
/**
* For inbound ES, else null.
* For NSR, use getHansdhakeState().getRemotePublicKey().getPublicKey().
* @since 0.9.46
*/
public PublicKey getRemoteKey() {
return _remoteKey;
}
/**
* For inbound NSR only, else null.
* MUST be cloned before processing NSR.