Ratchet: Simplify OB Session; there can be only one active OB TS

Fix bugs handling of out-of-order nextkeys
Expire unacked tagsets every time through
Remore unused OB session methods
This commit is contained in:
zzz
2020-04-06 20:27:47 +00:00
parent 14b33a1e4c
commit f6b5a2d493

View File

@ -850,16 +850,15 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* Before the first ack, all tagsets go here. These are never expired, we rely
* on the callers to call failTags() or ackTags() to remove them from this list.
* Actually we now do a failsafe expire.
* Synch on _tagSets to access this.
* Unsynchronized, sync to use.
* No particular order.
*/
private final Set<RatchetTagSet> _unackedTagSets;
/**
* As tagsets are acked, they go here.
* After the first ack, new tagsets go here (i.e. presumed acked)
* In order, earliest first.
* There is only one active outbound tagset.
* Synch on _unackedTagSets to access this.
*/
private final List<RatchetTagSet> _tagSets;
private RatchetTagSet _tagSet;
private final ConcurrentHashMap<Integer, ReplyCallback> _callbacks;
private final LinkedBlockingQueue<Integer> _acksToSend;
/**
@ -869,11 +868,6 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* to deliver the next set of tags.
*/
private volatile boolean _acked;
/**
* Fail count
* Synch on _tagSets to access this.
*/
private int _consecutiveFailures;
// next key
private int _myOBKeyID = -1;
@ -908,7 +902,6 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_established = _context.clock().now();
_lastUsed = _established;
_unackedTagSets = new HashSet<RatchetTagSet>(4);
_tagSets = new ArrayList<RatchetTagSet>(6);
_callbacks = new ConcurrentHashMap<Integer, ReplyCallback>();
_acksToSend = new LinkedBlockingQueue<Integer>();
// generate expected tagset
@ -924,7 +917,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
RatchetTagSet tagset = new RatchetTagSet(_hkdf, state,
rk, tk,
_established);
_tagSets.add(tagset);
_tagSet = tagset;
_state = null;
if (_log.shouldDebug())
_log.debug("New OB Session, rk = " + rk + " tk = " + tk + " 1st tagset:\n" + tagset);
@ -972,7 +965,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.debug("Update IB Session, rk = " + rk + " tk = " + Base64.encode(k_ab) + " ES tagset:\n" + tagset_ab);
_log.debug("Pending OB Session, rk = " + rk + " tk = " + Base64.encode(k_ba) + " ES tagset:\n" + tagset_ba);
}
synchronized (_tagSets) {
synchronized (_unackedTagSets) {
_unackedTagSets.add(tagset_ba);
_NSRcallback = callback;
}
@ -989,16 +982,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.debug("Update OB Session, rk = " + rk + " tk = " + Base64.encode(k_ab) + " ES tagset:\n" + tagset_ab);
_log.debug("Update IB Session, rk = " + rk + " tk = " + Base64.encode(k_ba) + " ES tagset:\n" + tagset_ba);
}
synchronized (_tagSets) {
for (Iterator<RatchetTagSet> iter = _tagSets.iterator(); iter.hasNext(); ) {
RatchetTagSet set = iter.next();
if (set.getID() == RatchetTagSet.DEBUG_OB_NSR) {
iter.remove();
if (_log.shouldDebug())
_log.debug("Removed OB NSR tagset:\n" + set);
}
}
_tagSets.add(tagset_ab);
synchronized (_unackedTagSets) {
_tagSet = tagset_ab;
_unackedTagSets.clear();
}
// We can't destroy the original state, as more NSRs may come in
@ -1019,7 +1004,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
boolean isRequest = key.isRequest();
boolean hasKey = key.getData() != null;
int id = key.getID();
synchronized (_tagSets) {
synchronized (_unackedTagSets) {
if (isReverse) {
// this is about my outbound tag set,
// and is an ack of new key sent
@ -1038,7 +1023,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
hisLastIBKeyID = -1;
else
hisLastIBKeyID = _hisIBKey.getID();
_hisIBKey = key;
// save as it may be replaced below; will be stored after all error checks complete
NextSessionKey receivedKey = key;
if (hisLastIBKeyID != id) {
// got a new key, use it
if (hisLastIBKeyID != id - 1) {
@ -1057,7 +1043,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (hasKey) {
// got a old key id but new data?
if (_hisIBKeyWithData != null && _log.shouldWarn())
_log.warn("Got nextkey for OB with data, didn't match previous " + key);
_log.warn("Got nextkey for OB with data: " + key + " didn't match previous " + _hisIBKey + " / " + _hisIBKeyWithData);
return;
} else {
if (_hisIBKeyWithData == null ||
_hisIBKeyWithData.getID() != key.getID()) {
@ -1076,11 +1063,9 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
else
oldtsID = 1 + _myOBKeyID + hisLastIBKeyID;
RatchetTagSet oldts = null;
for (RatchetTagSet ts : _tagSets) {
if (ts.getID() == oldtsID) {
oldts = ts;
break;
}
if (_tagSet != null) {
if (_tagSet.getID() == oldtsID)
oldts = _tagSet;
}
if (oldts == null) {
if (_log.shouldWarn())
@ -1101,6 +1086,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_myOBKeys = nextKeys;
_myOBKeyID++;
}
_hisIBKey = receivedKey;
// create new OB TS, delete old one
PublicKey pub = nextKeys.getPublic();
PrivateKey priv = nextKeys.getPrivate();
@ -1111,8 +1098,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
int newtsID = oldtsID + 1;
RatchetTagSet ts = new RatchetTagSet(_hkdf, oldts.getNextRootKey(), ssk,
_context.clock().now(), newtsID, _myOBKeyID);
_tagSets.add(ts);
_tagSets.remove(oldts);
_tagSet = ts;
_currentOBTagSetID = newtsID;
if (_log.shouldWarn())
_log.warn("Got nextkey " + key + " ratchet to new OB ES TS:\n" + ts);
@ -1128,7 +1114,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
hisLastOBKeyID = -1;
else
hisLastOBKeyID = _hisOBKey.getID();
_hisOBKey = key;
// save as it may be replaced below; will be stored after all error checks complete
NextSessionKey receivedKey = key;
if (hisLastOBKeyID != id) {
// got a new key, use it
if (hisLastOBKeyID != id - 1) {
@ -1147,7 +1134,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (hasKey) {
// got a old key id but new data?
if (_hisOBKeyWithData != null && _log.shouldWarn())
_log.warn("Got nextkey for IB with data, didn't match previous " + key);
_log.warn("Got nextkey for IB with data: " + key + " didn't match previous " + _hisOBKey + " / " + _hisOBKeyWithData);
return;
} else {
if (_hisOBKeyWithData == null ||
_hisOBKeyWithData.getID() != key.getID()) {
@ -1165,6 +1153,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.warn("Got nextkey for IB but we don't have next root key " + key);
return;
}
int oldtsID;
if (_myIBKeyID == -1 && hisLastOBKeyID == -1)
oldtsID = 0;
@ -1190,6 +1179,8 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.warn("Got reverse with request, using old key anyway " + key);
_myIBKey = new NextSessionKey(_myIBKeyID, true, false);
}
_hisOBKey = receivedKey;
PrivateKey sharedSecret = ECIESAEADEngine.doDH(_myIBKeys.getPrivate(), key);
int newtsID = oldtsID + 1;
_currentIBTagSetID = newtsID;
@ -1213,12 +1204,14 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
* @since 0.9.46
*/
private NextSessionKey getReverseSendKey() {
if (_myIBKey == null)
return null;
if (_myIBKeySendCount > MAX_SEND_REVERSE_KEY)
return null;
_myIBKeySendCount++;
return _myIBKey;
synchronized (_unackedTagSets) {
if (_myIBKey == null)
return null;
if (_myIBKeySendCount > MAX_SEND_REVERSE_KEY)
return null;
_myIBKeySendCount++;
return _myIBKey;
}
}
/**
@ -1230,7 +1223,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*/
void firstTagConsumed(RatchetTagSet set) {
SessionKey sk = set.getAssociatedKey();
synchronized (_tagSets) {
synchronized (_unackedTagSets) {
// save next root key
_nextIBRootKey = set.getNextRootKey();
for (RatchetTagSet obSet : _unackedTagSets) {
@ -1239,8 +1232,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
_log.debug("First tag received from IB ES\n" + set +
"\npromoting OB ES " + obSet);
_unackedTagSets.clear();
_tagSets.clear();
_tagSets.add(obSet);
_tagSet = obSet;
if (_NSRcallback != null) {
_NSRcallback.onReply();
_NSRcallback = null;
@ -1252,7 +1244,7 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
if (_log.shouldDebug())
_log.debug("First tag received from IB ES\n" + set +
" but no corresponding OB ES set found, unacked size: " + _unackedTagSets.size() +
" acked size: " + _tagSets.size());
" acked size: " + ((_tagSet != null) ? 1 : 0));
}
}
@ -1263,21 +1255,14 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*/
List<RatchetTagSet> getTagSets() {
List<RatchetTagSet> rv;
synchronized (_tagSets) {
synchronized (_unackedTagSets) {
rv = new ArrayList<RatchetTagSet>(_unackedTagSets);
rv.addAll(_tagSets);
if (_tagSet != null)
rv.add(_tagSet);
}
return rv;
}
/** didn't get an ack for these tags */
void failTags(RatchetTagSet set) {
synchronized (_tagSets) {
_unackedTagSets.remove(set);
_tagSets.remove(set);
}
}
public PublicKey getTarget() {
return _target;
}
@ -1309,52 +1294,43 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*/
public int expireTags(long now) {
int removed = 0;
synchronized (_tagSets) {
for (Iterator<RatchetTagSet> iter = _tagSets.iterator(); iter.hasNext(); ) {
synchronized (_unackedTagSets) {
if (_tagSet != null) {
if (_tagSet.getExpiration() <= now) {
_tagSet = null;
removed++;
}
}
for (Iterator<RatchetTagSet> iter = _unackedTagSets.iterator(); iter.hasNext(); ) {
RatchetTagSet set = iter.next();
if (set.getExpiration() <= now) {
iter.remove();
removed++;
}
}
// failsafe, sometimes these are sticking around, not sure why, so clean them periodically
if ((now & 0x0f) == 0) {
for (Iterator<RatchetTagSet> iter = _unackedTagSets.iterator(); iter.hasNext(); ) {
RatchetTagSet set = iter.next();
if (set.getExpiration() <= now) {
iter.remove();
removed++;
}
}
}
}
return removed;
}
public RatchetEntry consumeNext() {
long now = _context.clock().now();
synchronized (_tagSets) {
while (!_tagSets.isEmpty()) {
RatchetTagSet set = _tagSets.get(0);
synchronized(set) {
if (set.getExpiration() > now) {
RatchetSessionTag tag = set.consumeNext();
if (tag != null) {
_lastUsed = now;
set.setDate(now);
SessionKeyAndNonce skn = set.consumeNextKey();
// TODO PN
return new RatchetEntry(tag, skn, set.getID(), 0, set.getNextKey(),
getReverseSendKey(), getAcksToSend());
} else if (_log.shouldInfo()) {
_log.info("Removing empty " + set);
}
} else {
if (_log.shouldInfo())
_log.info("Expired " + set);
synchronized (_unackedTagSets) {
if (_tagSet != null) {
synchronized(_tagSet) {
// use even if expired, this will reset the expiration
RatchetSessionTag tag = _tagSet.consumeNext();
if (tag != null) {
_lastUsed = now;
_tagSet.setDate(now);
SessionKeyAndNonce skn = _tagSet.consumeNextKey();
// TODO PN
return new RatchetEntry(tag, skn, _tagSet.getID(), 0, _tagSet.getNextKey(),
getReverseSendKey(), getAcksToSend());
} else if (_log.shouldInfo()) {
_log.info("Removing empty " + _tagSet);
}
}
_tagSets.remove(0);
_tagSet = null;
}
}
return null;
@ -1362,21 +1338,16 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
/** @return the total number of tags in acked RatchetTagSets */
public int availableTags() {
int tags = 0;
long now = _context.clock().now();
synchronized (_tagSets) {
for (int i = 0; i < _tagSets.size(); i++) {
RatchetTagSet set = _tagSets.get(i);
if (!set.getAcked())
continue;
if (set.getExpiration() > now) {
// or just add fixed number?
int sz = set.remaining();
tags += sz;
synchronized (_unackedTagSets) {
if (_tagSet != null) {
synchronized(_tagSet) {
if (_tagSet.getExpiration() > now)
return _tagSet.remaining();
}
}
}
return tags;
return 0;
}
/**
@ -1385,29 +1356,13 @@ public class RatchetSKM extends SessionKeyManager implements SessionTagListener
*
*/
public long getLastExpirationDate() {
long last = 0;
synchronized (_tagSets) {
for (RatchetTagSet set : _tagSets) {
long exp = set.getExpiration();
if (exp > last && set.remaining() > 0)
last = exp;
}
synchronized (_unackedTagSets) {
if (_tagSet != null)
return _tagSet.getExpiration();
}
if (last > 0)
return last;
return -1;
}
/**
* Put the RatchetTagSet on the unacked list.
*/
public void addTags(RatchetTagSet set) {
_lastUsed = _context.clock().now();
synchronized (_tagSets) {
_unackedTagSets.add(set);
}
}
public boolean getAckReceived() {
return _acked;
}