Ratchet: Simplify OB Session; there can be only one active OB TS

Fix bugs handling of out-of-order nextkeys
Expire unacked tagsets every time through
Remore unused OB session methods
This commit is contained in:
zzz
2020-04-06 20:27:47 +00:00
parent 14b33a1e4c
commit f6b5a2d493

View File

@ -850,16 +850,15 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* Before the first ack, all tagsets go here. These are never expired, we rely * Before the first ack, all tagsets go here. These are never expired, we rely
* on the callers to call failTags() or ackTags() to remove them from this list. * on the callers to call failTags() or ackTags() to remove them from this list.
* Actually we now do a failsafe expire. * Actually we now do a failsafe expire.
* Synch on _tagSets to access this. * Unsynchronized, sync to use.
* No particular order. * No particular order.
*/ */
private final Set<RatchetTagSet> _unackedTagSets; private final Set<RatchetTagSet> _unackedTagSets;
/** /**
* As tagsets are acked, they go here. * There is only one active outbound tagset.
* After the first ack, new tagsets go here (i.e. presumed acked) * Synch on _unackedTagSets to access this.
* In order, earliest first.
*/ */
private final List<RatchetTagSet> _tagSets; private RatchetTagSet _tagSet;
private final ConcurrentHashMap<Integer, ReplyCallback> _callbacks; private final ConcurrentHashMap<Integer, ReplyCallback> _callbacks;
private final LinkedBlockingQueue<Integer> _acksToSend; private final LinkedBlockingQueue<Integer> _acksToSend;
/** /**
@ -869,11 +868,6 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* to deliver the next set of tags. * to deliver the next set of tags.
*/ */
private volatile boolean _acked; private volatile boolean _acked;
/**
* Fail count
* Synch on _tagSets to access this.
*/
private int _consecutiveFailures;
// next key // next key
private int _myOBKeyID = -1; private int _myOBKeyID = -1;
@ -908,7 +902,6 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_established = _context.clock().now(); _established = _context.clock().now();
_lastUsed = _established; _lastUsed = _established;
_unackedTagSets = new HashSet<RatchetTagSet>(4); _unackedTagSets = new HashSet<RatchetTagSet>(4);
_tagSets = new ArrayList<RatchetTagSet>(6);
_callbacks = new ConcurrentHashMap<Integer, ReplyCallback>(); _callbacks = new ConcurrentHashMap<Integer, ReplyCallback>();
_acksToSend = new LinkedBlockingQueue<Integer>(); _acksToSend = new LinkedBlockingQueue<Integer>();
// generate expected tagset // generate expected tagset
@ -924,7 +917,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
RatchetTagSet tagset = new RatchetTagSet(_hkdf, state, RatchetTagSet tagset = new RatchetTagSet(_hkdf, state,
rk, tk, rk, tk,
_established); _established);
_tagSets.add(tagset); _tagSet = tagset;
_state = null; _state = null;
if (_log.shouldDebug()) if (_log.shouldDebug())
_log.debug("New OB Session, rk = " + rk + " tk = " + tk + " 1st tagset:\n" + tagset); _log.debug("New OB Session, rk = " + rk + " tk = " + tk + " 1st tagset:\n" + tagset);
@ -972,7 +965,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.debug("Update IB Session, rk = " + rk + " tk = " + Base64.encode(k_ab) + " ES tagset:\n" + tagset_ab); _log.debug("Update IB Session, rk = " + rk + " tk = " + Base64.encode(k_ab) + " ES tagset:\n" + tagset_ab);
_log.debug("Pending OB Session, rk = " + rk + " tk = " + Base64.encode(k_ba) + " ES tagset:\n" + tagset_ba); _log.debug("Pending OB Session, rk = " + rk + " tk = " + Base64.encode(k_ba) + " ES tagset:\n" + tagset_ba);
} }
synchronized (_tagSets) { synchronized (_unackedTagSets) {
_unackedTagSets.add(tagset_ba); _unackedTagSets.add(tagset_ba);
_NSRcallback = callback; _NSRcallback = callback;
} }
@ -989,16 +982,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.debug("Update OB Session, rk = " + rk + " tk = " + Base64.encode(k_ab) + " ES tagset:\n" + tagset_ab); _log.debug("Update OB Session, rk = " + rk + " tk = " + Base64.encode(k_ab) + " ES tagset:\n" + tagset_ab);
_log.debug("Update IB Session, rk = " + rk + " tk = " + Base64.encode(k_ba) + " ES tagset:\n" + tagset_ba); _log.debug("Update IB Session, rk = " + rk + " tk = " + Base64.encode(k_ba) + " ES tagset:\n" + tagset_ba);
} }
synchronized (_tagSets) { synchronized (_unackedTagSets) {
for (Iterator<RatchetTagSet> iter = _tagSets.iterator(); iter.hasNext(); ) { _tagSet = tagset_ab;
RatchetTagSet set = iter.next();
if (set.getID() == RatchetTagSet.DEBUG_OB_NSR) {
iter.remove();
if (_log.shouldDebug())
_log.debug("Removed OB NSR tagset:\n" + set);
}
}
_tagSets.add(tagset_ab);
_unackedTagSets.clear(); _unackedTagSets.clear();
} }
// We can't destroy the original state, as more NSRs may come in // We can't destroy the original state, as more NSRs may come in
@ -1019,7 +1004,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
boolean isRequest = key.isRequest(); boolean isRequest = key.isRequest();
boolean hasKey = key.getData() != null; boolean hasKey = key.getData() != null;
int id = key.getID(); int id = key.getID();
synchronized (_tagSets) { synchronized (_unackedTagSets) {
if (isReverse) { if (isReverse) {
// this is about my outbound tag set, // this is about my outbound tag set,
// and is an ack of new key sent // and is an ack of new key sent
@ -1038,7 +1023,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
hisLastIBKeyID = -1; hisLastIBKeyID = -1;
else else
hisLastIBKeyID = _hisIBKey.getID(); hisLastIBKeyID = _hisIBKey.getID();
_hisIBKey = key; // save as it may be replaced below; will be stored after all error checks complete
NextSessionKey receivedKey = key;
if (hisLastIBKeyID != id) { if (hisLastIBKeyID != id) {
// got a new key, use it // got a new key, use it
if (hisLastIBKeyID != id - 1) { if (hisLastIBKeyID != id - 1) {
@ -1057,7 +1043,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (hasKey) { if (hasKey) {
// got a old key id but new data? // got a old key id but new data?
if (_hisIBKeyWithData != null && _log.shouldWarn()) if (_hisIBKeyWithData != null && _log.shouldWarn())
_log.warn("Got nextkey for OB with data, didn't match previous " + key); _log.warn("Got nextkey for OB with data: " + key + " didn't match previous " + _hisIBKey + " / " + _hisIBKeyWithData);
return;
} else { } else {
if (_hisIBKeyWithData == null || if (_hisIBKeyWithData == null ||
_hisIBKeyWithData.getID() != key.getID()) { _hisIBKeyWithData.getID() != key.getID()) {
@ -1076,11 +1063,9 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
else else
oldtsID = 1 + _myOBKeyID + hisLastIBKeyID; oldtsID = 1 + _myOBKeyID + hisLastIBKeyID;
RatchetTagSet oldts = null; RatchetTagSet oldts = null;
for (RatchetTagSet ts : _tagSets) { if (_tagSet != null) {
if (ts.getID() == oldtsID) { if (_tagSet.getID() == oldtsID)
oldts = ts; oldts = _tagSet;
break;
}
} }
if (oldts == null) { if (oldts == null) {
if (_log.shouldWarn()) if (_log.shouldWarn())
@ -1101,6 +1086,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_myOBKeys = nextKeys; _myOBKeys = nextKeys;
_myOBKeyID++; _myOBKeyID++;
} }
_hisIBKey = receivedKey;
// create new OB TS, delete old one // create new OB TS, delete old one
PublicKey pub = nextKeys.getPublic(); PublicKey pub = nextKeys.getPublic();
PrivateKey priv = nextKeys.getPrivate(); PrivateKey priv = nextKeys.getPrivate();
@ -1111,8 +1098,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
int newtsID = oldtsID + 1; int newtsID = oldtsID + 1;
RatchetTagSet ts = new RatchetTagSet(_hkdf, oldts.getNextRootKey(), ssk, RatchetTagSet ts = new RatchetTagSet(_hkdf, oldts.getNextRootKey(), ssk,
_context.clock().now(), newtsID, _myOBKeyID); _context.clock().now(), newtsID, _myOBKeyID);
_tagSets.add(ts); _tagSet = ts;
_tagSets.remove(oldts);
_currentOBTagSetID = newtsID; _currentOBTagSetID = newtsID;
if (_log.shouldWarn()) if (_log.shouldWarn())
_log.warn("Got nextkey " + key + " ratchet to new OB ES TS:\n" + ts); _log.warn("Got nextkey " + key + " ratchet to new OB ES TS:\n" + ts);
@ -1128,7 +1114,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
hisLastOBKeyID = -1; hisLastOBKeyID = -1;
else else
hisLastOBKeyID = _hisOBKey.getID(); hisLastOBKeyID = _hisOBKey.getID();
_hisOBKey = key; // save as it may be replaced below; will be stored after all error checks complete
NextSessionKey receivedKey = key;
if (hisLastOBKeyID != id) { if (hisLastOBKeyID != id) {
// got a new key, use it // got a new key, use it
if (hisLastOBKeyID != id - 1) { if (hisLastOBKeyID != id - 1) {
@ -1147,7 +1134,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (hasKey) { if (hasKey) {
// got a old key id but new data? // got a old key id but new data?
if (_hisOBKeyWithData != null && _log.shouldWarn()) if (_hisOBKeyWithData != null && _log.shouldWarn())
_log.warn("Got nextkey for IB with data, didn't match previous " + key); _log.warn("Got nextkey for IB with data: " + key + " didn't match previous " + _hisOBKey + " / " + _hisOBKeyWithData);
return;
} else { } else {
if (_hisOBKeyWithData == null || if (_hisOBKeyWithData == null ||
_hisOBKeyWithData.getID() != key.getID()) { _hisOBKeyWithData.getID() != key.getID()) {
@ -1165,6 +1153,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.warn("Got nextkey for IB but we don't have next root key " + key); _log.warn("Got nextkey for IB but we don't have next root key " + key);
return; return;
} }
int oldtsID; int oldtsID;
if (_myIBKeyID == -1 && hisLastOBKeyID == -1) if (_myIBKeyID == -1 && hisLastOBKeyID == -1)
oldtsID = 0; oldtsID = 0;
@ -1190,6 +1179,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.warn("Got reverse with request, using old key anyway " + key); _log.warn("Got reverse with request, using old key anyway " + key);
_myIBKey = new NextSessionKey(_myIBKeyID, true, false); _myIBKey = new NextSessionKey(_myIBKeyID, true, false);
} }
_hisOBKey = receivedKey;
PrivateKey sharedSecret = ECIESAEADEngine.doDH(_myIBKeys.getPrivate(), key); PrivateKey sharedSecret = ECIESAEADEngine.doDH(_myIBKeys.getPrivate(), key);
int newtsID = oldtsID + 1; int newtsID = oldtsID + 1;
_currentIBTagSetID = newtsID; _currentIBTagSetID = newtsID;
@ -1213,12 +1204,14 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* @since 0.9.46 * @since 0.9.46
*/ */
private NextSessionKey getReverseSendKey() { private NextSessionKey getReverseSendKey() {
if (_myIBKey == null) synchronized (_unackedTagSets) {
return null; if (_myIBKey == null)
if (_myIBKeySendCount > MAX_SEND_REVERSE_KEY) return null;
return null; if (_myIBKeySendCount > MAX_SEND_REVERSE_KEY)
_myIBKeySendCount++; return null;
return _myIBKey; _myIBKeySendCount++;
return _myIBKey;
}
} }
/** /**
@ -1230,7 +1223,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*/ */
void firstTagConsumed(RatchetTagSet set) { void firstTagConsumed(RatchetTagSet set) {
SessionKey sk = set.getAssociatedKey(); SessionKey sk = set.getAssociatedKey();
synchronized (_tagSets) { synchronized (_unackedTagSets) {
// save next root key // save next root key
_nextIBRootKey = set.getNextRootKey(); _nextIBRootKey = set.getNextRootKey();
for (RatchetTagSet obSet : _unackedTagSets) { for (RatchetTagSet obSet : _unackedTagSets) {
@ -1239,8 +1232,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.debug("First tag received from IB ES\n" + set + _log.debug("First tag received from IB ES\n" + set +
"\npromoting OB ES " + obSet); "\npromoting OB ES " + obSet);
_unackedTagSets.clear(); _unackedTagSets.clear();
_tagSets.clear(); _tagSet = obSet;
_tagSets.add(obSet);
if (_NSRcallback != null) { if (_NSRcallback != null) {
_NSRcallback.onReply(); _NSRcallback.onReply();
_NSRcallback = null; _NSRcallback = null;
@ -1252,7 +1244,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (_log.shouldDebug()) if (_log.shouldDebug())
_log.debug("First tag received from IB ES\n" + set + _log.debug("First tag received from IB ES\n" + set +
" but no corresponding OB ES set found, unacked size: " + _unackedTagSets.size() + " but no corresponding OB ES set found, unacked size: " + _unackedTagSets.size() +
" acked size: " + _tagSets.size()); " acked size: " + ((_tagSet != null) ? 1 : 0));
} }
} }
@ -1263,21 +1255,14 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*/ */
List<RatchetTagSet> getTagSets() { List<RatchetTagSet> getTagSets() {
List<RatchetTagSet> rv; List<RatchetTagSet> rv;
synchronized (_tagSets) { synchronized (_unackedTagSets) {
rv = new ArrayList<RatchetTagSet>(_unackedTagSets); rv = new ArrayList<RatchetTagSet>(_unackedTagSets);
rv.addAll(_tagSets); if (_tagSet != null)
rv.add(_tagSet);
} }
return rv; return rv;
} }
/** didn't get an ack for these tags */
void failTags(RatchetTagSet set) {
synchronized (_tagSets) {
_unackedTagSets.remove(set);
_tagSets.remove(set);
}
}
public PublicKey getTarget() { public PublicKey getTarget() {
return _target; return _target;
} }
@ -1309,52 +1294,43 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*/ */
public int expireTags(long now) { public int expireTags(long now) {
int removed = 0; int removed = 0;
synchronized (_tagSets) { synchronized (_unackedTagSets) {
for (Iterator<RatchetTagSet> iter = _tagSets.iterator(); iter.hasNext(); ) { if (_tagSet != null) {
if (_tagSet.getExpiration() <= now) {
_tagSet = null;
removed++;
}
}
for (Iterator<RatchetTagSet> iter = _unackedTagSets.iterator(); iter.hasNext(); ) {
RatchetTagSet set = iter.next(); RatchetTagSet set = iter.next();
if (set.getExpiration() <= now) { if (set.getExpiration() <= now) {
iter.remove(); iter.remove();
removed++; removed++;
} }
} }
// failsafe, sometimes these are sticking around, not sure why, so clean them periodically
if ((now & 0x0f) == 0) {
for (Iterator<RatchetTagSet> iter = _unackedTagSets.iterator(); iter.hasNext(); ) {
RatchetTagSet set = iter.next();
if (set.getExpiration() <= now) {
iter.remove();
removed++;
}
}
}
} }
return removed; return removed;
} }
public RatchetEntry consumeNext() { public RatchetEntry consumeNext() {
long now = _context.clock().now(); long now = _context.clock().now();
synchronized (_tagSets) { synchronized (_unackedTagSets) {
while (!_tagSets.isEmpty()) { if (_tagSet != null) {
RatchetTagSet set = _tagSets.get(0); synchronized(_tagSet) {
synchronized(set) { // use even if expired, this will reset the expiration
if (set.getExpiration() > now) { RatchetSessionTag tag = _tagSet.consumeNext();
RatchetSessionTag tag = set.consumeNext(); if (tag != null) {
if (tag != null) { _lastUsed = now;
_lastUsed = now; _tagSet.setDate(now);
set.setDate(now); SessionKeyAndNonce skn = _tagSet.consumeNextKey();
SessionKeyAndNonce skn = set.consumeNextKey(); // TODO PN
// TODO PN return new RatchetEntry(tag, skn, _tagSet.getID(), 0, _tagSet.getNextKey(),
return new RatchetEntry(tag, skn, set.getID(), 0, set.getNextKey(), getReverseSendKey(), getAcksToSend());
getReverseSendKey(), getAcksToSend()); } else if (_log.shouldInfo()) {
} else if (_log.shouldInfo()) { _log.info("Removing empty " + _tagSet);
_log.info("Removing empty " + set);
}
} else {
if (_log.shouldInfo())
_log.info("Expired " + set);
} }
} }
_tagSets.remove(0); _tagSet = null;
} }
} }
return null; return null;
@ -1362,21 +1338,16 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
/** @return the total number of tags in acked RatchetTagSets */ /** @return the total number of tags in acked RatchetTagSets */
public int availableTags() { public int availableTags() {
int tags = 0;
long now = _context.clock().now(); long now = _context.clock().now();
synchronized (_tagSets) { synchronized (_unackedTagSets) {
for (int i = 0; i < _tagSets.size(); i++) { if (_tagSet != null) {
RatchetTagSet set = _tagSets.get(i); synchronized(_tagSet) {
if (!set.getAcked()) if (_tagSet.getExpiration() > now)
continue; return _tagSet.remaining();
if (set.getExpiration() > now) {
// or just add fixed number?
int sz = set.remaining();
tags += sz;
} }
} }
} }
return tags; return 0;
} }
/** /**
@ -1385,29 +1356,13 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* *
*/ */
public long getLastExpirationDate() { public long getLastExpirationDate() {
long last = 0; synchronized (_unackedTagSets) {
synchronized (_tagSets) { if (_tagSet != null)
for (RatchetTagSet set : _tagSets) { return _tagSet.getExpiration();
long exp = set.getExpiration();
if (exp > last && set.remaining() > 0)
last = exp;
}
} }
if (last > 0)
return last;
return -1; return -1;
} }
/**
* Put the RatchetTagSet on the unacked list.
*/
public void addTags(RatchetTagSet set) {
_lastUsed = _context.clock().now();
synchronized (_tagSets) {
_unackedTagSets.add(set);
}
}
public boolean getAckReceived() { public boolean getAckReceived() {
return _acked; return _acked;
} }