diff --git a/src/ServiceDiscovery/ServicesBackend.cpp b/src/ServiceDiscovery/ServicesBackend.cpp index 4dfe980..eb052f6 100644 --- a/src/ServiceDiscovery/ServicesBackend.cpp +++ b/src/ServiceDiscovery/ServicesBackend.cpp @@ -2,6 +2,7 @@ namespace { const uint32_t MAX_UDP_PACKET_SIZE = 655355; + const uint32_t MAX_DECOMPRESSED_MSG_SIZE = 655355; } using namespace ToolFramework; @@ -140,8 +141,9 @@ bool ServicesBackend::Initialise(Store &variables_in){ } if(msg_compression){ - zstd_ctx = ZSTD_createCCtx(); + zstd_cctx = ZSTD_createCCtx(); compressed_msg_buf = new char[ZSTD_compressBound(MAX_UDP_PACKET_SIZE)]; + zstd_dctx = ZSTD_createDCtx(); } // initialise the message IDs based on the current time in unix seconds @@ -388,9 +390,9 @@ bool ServicesBackend::SendMulticast(MulticastType type, std::string command, std // compress the message if applicable msg_to_send=nullptr; std::unique_lock locker(msg_buf_mtx, std::defer_lock); - if(zstd_ctx){ + if(zstd_cctx){ locker.lock(); - bytes_to_send = ZSTD_compressCCtx(zstd_ctx, compressed_msg_buf, MAX_UDP_PACKET_SIZE, command.data(), command.size(), compression_level); + bytes_to_send = ZSTD_compressCCtx(zstd_cctx, compressed_msg_buf, MAX_UDP_PACKET_SIZE, command.data(), command.size(), compression_level); if(ZSTD_isError(bytes_to_send)){ locker.unlock(); std::string errmsg = std::string{"Warning: error compressing multicast message "}+ZSTD_getErrorName(bytes_to_send); @@ -462,9 +464,9 @@ bool ServicesBackend::SendCommand(const std::string& topic, const std::string& c // compress the message if applicable msg_to_send=nullptr; std::unique_lock locker(msg_buf_mtx, std::defer_lock); - if(zstd_ctx){ + if(zstd_cctx){ locker.lock(); - bytes_to_send = ZSTD_compressCCtx(zstd_ctx, compressed_msg_buf, MAX_UDP_PACKET_SIZE, command.data(), command.size(), compression_level); + bytes_to_send = ZSTD_compressCCtx(zstd_cctx, compressed_msg_buf, MAX_UDP_PACKET_SIZE, command.data(), command.size(), compression_level); if(ZSTD_isError(bytes_to_send)){ locker.unlock(); std::string errmsg = std::string{"Warning: error compressing multicast message "}+ZSTD_getErrorName(bytes_to_send); @@ -714,12 +716,44 @@ bool ServicesBackend::GetNextResponse(){ // if we also had further parts, fetch those // if the command failed the response contains an error message (which will only ever be one part) for(unsigned int i=2; i(response.at(i).data()))); - std::string resp(response.at(i).size(),'\0'); - memcpy((void*)resp.data(), response.at(i).data(), response.at(i).size()); - resp = resp.substr(0,resp.find('\0')); - if(cmd.success) cmd.response.push_back(resp); - else cmd.err = resp; + + if(zstd_dctx && response.at(i).size() && ((char*)(response.at(i).data()))[0]=='('){ + + // compressed - decompress it + next_bytes = ZSTD_getFrameContentSize(response.at(i).data(), response.at(i).size()); + if(next_bytes==ZSTD_CONTENTSIZE_UNKNOWN || next_bytes==ZSTD_CONTENTSIZE_ERROR){ + // bad response + cmd.success = false; + cmd.err="Received corrupt zstd response size"; + Log(cmd.err,v_warning,verbosity); + break; + } + if(next_bytes > MAX_DECOMPRESSED_MSG_SIZE){ + cmd.success = false; + cmd.err="Received oversized zstd response: "+std::to_string(next_bytes)+" bytes"; + Log(cmd.err,v_warning,verbosity); + break; + } + decompress_buffer.resize(next_bytes); + next_bytes = ZSTD_decompressDCtx(zstd_dctx,(void*)decompress_buffer.data(),next_bytes, response.at(i).data(), response.at(i).size()); + if(ZSTD_isError(next_bytes)){ + cmd.success = false; + cmd.err=std::string{"zstd error decompressing response: "}+ZSTD_getErrorName(next_bytes); + Log(cmd.err,v_warning,verbosity); + break; + } + + next_part = decompress_buffer.data(); + + } else { + + next_bytes = response.at(i).size(); + next_part = (const char*)response.at(i).data(); + + } + + if(cmd.success) cmd.response.emplace_back(next_part, next_bytes); + else cmd.err.assign(next_part, next_bytes); } @@ -843,7 +877,8 @@ bool ServicesBackend::Finalise(){ waiting_recipients.clear(); // cleanup zmq compression context - if(zstd_ctx) ZSTD_freeCCtx(zstd_ctx); + if(zstd_cctx) ZSTD_freeCCtx(zstd_cctx); + if(zstd_dctx) ZSTD_freeDCtx(zstd_dctx); // can't use 'Log' since we may have deleted the Logging class if(verbosity>3) std::cout<<"ServicesBackend finalise done"<