diff --git a/stuncore/stunbuilder.cpp b/stuncore/stunbuilder.cpp index 2cc456e..dbc3651 100644 --- a/stuncore/stunbuilder.cpp +++ b/stuncore/stunbuilder.cpp @@ -135,14 +135,31 @@ HRESULT CStunMessageBuilder::AddRandomTransactionId(StunTransactionId* pTransId) srand(entropy); - // the first four bytes of the transaction id is always the magic cookie - // followed by 12 bytes of the real transaction id - memcpy(transid.id, &stun_cookie_nbo, sizeof(stun_cookie_nbo)); - for (int x = 4; x < (STUN_TRANSACTION_ID_LENGTH-4); x++) + int x = 0; + if (!_fLegacyMode) + { + // the first four bytes of the transaction id is always the magic cookie + // followed by 12 bytes of the real transaction id + memcpy(transid.id, &stun_cookie_nbo, sizeof(stun_cookie_nbo)); + x = 4; + } + + for (; x < STUN_TRANSACTION_ID_LENGTH; x++) { transid.id[x] = (uint8_t)(rand() % 256); } + if (_fLegacyMode) + { + // rfc3489 (legacy mode) does not use magic cookie + // if the generated txn id happens to start with the magic cookie, keep + // re-generating the 1st byte until it doesn't + while (memcmp(transid.id, &stun_cookie_nbo, sizeof(stun_cookie_nbo)) == 0) + { + transid.id[0] = (uint8_t)(rand() % 256); + } + } + if (pTransId) { *pTransId = transid; diff --git a/testcode/testbuilder.cpp b/testcode/testbuilder.cpp index f928cf8..3e07224 100644 --- a/testcode/testbuilder.cpp +++ b/testcode/testbuilder.cpp @@ -27,6 +27,8 @@ HRESULT CTestBuilder::Run() HRESULT hr = S_OK; Chk(Test1()) Chk(Test2()); + Chk(Test3()); + Chk(Test4()); Cleanup: return hr; } @@ -138,4 +140,44 @@ HRESULT CTestBuilder::Test2() return hr; } +// This test validates the non-legacy mode transaction id generated by builder. +HRESULT CTestBuilder::Test3() +{ + HRESULT hr = S_OK; + CStunMessageBuilder builder; + CStunMessageReader reader; + StunTransactionId transid = {}; + CRefCountedBuffer spBuffer; + + builder.SetLegacyMode(false); + ChkA(builder.AddBindingRequestHeader()); + ChkA(builder.AddRandomTransactionId(&transid)); + ChkA(builder.GetResult(&spBuffer)); + + ChkIfA(CStunMessageReader::BodyValidated != reader.AddBytes(spBuffer->GetData(), spBuffer->GetSize()), E_FAIL); + ChkIf(reader.IsMessageLegacyFormat() == true, E_FAIL); + +Cleanup: + return hr; +} +// This test validates the legacy mode transaction id generated by builder. +HRESULT CTestBuilder::Test4() +{ + HRESULT hr = S_OK; + CStunMessageBuilder builder; + CStunMessageReader reader; + StunTransactionId transid = {}; + CRefCountedBuffer spBuffer; + + builder.SetLegacyMode(true); + ChkA(builder.AddBindingRequestHeader()); + ChkA(builder.AddRandomTransactionId(&transid)); + ChkA(builder.GetResult(&spBuffer)); + + ChkIfA(CStunMessageReader::BodyValidated != reader.AddBytes(spBuffer->GetData(), spBuffer->GetSize()), E_FAIL); + ChkIf(reader.IsMessageLegacyFormat() == false, E_FAIL); + +Cleanup: + return hr; +} \ No newline at end of file diff --git a/testcode/testbuilder.h b/testcode/testbuilder.h index 59c216b..2163104 100644 --- a/testcode/testbuilder.h +++ b/testcode/testbuilder.h @@ -26,6 +26,8 @@ class CTestBuilder : public IUnitTest public: HRESULT Test1(); HRESULT Test2(); + HRESULT Test3(); + HRESULT Test4(); virtual HRESULT Run();