diff --git a/conf/config.ini b/conf/config.ini index 21931088..5a468bd7 100644 --- a/conf/config.ini +++ b/conf/config.ini @@ -291,8 +291,8 @@ timeoutSec=5 #该端口是多线程的,同时支持客户端网络切换导致的连接迁移 port=9000 -#srt 协议中延迟缓存的估算参数,在握手阶段估算rtt ,然后lantencyMul*rtt 为最大缓存时长,此参数越大,表示等待重传的时长就越大 -lantencyMul=4 +#srt 协议中延迟缓存的估算参数,在握手阶段估算rtt ,然后latencyMul*rtt 为最大缓存时长,此参数越大,表示等待重传的时长就越大 +latencyMul=4 [rtsp] diff --git a/srt/Ack.cpp b/srt/Ack.cpp index 7be5a08d..f0afc017 100644 --- a/srt/Ack.cpp +++ b/srt/Ack.cpp @@ -2,8 +2,9 @@ #include "Common.hpp" namespace SRT { -bool ACKPacket::loadFromData(uint8_t *buf, size_t len) { - if(len < ACK_CIF_SIZE + ControlPacket::HEADER_SIZE){ + +bool ACKPacket::loadFromData(uint8_t *buf, size_t len) { + if (len < ACK_CIF_SIZE + ControlPacket::HEADER_SIZE) { return false; } @@ -11,7 +12,7 @@ bool ACKPacket::loadFromData(uint8_t *buf, size_t len) { _data->assign((char *)(buf), len); ControlPacket::loadHeader(); ack_number = loadUint32(type_specific_info); - uint8_t* ptr = (uint8_t*)_data->data()+ControlPacket::HEADER_SIZE; + uint8_t *ptr = (uint8_t *)_data->data() + ControlPacket::HEADER_SIZE; last_ack_pkt_seq_number = loadUint32(ptr); ptr += 4; @@ -32,52 +33,53 @@ bool ACKPacket::loadFromData(uint8_t *buf, size_t len) { ptr += 4; recv_rate = loadUint32(ptr); - ptr += 4; + ptr += 4; return true; } -bool ACKPacket::storeToData() { + +bool ACKPacket::storeToData() { _data = BufferRaw::create(); _data->setCapacity(HEADER_SIZE + ACK_CIF_SIZE); _data->setSize(HEADER_SIZE + ACK_CIF_SIZE); control_type = ControlPacket::ACK; sub_type = 0; - storeUint32(type_specific_info,ack_number); + storeUint32(type_specific_info, ack_number); storeToHeader(); - uint8_t* ptr = (uint8_t*)_data->data()+ControlPacket::HEADER_SIZE; - - storeUint32(ptr,last_ack_pkt_seq_number); + uint8_t *ptr = (uint8_t *)_data->data() + ControlPacket::HEADER_SIZE; + + storeUint32(ptr, last_ack_pkt_seq_number); ptr += 4; - storeUint32(ptr,rtt); + storeUint32(ptr, rtt); ptr += 4; - storeUint32(ptr,rtt_variance); + storeUint32(ptr, rtt_variance); ptr += 4; - storeUint32(ptr,pkt_recv_rate); + storeUint32(ptr, pkt_recv_rate); ptr += 4; - storeUint32(ptr,available_buf_size); + storeUint32(ptr, available_buf_size); ptr += 4; - storeUint32(ptr,estimated_link_capacity); + storeUint32(ptr, estimated_link_capacity); ptr += 4; - storeUint32(ptr,recv_rate); + storeUint32(ptr, recv_rate); ptr += 4; return true; } -std::string ACKPacket::dump(){ +std::string ACKPacket::dump() { _StrPrinter printer; - printer << "last_ack_pkt_seq_number="<; ACKPacket() = default; ~ACKPacket() = default; - enum{ - ACK_CIF_SIZE = 7*4 - }; + enum { ACK_CIF_SIZE = 7 * 4 }; std::string dump(); ///////ControlPacket override/////// bool loadFromData(uint8_t *buf, size_t len) override; @@ -59,15 +55,14 @@ public: uint32_t recv_rate; }; - -class ACKACKPacket : public ControlPacket{ +class ACKACKPacket : public ControlPacket { public: using Ptr = std::shared_ptr; ACKACKPacket() = default; ~ACKACKPacket() = default; ///////ControlPacket override/////// - bool loadFromData(uint8_t *buf, size_t len) override{ - if(len < ControlPacket::HEADER_SIZE){ + bool loadFromData(uint8_t *buf, size_t len) override { + if (len < ControlPacket::HEADER_SIZE) { return false; } _data = BufferRaw::create(); @@ -76,21 +71,20 @@ public: ack_number = loadUint32(type_specific_info); return true; } - bool storeToData() override{ + bool storeToData() override { _data = BufferRaw::create(); _data->setCapacity(HEADER_SIZE); - _data->setSize(HEADER_SIZE ); + _data->setSize(HEADER_SIZE); control_type = ControlPacket::ACKACK; sub_type = 0; - storeUint32(type_specific_info,ack_number); + storeUint32(type_specific_info, ack_number); storeToHeader(); return true; } uint32_t ack_number; - }; -} //namespace SRT +} // namespace SRT #endif // ZLMEDIAKIT_SRT_ACK_H \ No newline at end of file diff --git a/srt/Common.hpp b/srt/Common.hpp index d0ba6207..418a00a3 100644 --- a/srt/Common.hpp +++ b/srt/Common.hpp @@ -2,83 +2,73 @@ #define ZLMEDIAKIT_SRT_COMMON_H #include -namespace SRT -{ +namespace SRT { + using SteadyClock = std::chrono::steady_clock; using TimePoint = std::chrono::time_point; using Microseconds = std::chrono::microseconds; using Milliseconds = std::chrono::milliseconds; -inline int64_t DurationCountMicroseconds( SteadyClock::duration dur){ +static inline int64_t DurationCountMicroseconds(SteadyClock::duration dur) { return std::chrono::duration_cast(dur).count(); } -inline uint32_t loadUint32(uint8_t *ptr) { +static inline uint32_t loadUint32(uint8_t *ptr) { return ptr[0] << 24 | ptr[1] << 16 | ptr[2] << 8 | ptr[3]; } -inline uint16_t loadUint16(uint8_t *ptr) { + +static inline uint16_t loadUint16(uint8_t *ptr) { return ptr[0] << 8 | ptr[1]; } -inline void storeUint32(uint8_t *buf, uint32_t val) { +static inline void storeUint32(uint8_t *buf, uint32_t val) { buf[0] = val >> 24; buf[1] = (val >> 16) & 0xff; buf[2] = (val >> 8) & 0xff; buf[3] = val & 0xff; } -inline void storeUint16(uint8_t *buf, uint16_t val) { +static inline void storeUint16(uint8_t *buf, uint16_t val) { buf[0] = (val >> 8) & 0xff; buf[1] = val & 0xff; } -inline void storeUint32LE(uint8_t *buf, uint32_t val) { +static inline void storeUint32LE(uint8_t *buf, uint32_t val) { buf[0] = val & 0xff; buf[1] = (val >> 8) & 0xff; buf[2] = (val >> 16) & 0xff; - buf[3] = (val >>24) & 0xff; + buf[3] = (val >> 24) & 0xff; } -inline void storeUint16LE(uint8_t *buf, uint16_t val) { +static inline void storeUint16LE(uint8_t *buf, uint16_t val) { buf[0] = val & 0xff; - buf[1] = (val>>8) & 0xff; + buf[1] = (val >> 8) & 0xff; } -inline uint32_t srtVersion(int major, int minor, int patch) -{ - return patch + minor*0x100 + major*0x10000; +static inline uint32_t srtVersion(int major, int minor, int patch) { + return patch + minor * 0x100 + major * 0x10000; } class UTicker { public: - UTicker() { - _created = _begin = SteadyClock::now(); - } - - ~UTicker() { - } + UTicker() { _created = _begin = SteadyClock::now(); } + ~UTicker() = default; /** * 获取创建时间,单位微妙 */ - int64_t elapsedTime(TimePoint now) const { - return DurationCountMicroseconds(now - _begin); - } + int64_t elapsedTime(TimePoint now) const { return DurationCountMicroseconds(now - _begin); } /** * 获取上次resetTime后至今的时间,单位毫秒 */ - int64_t createdTime(TimePoint now) const { - return DurationCountMicroseconds(now - _created); - } + int64_t createdTime(TimePoint now) const { return DurationCountMicroseconds(now - _created); } /** * 重置计时器 */ - void resetTime(TimePoint now) { - _begin = now; - } + void resetTime(TimePoint now) { _begin = now; } private: TimePoint _begin; @@ -87,4 +77,4 @@ private: } // namespace SRT -#endif //ZLMEDIAKIT_SRT_COMMON_H \ No newline at end of file +#endif // ZLMEDIAKIT_SRT_COMMON_H \ No newline at end of file diff --git a/srt/HSExt.cpp b/srt/HSExt.cpp index ab4d4ff2..d12b2b3c 100644 --- a/srt/HSExt.cpp +++ b/srt/HSExt.cpp @@ -1,20 +1,21 @@ #include "HSExt.hpp" namespace SRT { + bool HSExtMessage::loadFromData(uint8_t *buf, size_t len) { - if(buf == NULL || len != HSEXT_MSG_SIZE){ + if (buf == NULL || len != HSEXT_MSG_SIZE) { return false; } _data = BufferRaw::create(); - _data->assign((char*)buf,len); + _data->assign((char *)buf, len); extension_length = 3; HSExt::loadHeader(); assert(extension_type == SRT_CMD_HSREQ || extension_type == SRT_CMD_HSRSP); - uint8_t* ptr = (uint8_t*)_data->data()+4; - srt_version = loadUint32(ptr); + uint8_t *ptr = (uint8_t *)_data->data() + 4; + srt_version = loadUint32(ptr); ptr += 4; srt_flag = loadUint32(ptr); @@ -27,105 +28,107 @@ bool HSExtMessage::loadFromData(uint8_t *buf, size_t len) { ptr += 2; return true; +} - } - std::string HSExtMessage::dump(){ - _StrPrinter printer; - printer << "srt version : "<data(); } return nullptr; - }; + } size_t size() const override { if (_data) { return _data->size(); } return 0; - }; + } protected: void loadHeader() { @@ -116,7 +117,7 @@ public: */ class HSExtStreamID : public HSExt { public: - using Ptr = std::shared_ptr; + using Ptr = std::shared_ptr; HSExtStreamID() = default; ~HSExtStreamID() = default; bool loadFromData(uint8_t *buf, size_t len) override; diff --git a/srt/Packet.cpp b/srt/Packet.cpp index 18ad279c..51bc5ab0 100644 --- a/srt/Packet.cpp +++ b/srt/Packet.cpp @@ -10,19 +10,13 @@ #include #endif // defined(_WIN32) - - #include #include "Util/logger.h" #include "Util/MD5.h" - #include "Packet.hpp" - - namespace SRT { - const size_t DataPacket::HEADER_SIZE; const size_t ControlPacket::HEADER_SIZE; const size_t HandshakePacket::HS_CONTENT_MIN_SIZE; @@ -38,7 +32,7 @@ bool DataPacket::isDataPacket(uint8_t *buf, size_t len) { return false; } -uint32_t DataPacket::getSocketID(uint8_t *buf, size_t len){ +uint32_t DataPacket::getSocketID(uint8_t *buf, size_t len) { uint8_t *ptr = buf; ptr += 12; return loadUint32(ptr); @@ -51,7 +45,7 @@ bool DataPacket::loadFromData(uint8_t *buf, size_t len) { } uint8_t *ptr = buf; f = ptr[0] >> 7; - packet_seq_number = loadUint32(ptr)&0x7fffffff; + packet_seq_number = loadUint32(ptr) & 0x7fffffff; ptr += 4; PP = ptr[0] >> 6; @@ -71,7 +65,8 @@ bool DataPacket::loadFromData(uint8_t *buf, size_t len) { _data->assign((char *)(buf), len); return true; } -bool DataPacket::storeToHeader(){ + +bool DataPacket::storeToHeader() { if (!_data || _data->size() < HEADER_SIZE) { WarnL << "data size less " << HEADER_SIZE; return false; @@ -101,6 +96,7 @@ bool DataPacket::storeToHeader(){ ptr += 4; return true; } + bool DataPacket::storeToData(uint8_t *buf, size_t len) { _data = BufferRaw::create(); _data->setCapacity(len + HEADER_SIZE); @@ -139,6 +135,7 @@ char *DataPacket::data() const { return nullptr; return _data->data(); } + size_t DataPacket::size() const { if (!_data) { return 0; @@ -151,6 +148,7 @@ char *DataPacket::payloadData() { return nullptr; return _data->data() + HEADER_SIZE; } + size_t DataPacket::payloadSize() { if (!_data) { return 0; @@ -158,8 +156,6 @@ size_t DataPacket::payloadSize() { return _data->size() - HEADER_SIZE; } - - bool ControlPacket::isControlPacket(uint8_t *buf, size_t len) { if (len < HEADER_SIZE) { WarnL << "data size" << len << " less " << HEADER_SIZE; @@ -199,6 +195,7 @@ bool ControlPacket::loadHeader() { ptr += 4; return true; } + bool ControlPacket::storeToHeader() { uint8_t *ptr = (uint8_t *)_data->data(); ptr[0] = 0x80; @@ -228,17 +225,20 @@ char *ControlPacket::data() const { return nullptr; return _data->data(); } + size_t ControlPacket::size() const { if (!_data) { return 0; } return _data->size(); } -uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len){ - return loadUint32(buf+12); + +uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len) { + return loadUint32(buf + 12); } + bool HandshakePacket::loadFromData(uint8_t *buf, size_t len) { - if(HEADER_SIZE+HS_CONTENT_MIN_SIZE > len){ + if (HEADER_SIZE + HS_CONTENT_MIN_SIZE > len) { ErrorL << "size too smalle " << encryption_field; return false; } @@ -282,79 +282,75 @@ bool HandshakePacket::loadFromData(uint8_t *buf, size_t len) { ErrorL << "not support encryption " << encryption_field; } - if(extension_field == 0){ + if (extension_field == 0) { return true; } - if(len == HEADER_SIZE+HS_CONTENT_MIN_SIZE){ - //ErrorL << "extension filed not exist " << extension_field; + if (len == HEADER_SIZE + HS_CONTENT_MIN_SIZE) { + // ErrorL << "extension filed not exist " << extension_field; return true; } - return loadExtMessage(ptr,len-HS_CONTENT_MIN_SIZE-HEADER_SIZE); + return loadExtMessage(ptr, len - HS_CONTENT_MIN_SIZE - HEADER_SIZE); } -bool HandshakePacket::loadExtMessage(uint8_t *buf,size_t len){ - uint8_t* ptr = buf; - ext_list.clear(); - uint16_t type; - uint16_t length; - HSExt::Ptr ext; - while(ptr(); - break; - case HSExt::SRT_CMD_SID: - ext = std::make_shared(); - break; + +bool HandshakePacket::loadExtMessage(uint8_t *buf, size_t len) { + uint8_t *ptr = buf; + ext_list.clear(); + uint16_t type; + uint16_t length; + HSExt::Ptr ext; + while (ptr < buf + len) { + type = loadUint16(ptr); + length = loadUint16(ptr + 2); + switch (type) { + case HSExt::SRT_CMD_HSREQ: + case HSExt::SRT_CMD_HSRSP: ext = std::make_shared(); break; + case HSExt::SRT_CMD_SID: ext = std::make_shared(); break; default: - WarnL<<"not support ext "<loadFromData(ptr,length*4+4)){ - ext_list.push_back(std::move(ext)); - }else{ - WarnL<<"parse HS EXT failed type="<assign((char*)buf,len); - + _data->assign((char *)buf, len); + return loadHeader(); } -bool KeepLivePacket::storeToData(){ +bool KeepLivePacket::storeToData() { control_type = ControlPacket::KEEPALIVE; sub_type = 0; @@ -506,22 +503,21 @@ bool NAKPacket::loadFromData(uint8_t *buf, size_t len) { return false; } _data = BufferRaw::create(); - _data->assign((char*)buf,len); + _data->assign((char *)buf, len); loadHeader(); - uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; - uint8_t* end = (uint8_t*)_data->data()+_data->size(); + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; + uint8_t *end = (uint8_t *)_data->data() + _data->size(); LostPair lost; - while (ptrsetCapacity(HEADER_SIZE+cif_size); - _data->setSize(HEADER_SIZE+cif_size); + _data->setCapacity(HEADER_SIZE + cif_size); + _data->setSize(HEADER_SIZE + cif_size); storeToHeader(); - uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; - for(auto it : lost_list){ - if(it.first+1 ==it.second){ - storeUint32(ptr,it.first); - ptr[0] = ptr[0]&0x7f; - ptr += 4; - }else{ - storeUint32(ptr,it.first); - ptr[0] |= 0x80; + for (auto it : lost_list) { + if (it.first + 1 == it.second) { + storeUint32(ptr, it.first); + ptr[0] = ptr[0] & 0x7f; + ptr += 4; + } else { + storeUint32(ptr, it.first); + ptr[0] |= 0x80; - storeUint32(ptr+4,it.second-1); - //ptr[4] = ptr[4]&0x7f; + storeUint32(ptr + 4, it.second - 1); + // ptr[4] = ptr[4]&0x7f; - ptr += 8; + ptr += 8; } } return true; } -size_t NAKPacket::getCIFSize(){ +size_t NAKPacket::getCIFSize() { size_t size = 0; - for(auto it : lost_list){ - if(it.first+1 ==it.second){ + for (auto it : lost_list) { + if (it.first + 1 == it.second) { size += 4; - }else{ + } else { size += 8; } } return size; } -std::string NAKPacket::dump(){ +std::string NAKPacket::dump() { _StrPrinter printer; for (auto it : lost_list) { - printer<<"[ "<assign((char*)buf,len); + _data->assign((char *)buf, len); loadHeader(); - uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; first_pkt_seq_num = loadUint32(ptr); ptr += 4; @@ -602,17 +598,17 @@ bool MsgDropReqPacket::storeToData() { control_type = DROPREQ; sub_type = 0; _data = BufferRaw::create(); - _data->setCapacity(HEADER_SIZE+8); - _data->setSize(HEADER_SIZE+8); + _data->setCapacity(HEADER_SIZE + 8); + _data->setSize(HEADER_SIZE + 8); storeToHeader(); - uint8_t* ptr = (uint8_t*)_data->data()+HEADER_SIZE; + uint8_t *ptr = (uint8_t *)_data->data() + HEADER_SIZE; - storeUint32(ptr,first_pkt_seq_num); + storeUint32(ptr, first_pkt_seq_num); ptr += 4; - storeUint32(ptr,last_pkt_seq_num); + storeUint32(ptr, last_pkt_seq_num); ptr += 4; return true; } diff --git a/srt/Packet.hpp b/srt/Packet.hpp index bab6df0c..668885da 100644 --- a/srt/Packet.hpp +++ b/srt/Packet.hpp @@ -5,6 +5,7 @@ #include #include "Network/Buffer.h" +#include "Network/sockutil.h" #include "Util/logger.h" #include "Common.hpp" @@ -171,7 +172,7 @@ class HandshakePacket : public ControlPacket { public: using Ptr = std::shared_ptr; enum { NO_ENCRYPTION = 0, AES_128 = 1, AES_196 = 2, AES_256 = 3 }; - static const size_t HS_CONTENT_MIN_SIZE = 48; + static const size_t HS_CONTENT_MIN_SIZE = 48; enum { HS_TYPE_DONE = 0xFFFFFFFD, HS_TYPE_AGREEMENT = 0xFFFFFFFE, @@ -181,18 +182,16 @@ public: }; enum { HS_EXT_FILED_HSREQ = 0x00000001, HS_EXT_FILED_KMREQ = 0x00000002, HS_EXT_FILED_CONFIG = 0x00000004 }; - - - + HandshakePacket() = default; ~HandshakePacket() = default; static bool isHandshakePacket(uint8_t *buf, size_t len); static uint32_t getHandshakeType(uint8_t *buf, size_t len); static uint32_t getSynCookie(uint8_t *buf, size_t len); - static uint32_t generateSynCookie(struct sockaddr_storage* addr,TimePoint ts,uint32_t current_cookie = 0, int correction = 0); + static uint32_t generateSynCookie(struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie = 0, int correction = 0); - void assignPeerIP(struct sockaddr_storage* addr); + void assignPeerIP(struct sockaddr_storage *addr); ///////ControlPacket override/////// bool loadFromData(uint8_t *buf, size_t len) override; bool storeToData() override; @@ -209,8 +208,9 @@ public: uint8_t peer_ip_addr[16]; std::vector ext_list; + private: - bool loadExtMessage(uint8_t *buf,size_t len); + bool loadExtMessage(uint8_t *buf, size_t len); bool storeExtMessage(); size_t getExtSize(); }; @@ -229,13 +229,12 @@ private: Figure 12: Keep-Alive control packet https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-keep-alive */ -class KeepLivePacket : public ControlPacket -{ +class KeepLivePacket : public ControlPacket { public: using Ptr = std::shared_ptr; KeepLivePacket() = default; ~KeepLivePacket() = default; - ///////ControlPacket override/////// + ///////ControlPacket override/////// bool loadFromData(uint8_t *buf, size_t len) override; bool storeToData() override; }; @@ -265,11 +264,10 @@ An SRT NAK packet is formatted as follows: Figure 14: NAK control packet https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-nak-control-packet */ -class NAKPacket : public ControlPacket -{ +class NAKPacket : public ControlPacket { public: using Ptr = std::shared_ptr; - using LostPair = std::pair; + using LostPair = std::pair; NAKPacket() = default; ~NAKPacket() = default; std::string dump(); @@ -278,11 +276,11 @@ public: bool storeToData() override; std::list lost_list; + private: size_t getCIFSize(); }; - /* 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -302,9 +300,8 @@ private: Figure 18: Drop Request control packet https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-message-drop-request */ -class MsgDropReqPacket : public ControlPacket -{ - public: +class MsgDropReqPacket : public ControlPacket { +public: using Ptr = std::shared_ptr; MsgDropReqPacket() = default; ~MsgDropReqPacket() = default; @@ -332,13 +329,13 @@ class MsgDropReqPacket : public ControlPacket https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-shutdown */ -class ShutDownPacket : public ControlPacket -{ +class ShutDownPacket : public ControlPacket { public: using Ptr = std::shared_ptr; ShutDownPacket() = default; ~ShutDownPacket() = default; - ///////ControlPacket override/////// + + ///////ControlPacket override/////// bool loadFromData(uint8_t *buf, size_t len) override { if (len < HEADER_SIZE) { WarnL << "data size" << len << " less " << HEADER_SIZE; @@ -348,7 +345,7 @@ public: _data->assign((char *)buf, len); return loadHeader(); - }; + } bool storeToData() override { control_type = ControlPacket::SHUTDOWN; sub_type = 0; @@ -356,8 +353,9 @@ public: _data->setCapacity(HEADER_SIZE); _data->setSize(HEADER_SIZE); return storeToHeader(); - }; + } }; + } // namespace SRT #endif //ZLMEDIAKIT_SRT_PACKET_H \ No newline at end of file diff --git a/srt/PacketQueue.cpp b/srt/PacketQueue.cpp index 7cfb4781..001585d1 100644 --- a/srt/PacketQueue.cpp +++ b/srt/PacketQueue.cpp @@ -2,80 +2,86 @@ namespace SRT { -#define MAX_SEQ 0x7fffffff -#define MAX_TS 0xffffffff -inline uint32_t genExpectedSeq(uint32_t seq){ +#define MAX_SEQ 0x7fffffff +#define MAX_TS 0xffffffff + +static inline uint32_t genExpectedSeq(uint32_t seq) { return MAX_SEQ & seq; } -inline bool isSeqEdge(uint32_t seq,uint32_t cap){ - if(seq >(MAX_SEQ - cap)){ + +static inline bool isSeqEdge(uint32_t seq, uint32_t cap) { + if (seq > (MAX_SEQ - cap)) { return true; } return false; } -inline bool isTSCycle(uint32_t first,uint32_t second){ +static inline bool isTSCycle(uint32_t first, uint32_t second) { uint32_t diff; - if(first>second){ + if (first > second) { diff = first - second; - }else{ + } else { diff = second - first; } - if(diff > (MAX_TS>>1)){ + if (diff > (MAX_TS >> 1)) { return true; - }else{ + } else { return false; } } -PacketQueue::PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t lantency) - : _pkt_expected_seq(init_seq) - , _pkt_cap(max_size) - , _pkt_lantency(lantency) { -} -void PacketQueue::tryInsertPkt(DataPacket::Ptr pkt){ +PacketQueue::PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t latency) + : _pkt_cap(max_size) + , _pkt_latency(latency) + , _pkt_expected_seq(init_seq) {} + +void PacketQueue::tryInsertPkt(DataPacket::Ptr pkt) { if (_pkt_expected_seq <= pkt->packet_seq_number) { auto diff = pkt->packet_seq_number - _pkt_expected_seq; - if(diff >= (MAX_SEQ>>1)){ - TraceL << "drop packet too later for cycle "<< "expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; + if (diff >= (MAX_SEQ >> 1)) { + TraceL << "drop packet too later for cycle " + << "expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; return; - }else{ + } else { _pkt_map.emplace(pkt->packet_seq_number, pkt); } } else { auto diff = _pkt_expected_seq - pkt->packet_seq_number; - if(diff >= (MAX_SEQ>>1)){ - _pkt_map.emplace(pkt->packet_seq_number, pkt); - TraceL<<" cycle packet "<<"expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; - }else{ - //TraceL << "drop packet too later "<< "expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; + if (diff >= (MAX_SEQ >> 1)) { + _pkt_map.emplace(pkt->packet_seq_number, pkt); + TraceL << " cycle packet " + << "expected seq=" << _pkt_expected_seq << " pkt seq=" << pkt->packet_seq_number; + } else { + // TraceL << "drop packet too later "<< "expected seq=" << _pkt_expected_seq << " pkt seq=" << + // pkt->packet_seq_number; } } } -bool PacketQueue::inputPacket(DataPacket::Ptr pkt,std::list& out) { + +bool PacketQueue::inputPacket(DataPacket::Ptr pkt, std::list &out) { tryInsertPkt(pkt); auto it = _pkt_map.find(_pkt_expected_seq); - while ( it != _pkt_map.end()) { + while (it != _pkt_map.end()) { out.push_back(it->second); _pkt_map.erase(it); - _pkt_expected_seq = genExpectedSeq(_pkt_expected_seq+1); + _pkt_expected_seq = genExpectedSeq(_pkt_expected_seq + 1); it = _pkt_map.find(_pkt_expected_seq); } while (_pkt_map.size() > _pkt_cap) { // 防止回环 it = _pkt_map.find(_pkt_expected_seq); - if(it != _pkt_map.end()){ + if (it != _pkt_map.end()) { out.push_back(it->second); _pkt_map.erase(it); } _pkt_expected_seq = genExpectedSeq(_pkt_expected_seq + 1); } - while (timeLantency() > _pkt_lantency) { + while (timeLatency() > _pkt_latency) { it = _pkt_map.find(_pkt_expected_seq); - if(it != _pkt_map.end()){ + if (it != _pkt_map.end()) { out.push_back(it->second); _pkt_map.erase(it); } @@ -85,22 +91,22 @@ bool PacketQueue::inputPacket(DataPacket::Ptr pkt,std::list& ou return true; } -bool PacketQueue::drop(uint32_t first, uint32_t last,std::list& out){ - uint32_t end = genExpectedSeq(last+1); +bool PacketQueue::drop(uint32_t first, uint32_t last, std::list &out) { + uint32_t end = genExpectedSeq(last + 1); decltype(_pkt_map.end()) it; - for(uint32_t i =_pkt_expected_seq;i< end;){ - it = _pkt_map.find(i); - if(it != _pkt_map.end()){ - out.push_back(it->second); - _pkt_map.erase(it); - } - i = genExpectedSeq(i+1); + for (uint32_t i = _pkt_expected_seq; i < end;) { + it = _pkt_map.find(i); + if (it != _pkt_map.end()) { + out.push_back(it->second); + _pkt_map.erase(it); + } + i = genExpectedSeq(i + 1); } _pkt_expected_seq = end; return true; } -uint32_t PacketQueue::timeLantency() { +uint32_t PacketQueue::timeLatency() { if (_pkt_map.empty()) { return 0; } @@ -108,15 +114,15 @@ uint32_t PacketQueue::timeLantency() { auto first = _pkt_map.begin()->second->timestamp; auto last = _pkt_map.rbegin()->second->timestamp; uint32_t dur; - if(last>first){ + if (last > first) { dur = last - first; - }else{ + } else { dur = first - last; } - if(dur > 0x80000000){ + if (dur > 0x80000000) { dur = MAX_TS - dur; - WarnL<<"cycle dur "< PacketQueue::getLostSeq() { std::list re; - if(_pkt_map.empty()){ + if (_pkt_map.empty()) { return re; } - - if(getExpectedSize() == getSize()){ + + if (getExpectedSize() == getSize()) { return re; } uint32_t end = 0; - uint32_t first,last; + uint32_t first, last; first = _pkt_map.begin()->second->packet_seq_number; last = _pkt_map.rbegin()->second->packet_seq_number; @@ -149,71 +155,76 @@ std::list PacketQueue::getLostSeq() { uint32_t i = _pkt_expected_seq; bool finish = true; - for(i = _pkt_expected_seq;i<=end;){ - if(_pkt_map.find(i) == _pkt_map.end()){ - if(finish){ + for (i = _pkt_expected_seq; i <= end;) { + if (_pkt_map.find(i) == _pkt_map.end()) { + if (finish) { finish = false; lost.first = i; - lost.second = i+1; - }else{ - lost.second = i+1; + lost.second = i + 1; + } else { + lost.second = i + 1; } - }else{ - if(!finish){ + } else { + if (!finish) { finish = true; re.push_back(lost); } } - i = genExpectedSeq(i+1); + i = genExpectedSeq(i + 1); } return re; } -size_t PacketQueue::getSize(){ +size_t PacketQueue::getSize() { return _pkt_map.size(); } size_t PacketQueue::getExpectedSize() { - if(_pkt_map.empty()){ + if (_pkt_map.empty()) { return 0; } uint32_t max = _pkt_map.rbegin()->first; uint32_t min = _pkt_map.begin()->first; - if((max-min)>=(MAX_SEQ>>1)){ - TraceL<<"cycle "<<"expected seq "<<_pkt_expected_seq<<" min "<= (MAX_SEQ >> 1)) { + TraceL << "cycle " + << "expected seq " << _pkt_expected_seq << " min " << min << " max " << max << " size " + << _pkt_map.size(); + return MAX_SEQ - _pkt_expected_seq + min + 1; + } else { + return max - _pkt_expected_seq + 1; } } -size_t PacketQueue::getAvailableBufferSize(){ - auto size = getExpectedSize(); - if(_pkt_cap > size){ +size_t PacketQueue::getAvailableBufferSize() { + auto size = getExpectedSize(); + if (_pkt_cap > size) { return _pkt_cap - size; } - if(_pkt_cap > _pkt_map.size()){ + if (_pkt_cap > _pkt_map.size()) { return _pkt_cap - _pkt_map.size(); } - WarnL<<" cap "<<_pkt_cap<<" expected size "<second->packet_seq_number; - printer<<" last:"<<_pkt_map.rbegin()->second->packet_seq_number; - printer<<" latency:"<second->packet_seq_number; + printer << " last:" << _pkt_map.rbegin()->second->packet_seq_number; + printer << " latency:" << timeLatency() / 1e3; + } + return std::move(printer); } + } // namespace SRT \ No newline at end of file diff --git a/srt/PacketQueue.hpp b/srt/PacketQueue.hpp index 66702d65..58d5f18a 100644 --- a/srt/PacketQueue.hpp +++ b/srt/PacketQueue.hpp @@ -3,8 +3,8 @@ #include "Packet.hpp" #include #include -#include #include +#include #include #include @@ -16,11 +16,11 @@ public: using Ptr = std::shared_ptr; using LostPair = std::pair; - PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t lantency); + PacketQueue(uint32_t max_size, uint32_t init_seq, uint32_t latency); ~PacketQueue() = default; - bool inputPacket(DataPacket::Ptr pkt,std::list& out); + bool inputPacket(DataPacket::Ptr pkt, std::list &out); - uint32_t timeLantency(); + uint32_t timeLatency(); std::list getLostSeq(); size_t getSize(); @@ -28,18 +28,17 @@ public: size_t getAvailableBufferSize(); uint32_t getExpectedSeq(); - bool drop(uint32_t first, uint32_t last,std::list& out); - std::string dump(); + bool drop(uint32_t first, uint32_t last, std::list &out); + private: void tryInsertPkt(DataPacket::Ptr pkt); + private: - - std::map _pkt_map; - - uint32_t _pkt_expected_seq = 0; uint32_t _pkt_cap; - uint32_t _pkt_lantency; + uint32_t _pkt_latency; + uint32_t _pkt_expected_seq = 0; + std::map _pkt_map; }; } // namespace SRT diff --git a/srt/PacketSendQueue.cpp b/srt/PacketSendQueue.cpp index 9ea11aaf..654b1194 100644 --- a/srt/PacketSendQueue.cpp +++ b/srt/PacketSendQueue.cpp @@ -2,9 +2,10 @@ namespace SRT { -PacketSendQueue::PacketSendQueue(uint32_t max_size, uint32_t lantency) +PacketSendQueue::PacketSendQueue(uint32_t max_size, uint32_t latency) : _pkt_cap(max_size) - , _pkt_lantency(lantency) {} + , _pkt_latency(latency) {} + bool PacketSendQueue::drop(uint32_t num) { decltype(_pkt_cache.begin()) it; for (it = _pkt_cache.begin(); it != _pkt_cache.end(); ++it) { @@ -17,12 +18,13 @@ bool PacketSendQueue::drop(uint32_t num) { } return true; } + bool PacketSendQueue::inputPacket(DataPacket::Ptr pkt) { _pkt_cache.push_back(pkt); while (_pkt_cache.size() > _pkt_cap) { _pkt_cache.pop_front(); } - while (timeLantency() > _pkt_lantency) { + while (timeLatency() > _pkt_latency) { _pkt_cache.pop_front(); } return true; @@ -53,7 +55,7 @@ std::list PacketSendQueue::findPacketBySeq(uint32_t start, uint return re; } -uint32_t PacketSendQueue::timeLantency() { +uint32_t PacketSendQueue::timeLatency() { if (_pkt_cache.empty()) { return 0; } @@ -67,7 +69,7 @@ uint32_t PacketSendQueue::timeLantency() { dur = first - last; } if (dur > (0x01 << 31)) { - TraceL << "cycle timeLantency " << dur; + TraceL << "cycle timeLatency " << dur; dur = 0xffffffff - dur; } diff --git a/srt/PacketSendQueue.hpp b/srt/PacketSendQueue.hpp index 86fa86f1..43227aca 100644 --- a/srt/PacketSendQueue.hpp +++ b/srt/PacketSendQueue.hpp @@ -1,5 +1,6 @@ #ifndef ZLMEDIAKIT_SRT_PACKET_SEND_QUEUE_H #define ZLMEDIAKIT_SRT_PACKET_SEND_QUEUE_H + #include "Packet.hpp" #include #include @@ -7,23 +8,30 @@ #include #include #include + namespace SRT { + class PacketSendQueue { public: using Ptr = std::shared_ptr; using LostPair = std::pair; - PacketSendQueue(uint32_t max_size, uint32_t lantency); + + PacketSendQueue(uint32_t max_size, uint32_t latency); ~PacketSendQueue() = default; + bool drop(uint32_t num); bool inputPacket(DataPacket::Ptr pkt); - std::list findPacketBySeq(uint32_t start,uint32_t end); + std::list findPacketBySeq(uint32_t start, uint32_t end); + private: - uint32_t timeLantency(); + uint32_t timeLatency(); + private: - std::list _pkt_cache; uint32_t _pkt_cap; - uint32_t _pkt_lantency; + uint32_t _pkt_latency; + std::list _pkt_cache; }; + } // namespace SRT #endif // ZLMEDIAKIT_SRT_PACKET_SEND_QUEUE_H \ No newline at end of file diff --git a/srt/SrtSession.cpp b/srt/SrtSession.cpp index fe8cfc13..d62cb7cd 100644 --- a/srt/SrtSession.cpp +++ b/srt/SrtSession.cpp @@ -10,10 +10,10 @@ using namespace mediakit; SrtSession::SrtSession(const Socket::Ptr &sock) : UdpSession(sock) { socklen_t addr_len = sizeof(_peer_addr); - memset(&_peer_addr,0,addr_len); - //TraceL<<"before addr len "<rawFD(), (struct sockaddr *)&_peer_addr, &addr_len); - //TraceL<<"after addr len "<data(); + uint8_t *data = (uint8_t *)buffer->data(); size_t size = buffer->size(); - if(DataPacket::isDataPacket(data,size)){ - uint32_t socket_id = DataPacket::getSocketID(data,size); + if (DataPacket::isDataPacket(data, size)) { + uint32_t socket_id = DataPacket::getSocketID(data, size); auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); return trans ? trans->getPoller() : nullptr; } - if(HandshakePacket::isHandshakePacket(data,size)){ - auto type = HandshakePacket::getHandshakeType(data,size); - if(type == HandshakePacket::HS_TYPE_INDUCTION){ + if (HandshakePacket::isHandshakePacket(data, size)) { + auto type = HandshakePacket::getHandshakeType(data, size); + if (type == HandshakePacket::HS_TYPE_INDUCTION) { // 握手第一阶段 return nullptr; - }else if(type == HandshakePacket::HS_TYPE_CONCLUSION){ + } else if (type == HandshakePacket::HS_TYPE_CONCLUSION) { // 握手第二阶段 - uint32_t sync_cookie = HandshakePacket::getSynCookie(data,size); + uint32_t sync_cookie = HandshakePacket::getSynCookie(data, size); auto trans = SrtTransportManager::Instance().getHandshakeItem(std::to_string(sync_cookie)); return trans ? trans->getPoller() : nullptr; - }else{ - WarnL<<" not reach there"; + } else { + WarnL << " not reach there"; } - }else{ - uint32_t socket_id = ControlPacket::getSocketID(data,size); + } else { + uint32_t socket_id = ControlPacket::getSocketID(data, size); auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); return trans ? trans->getPoller() : nullptr; } return nullptr; } -void SrtSession::attachServer(const toolkit::Server &server){ - SockUtil::setRecvBuf(getSock()->rawFD(),1024 * 1024); + +void SrtSession::attachServer(const toolkit::Server &server) { + SockUtil::setRecvBuf(getSock()->rawFD(), 1024 * 1024); } + void SrtSession::onRecv(const Buffer::Ptr &buffer) { - uint8_t* data = (uint8_t*)buffer->data(); + uint8_t *data = (uint8_t *)buffer->data(); size_t size = buffer->size(); if (_find_transport) { @@ -64,10 +66,10 @@ void SrtSession::onRecv(const Buffer::Ptr &buffer) { if (DataPacket::isDataPacket(data, size)) { uint32_t socket_id = DataPacket::getSocketID(data, size); auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); - if(trans){ + if (trans) { _transport = std::move(trans); - }else{ - WarnL<<" data packet not find transport "; + } else { + WarnL << " data packet not find transport "; } } @@ -92,24 +94,24 @@ void SrtSession::onRecv(const Buffer::Ptr &buffer) { } else { uint32_t socket_id = ControlPacket::getSocketID(data, size); auto trans = SrtTransportManager::Instance().getItem(std::to_string(socket_id)); - if(trans){ + if (trans) { _transport = std::move(trans); - }else{ + } else { WarnL << " not find transport"; } } - if(_transport){ + if (_transport) { _transport->setSession(shared_from_this()); } InfoP(this); } _ticker.resetTime(); - if(_transport){ - _transport->inputSockData(data,size,&_peer_addr); - }else{ - //WarnL<< "ingore data"; + if (_transport) { + _transport->inputSockData(data, size, &_peer_addr); + } else { + // WarnL<< "ingore data"; } } @@ -122,18 +124,20 @@ void SrtSession::onError(const SockException &err) { if (!_transport) { return; } - + // 防止互相引用导致不释放 auto transport = std::move(_transport); - getPoller()->async([transport,err] { - //延时减引用,防止使用transport对象时,销毁对象 - transport->onShutdown(err); - }, false); + getPoller()->async( + [transport, err] { + //延时减引用,防止使用transport对象时,销毁对象 + transport->onShutdown(err); + }, + false); } void SrtSession::onManager() { GET_CONFIG(float, timeoutSec, kTimeOutSec); - if (_ticker.elapsedTime() > timeoutSec*1000) { + if (_ticker.elapsedTime() > timeoutSec * 1000) { shutdown(SockException(Err_timeout, "srt connection timeout")); return; } diff --git a/srt/SrtSession.hpp b/srt/SrtSession.hpp index 401aae3c..342a4a91 100644 --- a/srt/SrtSession.hpp +++ b/srt/SrtSession.hpp @@ -24,8 +24,7 @@ private: Ticker _ticker; struct sockaddr_storage _peer_addr; SrtTransport::Ptr _transport; - }; } // namespace SRT -#endif //ZLMEDIAKIT_SRT_SESSION_H \ No newline at end of file +#endif // ZLMEDIAKIT_SRT_SESSION_H \ No newline at end of file diff --git a/srt/SrtTransport.cpp b/srt/SrtTransport.cpp index e7a84bf3..cce3e910 100644 --- a/srt/SrtTransport.cpp +++ b/srt/SrtTransport.cpp @@ -1,32 +1,33 @@ #include +#include "Ack.hpp" +#include "Packet.hpp" +#include "SrtTransport.hpp" #include "Util/onceToken.h" -#include "SrtTransport.hpp" -#include "Packet.hpp" -#include "Ack.hpp" namespace SRT { #define SRT_FIELD "srt." -//srt 超时时间 -const std::string kTimeOutSec = SRT_FIELD"timeoutSec"; -//srt 单端口udp服务器 -const std::string kPort = SRT_FIELD"port"; +// srt 超时时间 +const std::string kTimeOutSec = SRT_FIELD "timeoutSec"; +// srt 单端口udp服务器 +const std::string kPort = SRT_FIELD "port"; +const std::string kLatencyMul = SRT_FIELD "latencyMul"; -const std::string kLantencyMul = SRT_FIELD"lantencyMul"; +static std::atomic s_srt_socket_id_generate { 125 }; -static std::atomic s_srt_socket_id_generate{125}; //////////// SrtTransport ////////////////////////// SrtTransport::SrtTransport(const EventPoller::Ptr &poller) : _poller(poller) { - _start_timestamp = SteadyClock::now(); - _socket_id = s_srt_socket_id_generate.fetch_add(1);\ - _pkt_recv_rate_context = std::make_shared(_start_timestamp); - _recv_rate_context = std::make_shared(_start_timestamp); - _estimated_link_capacity_context = std::make_shared(_start_timestamp); - } - -SrtTransport::~SrtTransport(){ - TraceL<<" "; + _start_timestamp = SteadyClock::now(); + _socket_id = s_srt_socket_id_generate.fetch_add(1); + _pkt_recv_rate_context = std::make_shared(_start_timestamp); + _recv_rate_context = std::make_shared(_start_timestamp); + _estimated_link_capacity_context = std::make_shared(_start_timestamp); } + +SrtTransport::~SrtTransport() { + TraceL << " "; +} + const EventPoller::Ptr &SrtTransport::getPoller() const { return _poller; } @@ -40,24 +41,25 @@ void SrtTransport::setSession(Session::Ptr session) { } _selected_session = session; } + const Session::Ptr &SrtTransport::getSession() const { return _selected_session; } -void SrtTransport::switchToOtherTransport(uint8_t *buf, int len,uint32_t socketid, struct sockaddr_storage *addr){ +void SrtTransport::switchToOtherTransport(uint8_t *buf, int len, uint32_t socketid, struct sockaddr_storage *addr) { BufferRaw::Ptr tmp = BufferRaw::create(); struct sockaddr_storage tmp_addr = *addr; - tmp->assign((char*)buf,len); + tmp->assign((char *)buf, len); auto trans = SrtTransportManager::Instance().getItem(std::to_string(socketid)); - if(trans){ - trans->getPoller()->async([tmp,tmp_addr,trans]{ - trans->inputSockData((uint8_t*)tmp->data(),tmp->size(),(struct sockaddr_storage*)&tmp_addr); + if (trans) { + trans->getPoller()->async([tmp, tmp_addr, trans] { + trans->inputSockData((uint8_t *)tmp->data(), tmp->size(), (struct sockaddr_storage *)&tmp_addr); }); } } void SrtTransport::inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr) { - using srt_control_handler = void (SrtTransport::*)(uint8_t* buf,int len,struct sockaddr_storage *addr); + using srt_control_handler = void (SrtTransport::*)(uint8_t * buf, int len, struct sockaddr_storage *addr); static std::unordered_map s_control_functions; static onceToken token([]() { s_control_functions.emplace(ControlPacket::HANDSHAKE, &SrtTransport::handleHandshake); @@ -74,23 +76,23 @@ void SrtTransport::inputSockData(uint8_t *buf, int len, struct sockaddr_storage _now = SteadyClock::now(); // 处理srt数据 if (DataPacket::isDataPacket(buf, len)) { - uint32_t socketId = DataPacket::getSocketID(buf,len); - if(socketId == _socket_id){ + uint32_t socketId = DataPacket::getSocketID(buf, len); + if (socketId == _socket_id) { _pkt_recv_rate_context->inputPacket(_now); _estimated_link_capacity_context->inputPacket(_now); _recv_rate_context->inputPacket(_now, len); handleDataPacket(buf, len, addr); - }else{ - switchToOtherTransport(buf,len,socketId,addr); + } else { + switchToOtherTransport(buf, len, socketId, addr); } } else { if (ControlPacket::isControlPacket(buf, len)) { - uint32_t socketId = ControlPacket::getSocketID(buf,len); - uint16_t type = ControlPacket::getControlType(buf,len); - if(type != ControlPacket::HANDSHAKE && socketId != _socket_id && _socket_id != 0){ + uint32_t socketId = ControlPacket::getSocketID(buf, len); + uint16_t type = ControlPacket::getControlType(buf, len); + if (type != ControlPacket::HANDSHAKE && socketId != _socket_id && _socket_id != 0) { // socket id not same - switchToOtherTransport(buf,len,socketId,addr); + switchToOtherTransport(buf, len, socketId, addr); return; } _pkt_recv_rate_context->inputPacket(_now); @@ -99,10 +101,10 @@ void SrtTransport::inputSockData(uint8_t *buf, int len, struct sockaddr_storage auto it = s_control_functions.find(type); if (it == s_control_functions.end()) { - WarnL<<" not support type ignore" << ControlPacket::getControlType(buf,len); + WarnL << " not support type ignore" << ControlPacket::getControlType(buf, len); return; - }else{ - (this->*(it->second))(buf,len,addr); + } else { + (this->*(it->second))(buf, len, addr); } } else { // not reach @@ -119,7 +121,7 @@ void SrtTransport::handleHandshakeInduction(HandshakePacket &pkt, struct sockadd sendControlPacket(_handleshake_res, true); return; } - _induction_ts = _now; + _induction_ts = _now; _start_timestamp = _now; _init_seq_number = pkt.initial_packet_sequence_number; _max_window_size = pkt.max_flow_window_size; @@ -146,9 +148,10 @@ void SrtTransport::handleHandshakeInduction(HandshakePacket &pkt, struct sockadd registerSelfHandshake(); sendControlPacket(res, true); } + void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockaddr_storage *addr) { - if(!_handleshake_res){ - ErrorL<<"must Induction Phase for handleshake "; + if (!_handleshake_res) { + ErrorL << "must Induction Phase for handleshake "; return; } @@ -157,21 +160,21 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad HSExtMessage::Ptr req; HSExtStreamID::Ptr sid; uint32_t srt_flag = 0xbf; - uint16_t delay = DurationCountMicroseconds(_now - _induction_ts)*getLantencyMul()/1000; + uint16_t delay = DurationCountMicroseconds(_now - _induction_ts) * getLatencyMul() / 1000; for (auto ext : pkt.ext_list) { - //TraceL << getIdentifier() << " ext " << ext->dump(); + // TraceL << getIdentifier() << " ext " << ext->dump(); if (!req) { req = std::dynamic_pointer_cast(ext); } - if(!sid){ + if (!sid) { sid = std::dynamic_pointer_cast(ext); } } - if(sid){ + if (sid) { _stream_id = sid->streamid; } - if(req){ + if (req) { srt_flag = req->srt_flag; delay = delay <= req->recv_tsbpd_delay ? req->recv_tsbpd_delay : delay; } @@ -200,167 +203,174 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad unregisterSelfHandshake(); registerSelf(); sendControlPacket(res, true); - TraceL<<" buf size = "<max_flow_window_size<<" init seq ="<<_init_seq_number<<" lantency="<(res->max_flow_window_size,_init_seq_number, delay*1e3); - _send_buf = std::make_shared(res->max_flow_window_size, delay*1e3); + TraceL << " buf size = " << res->max_flow_window_size << " init seq =" << _init_seq_number << " latency=" << delay; + _recv_buf = std::make_shared(res->max_flow_window_size, _init_seq_number, delay * 1e3); + _send_buf = std::make_shared(res->max_flow_window_size, delay * 1e3); _send_packet_seq_number = _init_seq_number; _buf_delay = delay; - onHandShakeFinished(_stream_id,addr); + onHandShakeFinished(_stream_id, addr); } else { TraceL << getIdentifier() << " CONCLUSION handle repeate "; sendControlPacket(_handleshake_res, true); } _last_ack_pkt_seq_num = _init_seq_number; } -void SrtTransport::handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr){ - HandshakePacket pkt; - assert(pkt.loadFromData(buf,len)); - if(pkt.handshake_type == HandshakePacket::HS_TYPE_INDUCTION){ - handleHandshakeInduction(pkt,addr); - }else if(pkt.handshake_type == HandshakePacket::HS_TYPE_CONCLUSION){ - handleHandshakeConclusion(pkt,addr); - }else{ - WarnL<<" not support handshake type = "<< pkt.handshake_type; +void SrtTransport::handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr) { + HandshakePacket pkt; + assert(pkt.loadFromData(buf, len)); + + if (pkt.handshake_type == HandshakePacket::HS_TYPE_INDUCTION) { + handleHandshakeInduction(pkt, addr); + } else if (pkt.handshake_type == HandshakePacket::HS_TYPE_CONCLUSION) { + handleHandshakeConclusion(pkt, addr); + } else { + WarnL << " not support handshake type = " << pkt.handshake_type; } _ack_ticker.resetTime(_now); _nak_ticker.resetTime(_now); } -void SrtTransport::handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr){ - //TraceL; + +void SrtTransport::handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; sendKeepLivePacket(); } - void SrtTransport::sendKeepLivePacket(){ +void SrtTransport::sendKeepLivePacket() { KeepLivePacket::Ptr pkt = std::make_shared(); pkt->dst_socket_id = _peer_socket_id; - pkt->timestamp = DurationCountMicroseconds(_now -_start_timestamp); + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->storeToData(); - sendControlPacket(pkt,true); - } -void SrtTransport::handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr){ - //TraceL; + sendControlPacket(pkt, true); +} + +void SrtTransport::handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; ACKPacket ack; - if(!ack.loadFromData(buf,len)){ + if (!ack.loadFromData(buf, len)) { return; } ACKACKPacket::Ptr pkt = std::make_shared(); pkt->dst_socket_id = _peer_socket_id; - pkt->timestamp = DurationCountMicroseconds(_now -_start_timestamp); + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->ack_number = ack.ack_number; pkt->storeToData(); _send_buf->drop(ack.last_ack_pkt_seq_number); - sendControlPacket(pkt,true); - //TraceL<<"ack number "<(); pkt->dst_socket_id = _peer_socket_id; - pkt->timestamp = DurationCountMicroseconds(_now -_start_timestamp); + pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->first_pkt_seq_num = first; pkt->last_pkt_seq_num = last; pkt->storeToData(); - sendControlPacket(pkt,true); + sendControlPacket(pkt, true); } -void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr){ - //TraceL; + +void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; NAKPacket pkt; - pkt.loadFromData(buf,len); + pkt.loadFromData(buf, len); bool empty = false; bool flush = false; - for(auto it : pkt.lost_list){ - if(pkt.lost_list.back() == it){ + for (auto it : pkt.lost_list) { + if (pkt.lost_list.back() == it) { flush = true; } empty = true; - auto re_list = _send_buf->findPacketBySeq(it.first,it.second-1); - for(auto pkt : re_list){ + auto re_list = _send_buf->findPacketBySeq(it.first, it.second - 1); + for (auto pkt : re_list) { pkt->R = 1; pkt->storeToHeader(); - sendPacket(pkt,flush); + sendPacket(pkt, flush); empty = false; } - if(empty){ - sendMsgDropReq(it.first,it.second-1); + if (empty) { + sendMsgDropReq(it.first, it.second - 1); } } } -void SrtTransport::handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr){ + +void SrtTransport::handleCongestionWarning(uint8_t *buf, int len, struct sockaddr_storage *addr) { TraceL; } -void SrtTransport::handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr){ + +void SrtTransport::handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr) { TraceL; onShutdown(SockException(Err_shutdown, "peer close connection")); } -void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr){ + +void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr) { MsgDropReqPacket pkt; - pkt.loadFromData(buf,len); + pkt.loadFromData(buf, len); std::list list; - //TraceL<<"drop "<drop(pkt.first_pkt_seq_num,pkt.last_pkt_seq_num,list); - if(list.empty()){ + // TraceL<<"drop "<drop(pkt.first_pkt_seq_num, pkt.last_pkt_seq_num, list); + if (list.empty()) { return; } - for(auto data : list){ + for (auto data : list) { onSRTData(std::move(data)); } - auto nak_interval = (_rtt+_rtt_variance*4)/2; - if(nak_interval <= 20*1000){ - nak_interval = 20*1000; + auto nak_interval = (_rtt + _rtt_variance * 4) / 2; + if (nak_interval <= 20 * 1000) { + nak_interval = 20 * 1000; } - if(_nak_ticker.elapsedTime(_now)>nak_interval){ + if (_nak_ticker.elapsedTime(_now) > nak_interval) { auto lost = _recv_buf->getLostSeq(); - if(!lost.empty()){ - sendNAKPacket(lost); + if (!lost.empty()) { + sendNAKPacket(lost); } _nak_ticker.resetTime(_now); } - if(_ack_ticker.elapsedTime(_now)>10*1000){ + if (_ack_ticker.elapsedTime(_now) > 10 * 1000) { _light_ack_pkt_count = 0; _ack_ticker.resetTime(_now); - // send a ack per 10 ms for receiver + // send a ack per 10 ms for receiver sendACKPacket(); - }else{ - if(_light_ack_pkt_count >= 64){ + } else { + if (_light_ack_pkt_count >= 64) { // for high bitrate stream send light ack - // TODO + // TODO sendLightACKPacket(); - TraceL<<"send light ack"; + TraceL << "send light ack"; } _light_ack_pkt_count = 0; } _light_ack_pkt_count++; - } -void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr){ + +void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr) { TraceL; } -void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr){ - //TraceL; +void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr) { + // TraceL; ACKACKPacket::Ptr pkt = std::make_shared(); - pkt->loadFromData(buf,len); + pkt->loadFromData(buf, len); uint32_t rtt = DurationCountMicroseconds(_now - _ack_send_timestamp[pkt->ack_number]); - _rtt_variance = (3*_rtt_variance+abs((long)_rtt - (long)rtt))/4; - _rtt = (7*rtt+_rtt)/8; + _rtt_variance = (3 * _rtt_variance + abs((long)_rtt - (long)rtt)) / 4; + _rtt = (7 * rtt + _rtt) / 8; - - //TraceL<<" rtt:"<<_rtt<<" rtt variance:"<<_rtt_variance; + // TraceL<<" rtt:"<<_rtt<<" rtt variance:"<<_rtt_variance; _ack_send_timestamp.erase(pkt->ack_number); } -void SrtTransport::handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr){ +void SrtTransport::handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr) { TraceL; } void SrtTransport::sendACKPacket() { - ACKPacket::Ptr pkt=std::make_shared(); + ACKPacket::Ptr pkt = std::make_shared(); pkt->dst_socket_id = _peer_socket_id; pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->ack_number = ++_ack_number_count; @@ -373,13 +383,14 @@ void SrtTransport::sendACKPacket() { pkt->recv_rate = _recv_rate_context->getRecvRate(); pkt->storeToData(); _ack_send_timestamp[pkt->ack_number] = _now; - _last_ack_pkt_seq_num = pkt->last_ack_pkt_seq_number; - sendControlPacket(pkt,true); - //TraceL<<"send ack "<dump(); + _last_ack_pkt_seq_num = pkt->last_ack_pkt_seq_number; + sendControlPacket(pkt, true); + // TraceL<<"send ack "<dump(); } + void SrtTransport::sendLightACKPacket() { - ACKPacket::Ptr pkt=std::make_shared(); - + ACKPacket::Ptr pkt = std::make_shared(); + pkt->dst_socket_id = _peer_socket_id; pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->ack_number = 0; @@ -392,11 +403,11 @@ void SrtTransport::sendLightACKPacket() { pkt->recv_rate = 0; pkt->storeToData(); _last_ack_pkt_seq_num = pkt->last_ack_pkt_seq_number; - sendControlPacket(pkt,true); - TraceL<<"send ack "<dump(); + sendControlPacket(pkt, true); + TraceL << "send ack " << pkt->dump(); } -void SrtTransport::sendNAKPacket(std::list& lost_list){ +void SrtTransport::sendNAKPacket(std::list &lost_list) { NAKPacket::Ptr pkt = std::make_shared(); pkt->dst_socket_id = _peer_socket_id; @@ -405,112 +416,118 @@ void SrtTransport::sendNAKPacket(std::list& lost_list){ pkt->storeToData(); - //TraceL<<"send NAK "<dump(); - sendControlPacket(pkt,true); + // TraceL<<"send NAK "<dump(); + sendControlPacket(pkt, true); } -void SrtTransport::sendShutDown(){ +void SrtTransport::sendShutDown() { ShutDownPacket::Ptr pkt = std::make_shared(); pkt->dst_socket_id = _peer_socket_id; pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->storeToData(); - sendControlPacket(pkt,true); + sendControlPacket(pkt, true); } -void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr){ + +void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr) { DataPacket::Ptr pkt = std::make_shared(); - pkt->loadFromData(buf,len); + pkt->loadFromData(buf, len); pkt->get_ts = _now; std::list list; //TraceL<<" seq="<< pkt->packet_seq_number<<" ts="<timestamp<<" size="<payloadSize()<<\ //" PP="<<(int)pkt->PP<<" O="<<(int)pkt->O<<" kK="<<(int)pkt->KK<<" R="<<(int)pkt->R; - _recv_buf->inputPacket(pkt,list); - for(auto data : list){ + _recv_buf->inputPacket(pkt, list); + for (auto data : list) { onSRTData(std::move(data)); } - auto nak_interval = (_rtt+_rtt_variance*4)/2; - if(nak_interval <= 20*1000){ - nak_interval = 20*1000; + auto nak_interval = (_rtt + _rtt_variance * 4) / 2; + if (nak_interval <= 20 * 1000) { + nak_interval = 20 * 1000; } - if(list.empty()){ - //TraceL<<_recv_buf->dump()<<" nake interval:"<dump()<<" nake interval:"<nak_interval){ + if (_nak_ticker.elapsedTime(_now) > nak_interval) { auto lost = _recv_buf->getLostSeq(); - if(!lost.empty()){ - sendNAKPacket(lost); - //TraceL<<"send NAK"; - }else{ - //TraceL<<"lost is empty"; + if (!lost.empty()) { + sendNAKPacket(lost); + // TraceL<<"send NAK"; + } else { + // TraceL<<"lost is empty"; } _nak_ticker.resetTime(_now); } - if(_ack_ticker.elapsedTime(_now)>10*1000){ + if (_ack_ticker.elapsedTime(_now) > 10 * 1000) { _light_ack_pkt_count = 0; _ack_ticker.resetTime(_now); - // send a ack per 10 ms for receiver + // send a ack per 10 ms for receiver sendACKPacket(); - }else{ - if(_light_ack_pkt_count >= 64){ + } else { + if (_light_ack_pkt_count >= 64) { // for high bitrate stream send light ack - // TODO + // TODO sendLightACKPacket(); - TraceL<<"send light ack"; + TraceL << "send light ack"; } _light_ack_pkt_count = 0; } _light_ack_pkt_count++; - //bufCheckInterval(); + // bufCheckInterval(); } -void SrtTransport::sendDataPacket(DataPacket::Ptr pkt,char* buf,int len, bool flush) { - pkt->storeToData((uint8_t*)buf,len); - sendPacket(pkt,flush); +void SrtTransport::sendDataPacket(DataPacket::Ptr pkt, char *buf, int len, bool flush) { + pkt->storeToData((uint8_t *)buf, len); + sendPacket(pkt, flush); _send_buf->inputPacket(pkt); } -void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) { - sendPacket(pkt,flush); + +void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) { + sendPacket(pkt, flush); } -void SrtTransport::sendPacket(Buffer::Ptr pkt,bool flush){ - if(_selected_session){ - auto tmp = _packet_pool.obtain2(); - tmp->assign(pkt->data(),pkt->size()); - _selected_session->setSendFlushFlag(flush); - _selected_session->send(std::move(tmp)); - }else{ - WarnL<<"not reach this"; + +void SrtTransport::sendPacket(Buffer::Ptr pkt, bool flush) { + if (_selected_session) { + auto tmp = _packet_pool.obtain2(); + tmp->assign(pkt->data(), pkt->size()); + _selected_session->setSendFlushFlag(flush); + _selected_session->send(std::move(tmp)); + } else { + WarnL << "not reach this"; } } -std::string SrtTransport::getIdentifier(){ + +std::string SrtTransport::getIdentifier() { return _selected_session ? _selected_session->getIdentifier() : ""; } -void SrtTransport::registerSelfHandshake() { - SrtTransportManager::Instance().addHandshakeItem(std::to_string(_sync_cookie),shared_from_this()); +void SrtTransport::registerSelfHandshake() { + SrtTransportManager::Instance().addHandshakeItem(std::to_string(_sync_cookie), shared_from_this()); } -void SrtTransport::unregisterSelfHandshake() { - if(_sync_cookie == 0){ + +void SrtTransport::unregisterSelfHandshake() { + if (_sync_cookie == 0) { return; } SrtTransportManager::Instance().removeHandshakeItem(std::to_string(_sync_cookie)); } void SrtTransport::registerSelf() { - if(_socket_id == 0){ + if (_socket_id == 0) { return; } - SrtTransportManager::Instance().addItem(std::to_string(_socket_id),shared_from_this()); - + SrtTransportManager::Instance().addItem(std::to_string(_socket_id), shared_from_this()); } -void SrtTransport::unregisterSelf() { + +void SrtTransport::unregisterSelf() { SrtTransportManager::Instance().removeItem(std::to_string(_socket_id)); } -void SrtTransport::onShutdown(const SockException &ex){ +void SrtTransport::onShutdown(const SockException &ex) { sendShutDown(); WarnL << ex.what(); unregisterSelfHandshake(); @@ -522,23 +539,25 @@ void SrtTransport::onShutdown(const SockException &ex){ } } } -size_t SrtTransport::getPayloadSize(){ - size_t ret = (_mtu - 28 -16)/188*188; + +size_t SrtTransport::getPayloadSize() { + size_t ret = (_mtu - 28 - 16) / 188 * 188; return ret; } -void SrtTransport::onSendTSData(const Buffer::Ptr &buffer, bool flush){ - //TraceL; + +void SrtTransport::onSendTSData(const Buffer::Ptr &buffer, bool flush) { + // TraceL; DataPacket::Ptr pkt; size_t payloadSize = getPayloadSize(); - size_t size = buffer->size(); - char* ptr = buffer->data(); - char* end = buffer->data()+size; + size_t size = buffer->size(); + char *ptr = buffer->data(); + char *end = buffer->data() + size; - while(ptr < end && size >=payloadSize){ + while (ptr < end && size >= payloadSize) { pkt = std::make_shared(); pkt->f = 0; - pkt->packet_seq_number = _send_packet_seq_number&0x7fffffff; - _send_packet_seq_number = (_send_packet_seq_number+1)&0x7fffffff; + pkt->packet_seq_number = _send_packet_seq_number & 0x7fffffff; + _send_packet_seq_number = (_send_packet_seq_number + 1) & 0x7fffffff; pkt->PP = 3; pkt->O = 0; pkt->KK = 0; @@ -546,16 +565,16 @@ void SrtTransport::onSendTSData(const Buffer::Ptr &buffer, bool flush){ pkt->msg_number = _send_msg_number++; pkt->dst_socket_id = _peer_socket_id; pkt->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); - sendDataPacket(pkt,ptr,(int)payloadSize,flush); + sendDataPacket(pkt, ptr, (int)payloadSize, flush); ptr += payloadSize; size -= payloadSize; } - if(size >0 && ptr 0 && ptr < end) { pkt = std::make_shared(); pkt->f = 0; - pkt->packet_seq_number = _send_packet_seq_number&0x7fffffff; - _send_packet_seq_number = (_send_packet_seq_number+1)&0x7fffffff; + pkt->packet_seq_number = _send_packet_seq_number & 0x7fffffff; + _send_packet_seq_number = (_send_packet_seq_number + 1) & 0x7fffffff; pkt->PP = 3; pkt->O = 0; pkt->KK = 0; @@ -563,11 +582,12 @@ void SrtTransport::onSendTSData(const Buffer::Ptr &buffer, bool flush){ pkt->msg_number = _send_msg_number++; pkt->dst_socket_id = _peer_socket_id; pkt->timestamp = DurationCountMicroseconds(SteadyClock::now() - _start_timestamp); - sendDataPacket(pkt,ptr,(int)size,flush); + sendDataPacket(pkt, ptr, (int)size, flush); } - } + //////////// SrtTransportManager ////////////////////////// + SrtTransportManager &SrtTransportManager::Instance() { static SrtTransportManager s_instance; return s_instance; @@ -599,10 +619,12 @@ void SrtTransportManager::addHandshakeItem(const std::string &key, const SrtTran std::lock_guard lck(_handshake_mtx); _handshake_map[key] = ptr; } + void SrtTransportManager::removeHandshakeItem(const std::string &key) { - std::lock_guard lck(_handshake_mtx); + std::lock_guard lck(_handshake_mtx); _handshake_map.erase(key); } + SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const std::string &key) { if (key.empty()) { return nullptr; @@ -615,5 +637,4 @@ SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const std::string &key) return it->second.lock(); } - } // namespace SRT \ No newline at end of file diff --git a/srt/SrtTransport.hpp b/srt/SrtTransport.hpp index c055dd62..49d05677 100644 --- a/srt/SrtTransport.hpp +++ b/srt/SrtTransport.hpp @@ -1,10 +1,10 @@ #ifndef ZLMEDIAKIT_SRT_TRANSPORT_H #define ZLMEDIAKIT_SRT_TRANSPORT_H -#include +#include #include #include -#include +#include #include "Network/Session.h" #include "Poller/EventPoller.h" @@ -17,11 +17,12 @@ #include "Statistic.hpp" namespace SRT { + using namespace toolkit; extern const std::string kPort; extern const std::string kTimeOutSec; -extern const std::string kLantencyMul; +extern const std::string kLatencyMul; class SrtTransport : public std::enable_shared_from_this { public: @@ -33,6 +34,7 @@ public: const EventPoller::Ptr &getPoller() const; void setSession(Session::Ptr session); const Session::Ptr &getSession() const; + /** * socket收到udp数据 * @param buf 数据指针 @@ -43,26 +45,26 @@ public: virtual void onSendTSData(const Buffer::Ptr &buffer, bool flush); std::string getIdentifier(); - - void unregisterSelfHandshake(); void unregisterSelf(); + void unregisterSelfHandshake(); + protected: - virtual void onHandShakeFinished(std::string& streamid,struct sockaddr_storage *addr){}; - virtual void onSRTData(DataPacket::Ptr pkt){}; + virtual bool isPusher() { return true; }; + virtual void onSRTData(DataPacket::Ptr pkt) {}; virtual void onShutdown(const SockException &ex); - virtual bool isPusher(){ - return true; - }; + virtual void onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) {}; + virtual void sendPacket(Buffer::Ptr pkt, bool flush = true); + virtual int getLatencyMul() { return 4; }; private: - void registerSelfHandshake(); void registerSelf(); + void registerSelfHandshake(); - void switchToOtherTransport(uint8_t *buf, int len,uint32_t socketid, struct sockaddr_storage *addr); + void switchToOtherTransport(uint8_t *buf, int len, uint32_t socketid, struct sockaddr_storage *addr); void handleHandshake(uint8_t *buf, int len, struct sockaddr_storage *addr); - void handleHandshakeInduction(HandshakePacket& pkt,struct sockaddr_storage *addr); - void handleHandshakeConclusion(HandshakePacket& pkt,struct sockaddr_storage *addr); + void handleHandshakeInduction(HandshakePacket &pkt, struct sockaddr_storage *addr); + void handleHandshakeConclusion(HandshakePacket &pkt, struct sockaddr_storage *addr); void handleKeeplive(uint8_t *buf, int len, struct sockaddr_storage *addr); void handleACK(uint8_t *buf, int len, struct sockaddr_storage *addr); @@ -74,27 +76,25 @@ private: void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr); void handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr); void handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr); - - void sendNAKPacket(std::list& lost_list); + + void sendNAKPacket(std::list &lost_list); void sendACKPacket(); void sendLightACKPacket(); void sendKeepLivePacket(); void sendShutDown(); - void sendMsgDropReq(uint32_t first ,uint32_t last); + void sendMsgDropReq(uint32_t first, uint32_t last); size_t getPayloadSize(); + protected: - void sendDataPacket(DataPacket::Ptr pkt,char* buf,int len,bool flush = false); - void sendControlPacket(ControlPacket::Ptr pkt,bool flush = true); - virtual void sendPacket(Buffer::Ptr pkt,bool flush = true); - virtual int getLantencyMul(){ - return 4; - }; + void sendDataPacket(DataPacket::Ptr pkt, char *buf, int len, bool flush = false); + void sendControlPacket(ControlPacket::Ptr pkt, bool flush = true); + private: //当前选中的udp链接 Session::Ptr _selected_session; //链接迁移前后使用过的udp链接 - std::unordered_map > _history_sessions; + std::unordered_map> _history_sessions; EventPoller::Ptr _poller; @@ -109,7 +109,7 @@ private: uint32_t _mtu = 1500; uint32_t _max_window_size = 8192; - uint32_t _init_seq_number = 0; + uint32_t _init_seq_number = 0; std::string _stream_id; uint32_t _sync_cookie = 0; @@ -119,13 +119,13 @@ private: PacketSendQueue::Ptr _send_buf; uint32_t _buf_delay = 120; PacketQueue::Ptr _recv_buf; - uint32_t _rtt = 100*1000; - uint32_t _rtt_variance =50*1000; + uint32_t _rtt = 100 * 1000; + uint32_t _rtt_variance = 50 * 1000; uint32_t _light_ack_pkt_count = 0; uint32_t _ack_number_count = 0; uint32_t _last_ack_pkt_seq_num = 0; UTicker _ack_ticker; - std::map _ack_send_timestamp; + std::map _ack_send_timestamp; std::shared_ptr _pkt_recv_rate_context; std::shared_ptr _estimated_link_capacity_context; @@ -137,7 +137,6 @@ private: HandshakePacket::Ptr _handleshake_res; ResourcePool _packet_pool; - }; class SrtTransportManager { @@ -150,6 +149,7 @@ public: void addHandshakeItem(const std::string &key, const SrtTransport::Ptr &ptr); void removeHandshakeItem(const std::string &key); SrtTransport::Ptr getHandshakeItem(const std::string &key); + private: SrtTransportManager() = default; diff --git a/srt/SrtTransportImp.cpp b/srt/SrtTransportImp.cpp index 9425e2e7..a838a175 100644 --- a/srt/SrtTransportImp.cpp +++ b/srt/SrtTransportImp.cpp @@ -4,8 +4,7 @@ #include "SrtTransportImp.hpp" namespace SRT { -SrtTransportImp::SrtTransportImp(const EventPoller::Ptr &poller) - : SrtTransport(poller) {} +SrtTransportImp::SrtTransportImp(const EventPoller::Ptr &poller) : SrtTransport(poller) {} SrtTransportImp::~SrtTransportImp() { InfoP(this); @@ -23,52 +22,56 @@ SrtTransportImp::~SrtTransportImp() { } } -void SrtTransportImp::onHandShakeFinished(std::string &streamid,struct sockaddr_storage *addr) { - - // TODO parse streamid like this zlmediakit.com/live/test?token=1213444&type=push - if(!_addr){ +void SrtTransportImp::onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) { + // TODO parse stream id like this zlmediakit.com/live/test?token=1213444&type=push + if (!_addr) { _addr.reset(new sockaddr_storage(*((sockaddr_storage *)addr))); } - _is_pusher = false; - TraceL<<" stream id "<input(reinterpret_cast(pkt->payloadData()), pkt->payloadSize()); - }else{ - WarnP(this)<<" not reach this"; + } else { + WarnP(this) << " not reach this"; } } + void SrtTransportImp::onShutdown(const SockException &ex) { SrtTransport::onShutdown(ex); } -bool SrtTransportImp::close(mediakit::MediaSource &sender, bool force){ - if (!force && totalReaderCount(sender)) { +bool SrtTransportImp::close(mediakit::MediaSource &sender, bool force) { + if (!force && totalReaderCount(sender)) { return false; } - std::string err = StrPrinter << "close media:" << sender.getSchema() << "/" << sender.getVhost() << "/" - << sender.getApp() << "/" << sender.getId() << " " << force; + std::string err = StrPrinter << "close media:" << sender.getSchema() << "/" + << sender.getVhost() << "/" + << sender.getApp() << "/" + << sender.getId() << " " << force; + weak_ptr weak_self = static_pointer_cast(shared_from_this()); getPoller()->async([weak_self, err]() { auto strong_self = weak_self.lock(); @@ -80,21 +83,25 @@ bool SrtTransportImp::close(mediakit::MediaSource &sender, bool force){ }); return true; } + // 播放总人数 -int SrtTransportImp::totalReaderCount(mediakit::MediaSource &sender){ +int SrtTransportImp::totalReaderCount(mediakit::MediaSource &sender) { return _muxer ? _muxer->totalReaderCount() : sender.readerCount(); } + // 获取媒体源类型 -mediakit::MediaOriginType SrtTransportImp::getOriginType(mediakit::MediaSource &sender) const{ +mediakit::MediaOriginType SrtTransportImp::getOriginType(mediakit::MediaSource &sender) const { return MediaOriginType::srt_push; } + // 获取媒体源url或者文件路径 -std::string SrtTransportImp::getOriginUrl(mediakit::MediaSource &sender) const{ +std::string SrtTransportImp::getOriginUrl(mediakit::MediaSource &sender) const { return _media_info._full_url; } + // 获取媒体源客户端相关信息 -std::shared_ptr SrtTransportImp::getOriginSock(mediakit::MediaSource &sender) const{ - return static_pointer_cast(getSession()); +std::shared_ptr SrtTransportImp::getOriginSock(mediakit::MediaSource &sender) const { + return static_pointer_cast(getSession()); } void SrtTransportImp::emitOnPublish() { @@ -114,7 +121,7 @@ void SrtTransportImp::emitOnPublish() { InfoP(strong_self) << "允许 srt 推流"; } else { WarnP(strong_self) << "禁止 srt 推流:" << err; - strong_self->onShutdown(SockException(Err_refused,err)); + strong_self->onShutdown(SockException(Err_refused, err)); } }; @@ -126,47 +133,46 @@ void SrtTransportImp::emitOnPublish() { } } - -void SrtTransportImp::emitOnPlay(){ +void SrtTransportImp::emitOnPlay() { std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); - Broadcast::AuthInvoker invoker = [weak_self](const string &err){ + Broadcast::AuthInvoker invoker = [weak_self](const string &err) { auto strong_self = weak_self.lock(); if (!strong_self) { return; } - strong_self->getPoller()->async([strong_self,err]{ - if(err != ""){ - strong_self->onShutdown(SockException(Err_refused,err)); - }else{ + strong_self->getPoller()->async([strong_self, err] { + if (err != "") { + strong_self->onShutdown(SockException(Err_refused, err)); + } else { strong_self->doPlay(); } }); }; auto flag = NoticeCenter::Instance().emitEvent(Broadcast::kBroadcastMediaPlayed, _media_info, invoker, static_cast(*this)); - if(!flag){ + if (!flag) { doPlay(); } } -void SrtTransportImp::doPlay(){ - //异步查找直播流 - std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); +void SrtTransportImp::doPlay() { + //异步查找直播流 MediaInfo info = _media_info; info._schema = TS_SCHEMA; + std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); MediaSource::findAsync(info, getSession(), [weak_self](const MediaSource::Ptr &src) { auto strong_self = weak_self.lock(); if (!strong_self) { //本对象已经销毁 - TraceL<<"本对象已经销毁"; + TraceL << "本对象已经销毁"; return; } if (!src) { //未找到该流 - TraceL<<"未找到该流"; + TraceL << "未找到该流"; strong_self->onShutdown(SockException(Err_shutdown)); } else { - TraceL<<"找到该流"; + TraceL << "找到该流"; auto ts_src = dynamic_pointer_cast(src); assert(ts_src); ts_src->pause(false); @@ -189,9 +195,10 @@ void SrtTransportImp::doPlay(){ auto size = ts_list->size(); ts_list->for_each([&](const TSPacket::Ptr &ts) { strong_self->onSendTSData(ts, ++i == size); }); }); - }; + } }); } + std::string SrtTransportImp::get_peer_ip() { if (!_addr) { return "::"; @@ -215,7 +222,7 @@ std::string SrtTransportImp::get_local_ip() { } uint16_t SrtTransportImp::get_local_port() { - auto s = getSession(); + auto s = getSession(); if (s) { return s->get_local_port(); } @@ -236,9 +243,7 @@ bool SrtTransportImp::inputFrame(const Frame::Ptr &frame) { } auto frame_cached = Frame::getCacheAbleFrame(frame); lock_guard lck(_func_mtx); - _cached_func.emplace_back([this, frame_cached]() { - _muxer->inputFrame(frame_cached); - }); + _cached_func.emplace_back([this, frame_cached]() { _muxer->inputFrame(frame_cached); }); return true; } @@ -248,9 +253,7 @@ bool SrtTransportImp::addTrack(const Track::Ptr &track) { } lock_guard lck(_func_mtx); - _cached_func.emplace_back([this, track]() { - _muxer->addTrack(track); - }); + _cached_func.emplace_back([this, track]() { _muxer->addTrack(track); }); return true; } @@ -259,9 +262,7 @@ void SrtTransportImp::addTrackCompleted() { _muxer->addTrackCompleted(); } else { lock_guard lck(_func_mtx); - _cached_func.emplace_back([this]() { - _muxer->addTrackCompleted(); - }); + _cached_func.emplace_back([this]() { _muxer->addTrackCompleted(); }); } } @@ -273,10 +274,9 @@ void SrtTransportImp::doCachedFunc() { _cached_func.clear(); } -int SrtTransportImp::getLantencyMul(){ - GET_CONFIG(int, lantencyMul, kLantencyMul); - return lantencyMul; +int SrtTransportImp::getLatencyMul() { + GET_CONFIG(int, latencyMul, kLatencyMul); + return latencyMul; } - } // namespace SRT \ No newline at end of file diff --git a/srt/SrtTransportImp.hpp b/srt/SrtTransportImp.hpp index 5e259632..5dc4b228 100644 --- a/srt/SrtTransportImp.hpp +++ b/srt/SrtTransportImp.hpp @@ -1,16 +1,18 @@ #ifndef ZLMEDIAKIT_SRT_TRANSPORT_IMP_H #define ZLMEDIAKIT_SRT_TRANSPORT_IMP_H -#include -#include "Common/MultiMediaSourceMuxer.h" -#include "Rtp/Decoder.h" -#include "TS/TSMediaSource.h" -#include "SrtTransport.hpp" +#include +#include "Rtp/Decoder.h" +#include "SrtTransport.hpp" +#include "TS/TSMediaSource.h" +#include "Common/MultiMediaSourceMuxer.h" namespace SRT { - using namespace toolkit; - using namespace mediakit; - using namespace std; + +using namespace std; +using namespace toolkit; +using namespace mediakit; + class SrtTransportImp : public SrtTransport , public toolkit::SockInfo @@ -19,13 +21,13 @@ class SrtTransportImp public: SrtTransportImp(const EventPoller::Ptr &poller); ~SrtTransportImp(); - void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr){ - SrtTransport::inputSockData(buf,len,addr); + + void inputSockData(uint8_t *buf, int len, struct sockaddr_storage *addr) override { + SrtTransport::inputSockData(buf, len, addr); _total_bytes += len; } - void onSendTSData(const Buffer::Ptr &buffer, bool flush){ - SrtTransport::onSendTSData(buffer,flush); - } + void onSendTSData(const Buffer::Ptr &buffer, bool flush) override { SrtTransport::onSendTSData(buffer, flush); } + /// SockInfo override std::string get_local_ip() override; uint16_t get_local_port() override; @@ -35,20 +37,18 @@ public: protected: ///////SrtTransport override/////// - void onHandShakeFinished(std::string& streamid,struct sockaddr_storage *addr) override; + int getLatencyMul() override; void onSRTData(DataPacket::Ptr pkt) override; void onShutdown(const SockException &ex) override; - int getLantencyMul() override; + void onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) override; - void sendPacket(Buffer::Ptr pkt,bool flush = true) override{ + void sendPacket(Buffer::Ptr pkt, bool flush = true) override { _total_bytes += pkt->size(); - SrtTransport::sendPacket(pkt,flush); - }; - - bool isPusher() override{ - return _is_pusher; + SrtTransport::sendPacket(pkt, flush); } + bool isPusher() override { return _is_pusher; } + ///////MediaSourceEvent override/////// // 关闭 bool close(mediakit::MediaSource &sender, bool force) override; @@ -61,10 +61,11 @@ protected: // 获取媒体源客户端相关信息 std::shared_ptr getOriginSock(mediakit::MediaSource &sender) const override; - bool inputFrame(const Frame::Ptr &frame) override; - bool addTrack(const Track::Ptr & track) override; - void addTrackCompleted() override; + ///////MediaSinkInterface override/////// void resetTracks() override {}; + void addTrackCompleted() override; + bool addTrack(const Track::Ptr &track) override; + bool inputFrame(const Frame::Ptr &frame) override; private: void emitOnPublish(); @@ -76,12 +77,12 @@ private: private: bool _is_pusher = true; MediaInfo _media_info; - uint64_t _total_bytes = 0; + uint64_t _total_bytes = 0; Ticker _alive_ticker; std::unique_ptr _addr; - // for player + // for player TSMediaSource::RingType::RingReader::Ptr _ts_reader; - // for pusher + // for pusher MultiMediaSourceMuxer::Ptr _muxer; DecoderImp::Ptr _decoder; std::recursive_mutex _func_mtx; diff --git a/srt/Statistic.cpp b/srt/Statistic.cpp index 9fc13cd5..446fc6fe 100644 --- a/srt/Statistic.cpp +++ b/srt/Statistic.cpp @@ -1,14 +1,17 @@ #include #include "Statistic.hpp" + namespace SRT { -void PacketRecvRateContext::inputPacket(TimePoint& ts) { - if(_pkt_map.size()>100){ - _pkt_map.erase(_pkt_map.begin()); - } - auto tmp = DurationCountMicroseconds(ts - _start); - _pkt_map.emplace(tmp,tmp); + +void PacketRecvRateContext::inputPacket(TimePoint &ts) { + if (_pkt_map.size() > 100) { + _pkt_map.erase(_pkt_map.begin()); + } + auto tmp = DurationCountMicroseconds(ts - _start); + _pkt_map.emplace(tmp, tmp); } + uint32_t PacketRecvRateContext::getPacketRecvRate() { if (_pkt_map.size() < 2) { return 50000; @@ -17,79 +20,78 @@ uint32_t PacketRecvRateContext::getPacketRecvRate() { for (auto it = _pkt_map.begin(); it != _pkt_map.end(); ++it) { auto next = it; ++next; - if (next != _pkt_map.end()) { - if ((next->first - it->first) < dur) { - dur = next->first - it->first; - } - } else { + if (next == _pkt_map.end()) { break; } + + if ((next->first - it->first) < dur) { + dur = next->first - it->first; + } } double rate = 1e6 / (double)dur; - if(rate <=1000){ + if (rate <= 1000) { return 50000; } return rate; } -void EstimatedLinkCapacityContext::inputPacket(TimePoint& ts) { +void EstimatedLinkCapacityContext::inputPacket(TimePoint &ts) { if (_pkt_map.size() > 16) { _pkt_map.erase(_pkt_map.begin()); } auto tmp = DurationCountMicroseconds(ts - _start); _pkt_map.emplace(tmp, tmp); } + uint32_t EstimatedLinkCapacityContext::getEstimatedLinkCapacity() { - decltype(_pkt_map.begin()) next; - std::vector tmp; + decltype(_pkt_map.begin()) next; + std::vector tmp; - for(auto it = _pkt_map.begin();it != _pkt_map.end();++it){ - next = it; - ++next; - if(next != _pkt_map.end()){ - tmp.push_back(next->first -it->first); - }else{ - break; - } - } - std::sort(tmp.begin(),tmp.end()); - if(tmp.empty()){ - return 1000; - } + for (auto it = _pkt_map.begin(); it != _pkt_map.end(); ++it) { + next = it; + ++next; + if (next != _pkt_map.end()) { + tmp.push_back(next->first - it->first); + } else { + break; + } + } + std::sort(tmp.begin(), tmp.end()); + if (tmp.empty()) { + return 1000; + } - if(tmp.size()<16){ - return 1000; - } - - double dur =tmp[0]/1e6; - - return (uint32_t)(1.0/dur); + if (tmp.size() < 16) { + return 1000; + } + double dur = tmp[0] / 1e6; + return (uint32_t)(1.0 / dur); } -void RecvRateContext::inputPacket(TimePoint& ts, size_t size ) { +void RecvRateContext::inputPacket(TimePoint &ts, size_t size) { if (_pkt_map.size() > 100) { _pkt_map.erase(_pkt_map.begin()); } - auto tmp = DurationCountMicroseconds(ts - _start); - + auto tmp = DurationCountMicroseconds(ts - _start); _pkt_map.emplace(tmp, tmp); } + uint32_t RecvRateContext::getRecvRate() { - if(_pkt_map.size()<2){ + if (_pkt_map.size() < 2) { return 0; } auto first = _pkt_map.begin(); auto last = _pkt_map.rbegin(); - double dur = (last->first - first->first)/1000000.0; + double dur = (last->first - first->first) / 1000000.0; size_t bytes = 0; - for(auto it : _pkt_map){ + for (auto it : _pkt_map) { bytes += it.second; } - double rate = (double)bytes/dur; + double rate = (double)bytes / dur; return (uint32_t)rate; } diff --git a/srt/Statistic.hpp b/srt/Statistic.hpp index 4524aebe..a2e9f68c 100644 --- a/srt/Statistic.hpp +++ b/srt/Statistic.hpp @@ -6,41 +6,42 @@ #include "Packet.hpp" namespace SRT { + class PacketRecvRateContext { public: - PacketRecvRateContext(TimePoint start):_start(start){}; + PacketRecvRateContext(TimePoint start): _start(start) {}; ~PacketRecvRateContext() = default; - void inputPacket(TimePoint& ts); + void inputPacket(TimePoint &ts); uint32_t getPacketRecvRate(); + private: - std::map _pkt_map; TimePoint _start; - + std::map _pkt_map; }; class EstimatedLinkCapacityContext { public: - EstimatedLinkCapacityContext(TimePoint start):_start(start){}; + EstimatedLinkCapacityContext(TimePoint start) : _start(start) {}; ~EstimatedLinkCapacityContext() = default; - void inputPacket(TimePoint& ts); + void inputPacket(TimePoint &ts); uint32_t getEstimatedLinkCapacity(); + private: - std::map _pkt_map; TimePoint _start; + std::map _pkt_map; }; class RecvRateContext { public: - RecvRateContext(TimePoint start):_start(start){}; + RecvRateContext(TimePoint start): _start(start) {}; ~RecvRateContext() = default; - void inputPacket(TimePoint& ts,size_t size); + void inputPacket(TimePoint &ts, size_t size); uint32_t getRecvRate(); + private: - std::map _pkt_map; - TimePoint _start; + TimePoint _start; + std::map _pkt_map; }; - - } // namespace SRT #endif // ZLMEDIAKIT_SRT_STATISTIC_H \ No newline at end of file