diff --git a/3rdpart/ZLToolKit b/3rdpart/ZLToolKit index 57901f9d..7d52e11a 160000 --- a/3rdpart/ZLToolKit +++ b/3rdpart/ZLToolKit @@ -1 +1 @@ -Subproject commit 57901f9d341478378b1526f7efe99ebc79b2ddb5 +Subproject commit 7d52e11ae4e8d6d2c20aa0349dbfea8a1d82a968 diff --git a/AUTHORS b/AUTHORS index 85f82479..6941936e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -65,4 +65,5 @@ WuPeng [PioLing](https://github.com/PioLing) [KevinZang](https://github.com/ZSC714725) [gongluck](https://github.com/gongluck) -[a-ucontrol](https://github.com/a-ucontrol) \ No newline at end of file +[a-ucontrol](https://github.com/a-ucontrol) +[TalusL](https://github.com/TalusL) \ No newline at end of file diff --git a/README.md b/README.md index e88fbc33..ede1a61c 100644 --- a/README.md +++ b/README.md @@ -291,6 +291,7 @@ bash build_docker_images.sh [KevinZang](https://github.com/ZSC714725) [gongluck](https://github.com/gongluck) [a-ucontrol](https://github.com/a-ucontrol) +[TalusL](https://github.com/TalusL) ## 使用案例 diff --git a/api/CMakeLists.txt b/api/CMakeLists.txt index 42061925..c7962a0f 100644 --- a/api/CMakeLists.txt +++ b/api/CMakeLists.txt @@ -81,4 +81,5 @@ install(FILES ${API_HEADER_LIST} DESTINATION ${INSTALL_PATH_INCLUDE}) install(TARGETS mk_api ARCHIVE DESTINATION ${INSTALL_PATH_LIB} - LIBRARY DESTINATION ${INSTALL_PATH_LIB}) + LIBRARY DESTINATION ${INSTALL_PATH_LIB} + RUNTIME DESTINATION ${INSTALL_PATH_RUNTIME}) diff --git a/api/include/mk_events_objects.h b/api/include/mk_events_objects.h index 7641cb3e..f8d3c571 100644 --- a/api/include/mk_events_objects.h +++ b/api/include/mk_events_objects.h @@ -18,7 +18,7 @@ extern "C" { ///////////////////////////////////////////MP4Info///////////////////////////////////////////// //MP4Info对象的C映射 -typedef void* mk_mp4_info; +typedef struct mk_mp4_info_t *mk_mp4_info; // GMT 标准时间,单位秒 API_EXPORT uint64_t API_CALL mk_mp4_info_get_start_time(const mk_mp4_info ctx); // 录像长度,单位秒 @@ -42,7 +42,7 @@ API_EXPORT const char* API_CALL mk_mp4_info_get_stream(const mk_mp4_info ctx); ///////////////////////////////////////////Parser///////////////////////////////////////////// //Parser对象的C映射 -typedef void* mk_parser; +typedef struct mk_parser_t *mk_parser; //Parser::Method(),获取命令字,譬如GET/POST API_EXPORT const char* API_CALL mk_parser_get_method(const mk_parser ctx); //Parser::Url(),获取HTTP的访问url(不包括?后面的参数) @@ -60,7 +60,7 @@ API_EXPORT const char* API_CALL mk_parser_get_content(const mk_parser ctx, size_ ///////////////////////////////////////////MediaInfo///////////////////////////////////////////// //MediaInfo对象的C映射 -typedef void* mk_media_info; +typedef struct mk_media_info_t *mk_media_info; //MediaInfo::_param_strs API_EXPORT const char* API_CALL mk_media_info_get_params(const mk_media_info ctx); //MediaInfo::_schema @@ -79,7 +79,7 @@ API_EXPORT uint16_t API_CALL mk_media_info_get_port(const mk_media_info ctx); ///////////////////////////////////////////MediaSource///////////////////////////////////////////// //MediaSource对象的C映射 -typedef void* mk_media_source; +typedef struct mk_media_source_t *mk_media_source; //查找MediaSource的回调函数 typedef void(API_CALL *on_mk_media_source_find_cb)(void *user_data, const mk_media_source ctx); @@ -138,7 +138,7 @@ API_EXPORT void API_CALL mk_media_source_for_each(void *user_data, on_mk_media_s ///////////////////////////////////////////HttpBody///////////////////////////////////////////// //HttpBody对象的C映射 -typedef void* mk_http_body; +typedef struct mk_http_body_t *mk_http_body; /** * 生成HttpStringBody * @param str 字符串指针 @@ -173,7 +173,7 @@ API_EXPORT void API_CALL mk_http_body_release(mk_http_body ctx); ///////////////////////////////////////////HttpResponseInvoker///////////////////////////////////////////// //HttpSession::HttpResponseInvoker对象的C映射 -typedef void* mk_http_response_invoker; +typedef struct mk_http_response_invoker_t *mk_http_response_invoker; /** * HttpSession::HttpResponseInvoker(const string &codeOut, const StrCaseMap &headerOut, const HttpBody::Ptr &body); @@ -219,7 +219,8 @@ API_EXPORT void API_CALL mk_http_response_invoker_clone_release(const mk_http_re ///////////////////////////////////////////HttpAccessPathInvoker///////////////////////////////////////////// //HttpSession::HttpAccessPathInvoker对象的C映射 -typedef void* mk_http_access_path_invoker; +typedef struct mk_http_access_path_invoker_t *mk_http_access_path_invoker; + /** * HttpSession::HttpAccessPathInvoker(const string &errMsg,const string &accessPath, int cookieLifeSecond); * @param err_msg 如果为空,则代表鉴权通过,否则为错误提示,可以为null @@ -244,7 +245,7 @@ API_EXPORT void API_CALL mk_http_access_path_invoker_clone_release(const mk_http ///////////////////////////////////////////RtspSession::onGetRealm///////////////////////////////////////////// //RtspSession::onGetRealm对象的C映射 -typedef void* mk_rtsp_get_realm_invoker; +typedef struct mk_rtsp_get_realm_invoker_t *mk_rtsp_get_realm_invoker; /** * 执行RtspSession::onGetRealm * @param realm 该rtsp流是否需要开启rtsp专属鉴权,至null或空字符串则不鉴权 @@ -265,7 +266,7 @@ API_EXPORT void API_CALL mk_rtsp_get_realm_invoker_clone_release(const mk_rtsp_g ///////////////////////////////////////////RtspSession::onAuth///////////////////////////////////////////// //RtspSession::onAuth对象的C映射 -typedef void* mk_rtsp_auth_invoker; +typedef struct mk_rtsp_auth_invoker_t *mk_rtsp_auth_invoker; /** * 执行RtspSession::onAuth @@ -289,7 +290,7 @@ API_EXPORT void API_CALL mk_rtsp_auth_invoker_clone_release(const mk_rtsp_auth_i ///////////////////////////////////////////Broadcast::PublishAuthInvoker///////////////////////////////////////////// //Broadcast::PublishAuthInvoker对象的C映射 -typedef void* mk_publish_auth_invoker; +typedef struct mk_publish_auth_invoker_t *mk_publish_auth_invoker; /** * 执行Broadcast::PublishAuthInvoker @@ -315,7 +316,7 @@ API_EXPORT void API_CALL mk_publish_auth_invoker_clone_release(const mk_publish_ ///////////////////////////////////////////Broadcast::AuthInvoker///////////////////////////////////////////// //Broadcast::AuthInvoker对象的C映射 -typedef void* mk_auth_invoker; +typedef struct mk_auth_invoker_t *mk_auth_invoker; /** * 执行Broadcast::AuthInvoker diff --git a/api/include/mk_frame.h b/api/include/mk_frame.h index a5120ebb..40ba4a4d 100644 --- a/api/include/mk_frame.h +++ b/api/include/mk_frame.h @@ -39,7 +39,7 @@ API_EXPORT extern const int MKCodecVP9; API_EXPORT extern const int MKCodecAV1; API_EXPORT extern const int MKCodecJPEG; -typedef void *mk_frame; +typedef struct mk_frame_t *mk_frame; // 用户自定义free回调函数 typedef void(API_CALL *on_mk_frame_data_release)(void *user_data, char *ptr); diff --git a/api/include/mk_h264_splitter.h b/api/include/mk_h264_splitter.h index 62232414..24db3599 100644 --- a/api/include/mk_h264_splitter.h +++ b/api/include/mk_h264_splitter.h @@ -17,7 +17,7 @@ extern "C" { #endif -typedef void *mk_h264_splitter; +typedef struct mk_h264_splitter_t *mk_h264_splitter; /** * h264 分帧器输出回调函数 diff --git a/api/include/mk_httpclient.h b/api/include/mk_httpclient.h index 2405e5dd..c76d4ce0 100755 --- a/api/include/mk_httpclient.h +++ b/api/include/mk_httpclient.h @@ -20,7 +20,7 @@ extern "C" { ///////////////////////////////////////////HttpDownloader///////////////////////////////////////////// -typedef void *mk_http_downloader; +typedef struct mk_http_downloader_t *mk_http_downloader; /** * @param user_data 用户数据指针 @@ -54,7 +54,7 @@ API_EXPORT void API_CALL mk_http_downloader_start(mk_http_downloader ctx, const API_EXPORT void API_CALL mk_http_downloader_start2(mk_http_downloader ctx, const char *url, const char *file, on_mk_download_complete cb, void *user_data, on_user_data_free user_data_free); ///////////////////////////////////////////HttpRequester///////////////////////////////////////////// -typedef void *mk_http_requester; +typedef struct mk_http_requester_t *mk_http_requester; /** * http请求结果回调 diff --git a/api/include/mk_media.h b/api/include/mk_media.h index 5280c904..dfefda1b 100755 --- a/api/include/mk_media.h +++ b/api/include/mk_media.h @@ -22,7 +22,7 @@ extern "C" { #endif -typedef void *mk_media; +typedef struct mk_media_t *mk_media; /** * 创建一个媒体源 diff --git a/api/include/mk_player.h b/api/include/mk_player.h index b57809f7..5ab0eea5 100755 --- a/api/include/mk_player.h +++ b/api/include/mk_player.h @@ -19,7 +19,7 @@ extern "C" { #endif -typedef void* mk_player; +typedef struct mk_player_t *mk_player; /** * 播放结果或播放中断事件的回调 diff --git a/api/include/mk_proxyplayer.h b/api/include/mk_proxyplayer.h index f25bf9b7..bb7afac0 100644 --- a/api/include/mk_proxyplayer.h +++ b/api/include/mk_proxyplayer.h @@ -17,7 +17,7 @@ extern "C" { #endif -typedef void *mk_proxy_player; +typedef struct mk_proxy_player_t *mk_proxy_player; /** * 创建一个代理播放器 diff --git a/api/include/mk_pusher.h b/api/include/mk_pusher.h index fc7b067c..adb905e5 100644 --- a/api/include/mk_pusher.h +++ b/api/include/mk_pusher.h @@ -18,7 +18,7 @@ extern "C" { #endif -typedef void* mk_pusher; +typedef struct mk_pusher_t *mk_pusher; /** * 推流结果或推流中断事件的回调 diff --git a/api/include/mk_recorder.h b/api/include/mk_recorder.h index 2653fd50..c9053db5 100644 --- a/api/include/mk_recorder.h +++ b/api/include/mk_recorder.h @@ -19,7 +19,7 @@ extern "C" { ///////////////////////////////////////////flv录制///////////////////////////////////////////// -typedef void* mk_flv_recorder; +typedef struct mk_flv_recorder_t *mk_flv_recorder; /** * 创建flv录制器 diff --git a/api/include/mk_rtp_server.h b/api/include/mk_rtp_server.h index f50f438a..add2104d 100644 --- a/api/include/mk_rtp_server.h +++ b/api/include/mk_rtp_server.h @@ -14,7 +14,7 @@ extern "C" { #endif -typedef void* mk_rtp_server; +typedef struct mk_rtp_server_t *mk_rtp_server; /** * 创建GB28181 RTP 服务器 diff --git a/api/include/mk_tcp.h b/api/include/mk_tcp.h index 8ba87d1f..84deed8e 100644 --- a/api/include/mk_tcp.h +++ b/api/include/mk_tcp.h @@ -19,7 +19,7 @@ extern "C" { ///////////////////////////////////////////Buffer::Ptr///////////////////////////////////////////// -typedef void *mk_buffer; +typedef struct mk_buffer_t *mk_buffer; typedef void(API_CALL *on_mk_buffer_free)(void *user_data, void *data); /** @@ -39,7 +39,7 @@ API_EXPORT size_t API_CALL mk_buffer_get_size(mk_buffer buffer); ///////////////////////////////////////////SockInfo///////////////////////////////////////////// //SockInfo对象的C映射 -typedef void* mk_sock_info; +typedef struct mk_sock_info_t *mk_sock_info; //SockInfo::get_peer_ip() API_EXPORT const char* API_CALL mk_sock_info_peer_ip(const mk_sock_info ctx, char *buf); @@ -66,8 +66,8 @@ API_EXPORT uint16_t API_CALL mk_sock_info_local_port(const mk_sock_info ctx); #endif ///////////////////////////////////////////TcpSession///////////////////////////////////////////// //TcpSession对象的C映射 -typedef void* mk_tcp_session; -typedef void* mk_tcp_session_ref; +typedef struct mk_tcp_session_t *mk_tcp_session; +typedef struct mk_tcp_session_ref_t *mk_tcp_session_ref; //获取基类指针以便获取其网络相关信息 API_EXPORT mk_sock_info API_CALL mk_tcp_session_get_sock_info(const mk_tcp_session ctx); @@ -168,7 +168,7 @@ API_EXPORT void API_CALL mk_tcp_server_events_listen(const mk_tcp_session_events ///////////////////////////////////////////自定义tcp客户端///////////////////////////////////////////// -typedef void* mk_tcp_client; +typedef struct mk_tcp_client_t *mk_tcp_client; //获取基类指针以便获取其网络相关信息 API_EXPORT mk_sock_info API_CALL mk_tcp_client_get_sock_info(const mk_tcp_client ctx); diff --git a/api/include/mk_thread.h b/api/include/mk_thread.h index 959626a3..79994154 100644 --- a/api/include/mk_thread.h +++ b/api/include/mk_thread.h @@ -20,7 +20,7 @@ extern "C" { #endif ///////////////////////////////////////////事件线程///////////////////////////////////////////// -typedef void* mk_thread; +typedef struct mk_thread_t *mk_thread; /** * 获取tcp会话对象所在事件线程 @@ -52,7 +52,7 @@ API_EXPORT mk_thread API_CALL mk_thread_from_pool(); */ API_EXPORT mk_thread API_CALL mk_thread_from_pool_work(); -typedef void* mk_thread_pool; +typedef struct mk_thread_pool_t *mk_thread_pool; /** * 创建线程池 @@ -108,7 +108,7 @@ API_EXPORT void API_CALL mk_async_do_delay2(mk_thread ctx, size_t ms, on_mk_asyn API_EXPORT void API_CALL mk_sync_do(mk_thread ctx, on_mk_async cb, void *user_data); ///////////////////////////////////////////定时器///////////////////////////////////////////// -typedef void* mk_timer; +typedef struct mk_timer_t *mk_timer; /** * 定时器触发事件 @@ -135,7 +135,7 @@ API_EXPORT void API_CALL mk_timer_release(mk_timer ctx); ///////////////////////////////////////////信号量///////////////////////////////////////////// -typedef void* mk_sem; +typedef struct mk_sem_t *mk_sem; /** * 创建信号量 diff --git a/api/include/mk_track.h b/api/include/mk_track.h index ff8ddf1c..b7a01b7f 100644 --- a/api/include/mk_track.h +++ b/api/include/mk_track.h @@ -19,7 +19,7 @@ extern "C" { #endif //音视频轨道 -typedef void* mk_track; +typedef struct mk_track_t *mk_track; //输出frame回调 typedef void(API_CALL *on_mk_frame_out)(void *user_data, mk_frame frame); diff --git a/api/include/mk_transcode.h b/api/include/mk_transcode.h index be97fd15..1c4bd0ce 100644 --- a/api/include/mk_transcode.h +++ b/api/include/mk_transcode.h @@ -20,11 +20,11 @@ extern "C" { #endif //解码器对象 -typedef void *mk_decoder; +typedef struct mk_decoder_t *mk_decoder; //解码后的frame -typedef void *mk_frame_pix; +typedef struct mk_frame_pix_t *mk_frame_pix; //SwsContext的包装 -typedef void *mk_swscale; +typedef struct mk_swscale_t *mk_swscale; //FFmpeg原始解码帧对象 typedef struct AVFrame AVFrame; //FFmpeg编解码器对象 diff --git a/api/include/mk_util.h b/api/include/mk_util.h index 52a8d776..bd5b29ce 100644 --- a/api/include/mk_util.h +++ b/api/include/mk_util.h @@ -58,7 +58,7 @@ API_EXPORT char* API_CALL mk_util_get_current_time_string(const char *fmt); API_EXPORT char* API_CALL mk_util_hex_dump(const void *buf, int len); ///////////////////////////////////////////mk ini///////////////////////////////////////////// -typedef void* mk_ini; +typedef struct mk_ini_t *mk_ini; /** * 创建ini配置对象 diff --git a/api/source/mk_events_objects.cpp b/api/source/mk_events_objects.cpp index 7cd379cb..4b641a2b 100644 --- a/api/source/mk_events_objects.cpp +++ b/api/source/mk_events_objects.cpp @@ -245,14 +245,14 @@ API_EXPORT void API_CALL mk_media_source_find(const char *schema, on_mk_media_source_find_cb cb) { assert(schema && vhost && app && stream && cb); auto src = MediaSource::find(schema, vhost, app, stream, from_mp4); - cb(user_data, src.get()); + cb(user_data, (mk_media_source)src.get()); } API_EXPORT void API_CALL mk_media_source_for_each(void *user_data, on_mk_media_source_find_cb cb, const char *schema, const char *vhost, const char *app, const char *stream) { assert(cb); MediaSource::for_each_media([&](const MediaSource::Ptr &src) { - cb(user_data, src.get()); + cb(user_data, (mk_media_source)src.get()); }, schema ? schema : "", vhost ? vhost : "", app ? app : "", stream ? stream : ""); } @@ -263,17 +263,17 @@ API_EXPORT mk_http_body API_CALL mk_http_body_from_string(const char *str, size_ if(!len){ len = strlen(str); } - return new HttpBody::Ptr(new HttpStringBody(std::string(str, len))); + return (mk_http_body)new HttpBody::Ptr(new HttpStringBody(std::string(str, len))); } API_EXPORT mk_http_body API_CALL mk_http_body_from_buffer(mk_buffer buffer) { assert(buffer); - return new HttpBody::Ptr(new HttpBufferBody(*((Buffer::Ptr *) buffer))); + return (mk_http_body)new HttpBody::Ptr(new HttpBufferBody(*((Buffer::Ptr *) buffer))); } API_EXPORT mk_http_body API_CALL mk_http_body_from_file(const char *file_path){ assert(file_path); - return new HttpBody::Ptr(new HttpFileBody(file_path)); + return (mk_http_body)new HttpBody::Ptr(new HttpFileBody(file_path)); } template @@ -294,7 +294,7 @@ static C get_http_header( const char *response_header[]){ API_EXPORT mk_http_body API_CALL mk_http_body_from_multi_form(const char *key_val[],const char *file_path){ assert(key_val && file_path); - return new HttpBody::Ptr(new HttpMultiFormBody(get_http_header(key_val),file_path)); + return (mk_http_body)new HttpBody::Ptr(new HttpMultiFormBody(get_http_header(key_val),file_path)); } API_EXPORT void API_CALL mk_http_body_release(mk_http_body ctx){ @@ -338,7 +338,7 @@ API_EXPORT void API_CALL mk_http_response_invoker_do(const mk_http_response_invo API_EXPORT mk_http_response_invoker API_CALL mk_http_response_invoker_clone(const mk_http_response_invoker ctx){ assert(ctx); HttpSession::HttpResponseInvoker *invoker = (HttpSession::HttpResponseInvoker *)ctx; - return new HttpSession::HttpResponseInvoker (*invoker); + return (mk_http_response_invoker)new HttpSession::HttpResponseInvoker (*invoker); } API_EXPORT void API_CALL mk_http_response_invoker_clone_release(const mk_http_response_invoker ctx){ @@ -362,7 +362,7 @@ API_EXPORT void API_CALL mk_http_access_path_invoker_do(const mk_http_access_pat API_EXPORT mk_http_access_path_invoker API_CALL mk_http_access_path_invoker_clone(const mk_http_access_path_invoker ctx){ assert(ctx); HttpSession::HttpAccessPathInvoker *invoker = (HttpSession::HttpAccessPathInvoker *)ctx; - return new HttpSession::HttpAccessPathInvoker(*invoker); + return (mk_http_access_path_invoker)new HttpSession::HttpAccessPathInvoker(*invoker); } API_EXPORT void API_CALL mk_http_access_path_invoker_clone_release(const mk_http_access_path_invoker ctx){ @@ -382,7 +382,7 @@ API_EXPORT void API_CALL mk_rtsp_get_realm_invoker_do(const mk_rtsp_get_realm_in API_EXPORT mk_rtsp_get_realm_invoker API_CALL mk_rtsp_get_realm_invoker_clone(const mk_rtsp_get_realm_invoker ctx){ assert(ctx); RtspSession::onGetRealm *invoker = (RtspSession::onGetRealm *)ctx; - return new RtspSession::onGetRealm (*invoker); + return (mk_rtsp_get_realm_invoker)new RtspSession::onGetRealm (*invoker); } API_EXPORT void API_CALL mk_rtsp_get_realm_invoker_clone_release(const mk_rtsp_get_realm_invoker ctx){ @@ -403,7 +403,7 @@ API_EXPORT void API_CALL mk_rtsp_auth_invoker_do(const mk_rtsp_auth_invoker ctx, API_EXPORT mk_rtsp_auth_invoker API_CALL mk_rtsp_auth_invoker_clone(const mk_rtsp_auth_invoker ctx){ assert(ctx); RtspSession::onAuth *invoker = (RtspSession::onAuth *)ctx; - return new RtspSession::onAuth(*invoker); + return (mk_rtsp_auth_invoker)new RtspSession::onAuth(*invoker); } API_EXPORT void API_CALL mk_rtsp_auth_invoker_clone_release(const mk_rtsp_auth_invoker ctx){ @@ -428,7 +428,7 @@ API_EXPORT void API_CALL mk_publish_auth_invoker_do(const mk_publish_auth_invoke API_EXPORT mk_publish_auth_invoker API_CALL mk_publish_auth_invoker_clone(const mk_publish_auth_invoker ctx){ assert(ctx); Broadcast::PublishAuthInvoker *invoker = (Broadcast::PublishAuthInvoker *)ctx; - return new Broadcast::PublishAuthInvoker(*invoker); + return (mk_publish_auth_invoker)new Broadcast::PublishAuthInvoker(*invoker); } API_EXPORT void API_CALL mk_publish_auth_invoker_clone_release(const mk_publish_auth_invoker ctx){ @@ -447,7 +447,7 @@ API_EXPORT void API_CALL mk_auth_invoker_do(const mk_auth_invoker ctx, const cha API_EXPORT mk_auth_invoker API_CALL mk_auth_invoker_clone(const mk_auth_invoker ctx){ assert(ctx); Broadcast::AuthInvoker *invoker = (Broadcast::AuthInvoker *)ctx; - return new Broadcast::AuthInvoker(*invoker); + return (mk_auth_invoker)new Broadcast::AuthInvoker(*invoker); } API_EXPORT void API_CALL mk_auth_invoker_clone_release(const mk_auth_invoker ctx){ diff --git a/api/source/mk_frame.cpp b/api/source/mk_frame.cpp index 4b026577..8c2e4ff6 100644 --- a/api/source/mk_frame.cpp +++ b/api/source/mk_frame.cpp @@ -74,13 +74,13 @@ static mk_frame mk_frame_create_complex(int codec_id, uint64_t dts, uint64_t pts char *data, size_t size, on_mk_frame_data_release cb, std::shared_ptr user_data) { switch (codec_id) { case CodecH264: - return new Frame::Ptr(new H264FrameHelper( + return (mk_frame)new Frame::Ptr(new H264FrameHelper( cb, frame_flags, cb, std::move(user_data), (CodecId)codec_id, data, size, dts, pts, prefix_size)); case CodecH265: - return new Frame::Ptr(new H265FrameHelper( + return (mk_frame)new Frame::Ptr(new H265FrameHelper( cb, frame_flags, cb, std::move(user_data), (CodecId)codec_id, data, size, dts, pts, prefix_size)); default: - return new Frame::Ptr(new FrameFromPtrForC( + return (mk_frame)new Frame::Ptr(new FrameFromPtrForC( cb, frame_flags, cb, std::move(user_data), (CodecId)codec_id, data, size, dts, pts, prefix_size)); } } @@ -117,7 +117,7 @@ API_EXPORT void API_CALL mk_frame_unref(mk_frame frame) { API_EXPORT mk_frame API_CALL mk_frame_ref(mk_frame frame) { assert(frame); - return new Frame::Ptr(Frame::getCacheAbleFrame(*((Frame::Ptr *) frame))); + return (mk_frame)new Frame::Ptr(Frame::getCacheAbleFrame(*((Frame::Ptr *) frame))); } API_EXPORT int API_CALL mk_frame_codec_id(mk_frame frame) { diff --git a/api/source/mk_httpclient.cpp b/api/source/mk_httpclient.cpp index 4f9b7963..9c52cd33 100755 --- a/api/source/mk_httpclient.cpp +++ b/api/source/mk_httpclient.cpp @@ -48,7 +48,7 @@ API_EXPORT void API_CALL mk_http_downloader_start2(mk_http_downloader ctx, const ///////////////////////////////////////////HttpRequester///////////////////////////////////////////// API_EXPORT mk_http_requester API_CALL mk_http_requester_create(){ HttpRequester::Ptr *ret = new HttpRequester::Ptr(new HttpRequester); - return ret; + return (mk_http_requester)ret; } API_EXPORT void API_CALL mk_http_requester_clear(mk_http_requester ctx){ diff --git a/api/source/mk_media.cpp b/api/source/mk_media.cpp index 95178d50..754905ce 100755 --- a/api/source/mk_media.cpp +++ b/api/source/mk_media.cpp @@ -99,7 +99,7 @@ protected: void onRegist(MediaSource &sender, bool regist) override{ if (_on_regist) { - _on_regist(_on_regist_data.get(), &sender, regist); + _on_regist(_on_regist_data.get(), (mk_media_source)&sender, regist); } } @@ -270,9 +270,9 @@ API_EXPORT int API_CALL mk_media_input_aac(mk_media ctx, const void *data, int l } API_EXPORT int API_CALL mk_media_input_pcm(mk_media ctx, void *data , int len, uint64_t pts){ - assert(ctx && data && len > 0); - MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx; - return (*obj)->getChannel()->inputPCM((char*)data, len, pts); + assert(ctx && data && len > 0); + MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx; + return (*obj)->getChannel()->inputPCM((char*)data, len, pts); } API_EXPORT int API_CALL mk_media_input_audio(mk_media ctx, const void* data, int len, uint64_t dts){ @@ -320,5 +320,5 @@ API_EXPORT void API_CALL mk_media_stop_send_rtp(mk_media ctx, const char *ssrc){ API_EXPORT mk_thread API_CALL mk_media_get_owner_thread(mk_media ctx) { MediaHelper::Ptr *obj = (MediaHelper::Ptr *)ctx; - return (*obj)->getChannel()->getOwnerPoller(MediaSource::NullMediaSource()).get(); + return (mk_thread)(*obj)->getChannel()->getOwnerPoller(MediaSource::NullMediaSource()).get(); } \ No newline at end of file diff --git a/api/source/mk_player.cpp b/api/source/mk_player.cpp index 104b272d..078f7140 100755 --- a/api/source/mk_player.cpp +++ b/api/source/mk_player.cpp @@ -105,7 +105,7 @@ private: API_EXPORT mk_player API_CALL mk_player_create() { MediaPlayerForC::Ptr *obj = new MediaPlayerForC::Ptr(new MediaPlayerForC()); (*obj)->setup(); - return obj; + return (mk_player)obj; } API_EXPORT void API_CALL mk_player_release(mk_player ctx) { assert(ctx); diff --git a/api/source/mk_pusher.cpp b/api/source/mk_pusher.cpp index b44a8d89..48956f7d 100644 --- a/api/source/mk_pusher.cpp +++ b/api/source/mk_pusher.cpp @@ -18,14 +18,14 @@ using namespace mediakit; API_EXPORT mk_pusher API_CALL mk_pusher_create(const char *schema,const char *vhost,const char *app, const char *stream){ assert(schema && vhost && app && schema); MediaPusher::Ptr *obj = new MediaPusher::Ptr(new MediaPusher(schema,vhost,app,stream)); - return obj; + return (mk_pusher)obj; } API_EXPORT mk_pusher API_CALL mk_pusher_create_src(mk_media_source ctx){ assert(ctx); MediaSource *src = (MediaSource *)ctx; MediaPusher::Ptr *obj = new MediaPusher::Ptr(new MediaPusher(src->shared_from_this())); - return obj; + return (mk_pusher)obj; } API_EXPORT void API_CALL mk_pusher_release(mk_pusher ctx){ diff --git a/api/source/mk_recorder.cpp b/api/source/mk_recorder.cpp index 8bacbb0b..e16a64c4 100644 --- a/api/source/mk_recorder.cpp +++ b/api/source/mk_recorder.cpp @@ -18,7 +18,7 @@ using namespace mediakit; API_EXPORT mk_flv_recorder API_CALL mk_flv_recorder_create(){ FlvRecorder::Ptr *ret = new FlvRecorder::Ptr(new FlvRecorder); - return ret; + return (mk_flv_recorder)ret; } API_EXPORT void API_CALL mk_flv_recorder_release(mk_flv_recorder ctx){ assert(ctx); diff --git a/api/source/mk_rtp_server.cpp b/api/source/mk_rtp_server.cpp index 79517263..3f5f2a09 100644 --- a/api/source/mk_rtp_server.cpp +++ b/api/source/mk_rtp_server.cpp @@ -19,7 +19,7 @@ using namespace mediakit; API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create(uint16_t port, int tcp_mode, const char *stream_id) { RtpServer::Ptr *server = new RtpServer::Ptr(new RtpServer); (*server)->start(port, stream_id, (RtpServer::TcpMode)tcp_mode); - return server; + return (mk_rtp_server)server; } API_EXPORT void API_CALL mk_rtp_server_connect(mk_rtp_server ctx, const char *dst_url, uint16_t dst_port, on_mk_rtp_server_connected cb, void *user_data) { diff --git a/api/source/mk_tcp.cpp b/api/source/mk_tcp.cpp index 62f58376..86ea8389 100644 --- a/api/source/mk_tcp.cpp +++ b/api/source/mk_tcp.cpp @@ -65,12 +65,12 @@ API_EXPORT mk_buffer API_CALL mk_buffer_from_char(const char *data, size_t len, API_EXPORT mk_buffer API_CALL mk_buffer_from_char2(const char *data, size_t len, on_mk_buffer_free cb, void *user_data, on_user_data_free user_data_free) { assert(data); std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); - return new Buffer::Ptr(std::make_shared(data, len, cb, std::move(ptr))); + return (mk_buffer)new Buffer::Ptr(std::make_shared(data, len, cb, std::move(ptr))); } API_EXPORT mk_buffer API_CALL mk_buffer_ref(mk_buffer buffer) { assert(buffer); - return new Buffer::Ptr(*((Buffer::Ptr *) buffer)); + return (mk_buffer)new Buffer::Ptr(*((Buffer::Ptr *) buffer)); } API_EXPORT void API_CALL mk_buffer_unref(mk_buffer buffer) { @@ -115,7 +115,7 @@ API_EXPORT uint16_t API_CALL mk_sock_info_local_port(const mk_sock_info ctx){ API_EXPORT mk_sock_info API_CALL mk_tcp_session_get_sock_info(const mk_tcp_session ctx){ assert(ctx); SessionForC *session = (SessionForC *)ctx; - return (SockInfo *)session; + return (mk_sock_info)session; } API_EXPORT void API_CALL mk_tcp_session_shutdown(const mk_tcp_session ctx,int err,const char *err_msg){ @@ -155,7 +155,7 @@ API_EXPORT void API_CALL mk_tcp_session_send_buffer_safe(const mk_tcp_session ct API_EXPORT mk_tcp_session_ref API_CALL mk_tcp_session_ref_from(const mk_tcp_session ctx) { auto ref = ((SessionForC *) ctx)->shared_from_this(); - return new std::shared_ptr(std::dynamic_pointer_cast(ref)); + return (mk_tcp_session_ref)new std::shared_ptr(std::dynamic_pointer_cast(ref)); } API_EXPORT void mk_tcp_session_ref_release(const mk_tcp_session_ref ref) { @@ -163,7 +163,7 @@ API_EXPORT void mk_tcp_session_ref_release(const mk_tcp_session_ref ref) { } API_EXPORT mk_tcp_session mk_tcp_session_from_ref(const mk_tcp_session_ref ref) { - return ((std::shared_ptr *) ref)->get(); + return (mk_tcp_session)((std::shared_ptr *) ref)->get(); } API_EXPORT void API_CALL mk_tcp_session_send_safe(const mk_tcp_session ctx, const char *data, size_t len) { @@ -179,25 +179,25 @@ static mk_tcp_session_events s_events_server = {0}; SessionForC::SessionForC(const Socket::Ptr &pSock) : Session(pSock) { _local_port = get_local_port(); if (s_events_server.on_mk_tcp_session_create) { - s_events_server.on_mk_tcp_session_create(_local_port,this); + s_events_server.on_mk_tcp_session_create(_local_port, (mk_tcp_session) this); } } void SessionForC::onRecv(const Buffer::Ptr &buffer) { if (s_events_server.on_mk_tcp_session_data) { - s_events_server.on_mk_tcp_session_data(_local_port, this, (mk_buffer)&buffer); + s_events_server.on_mk_tcp_session_data(_local_port, (mk_tcp_session)this, (mk_buffer)&buffer); } } void SessionForC::onError(const SockException &err) { if (s_events_server.on_mk_tcp_session_disconnect) { - s_events_server.on_mk_tcp_session_disconnect(_local_port,this, err.getErrCode(), err.what()); + s_events_server.on_mk_tcp_session_disconnect(_local_port, (mk_tcp_session)this, err.getErrCode(), err.what()); } } void SessionForC::onManager() { if (s_events_server.on_mk_tcp_session_manager) { - s_events_server.on_mk_tcp_session_manager(_local_port,this); + s_events_server.on_mk_tcp_session_manager(_local_port, (mk_tcp_session)this); } } @@ -320,13 +320,13 @@ TcpClientForC::Ptr *mk_tcp_client_create_l(mk_tcp_client_events *events, mk_tcp_ API_EXPORT mk_sock_info API_CALL mk_tcp_client_get_sock_info(const mk_tcp_client ctx){ assert(ctx); TcpClientForC::Ptr *client = (TcpClientForC::Ptr *)ctx; - return (SockInfo *)client->get(); + return (mk_sock_info)(SockInfo *)client->get(); } API_EXPORT mk_tcp_client API_CALL mk_tcp_client_create(mk_tcp_client_events *events, mk_tcp_type type){ auto ret = mk_tcp_client_create_l(events,type); - (*ret)->setClient(ret); - return ret; + (*ret)->setClient((mk_tcp_client)ret); + return (mk_tcp_client)ret; } API_EXPORT void API_CALL mk_tcp_client_release(mk_tcp_client ctx){ diff --git a/api/source/mk_thread.cpp b/api/source/mk_thread.cpp index d783ff5e..56b5be2a 100644 --- a/api/source/mk_thread.cpp +++ b/api/source/mk_thread.cpp @@ -19,21 +19,21 @@ using namespace toolkit; API_EXPORT mk_thread API_CALL mk_thread_from_tcp_session(mk_tcp_session ctx){ assert(ctx); SessionForC *obj = (SessionForC *)ctx; - return obj->getPoller().get(); + return (mk_thread)(obj->getPoller().get()); } API_EXPORT mk_thread API_CALL mk_thread_from_tcp_client(mk_tcp_client ctx){ assert(ctx); TcpClientForC::Ptr *client = (TcpClientForC::Ptr *)ctx; - return (*client)->getPoller().get(); + return (mk_thread)((*client)->getPoller().get()); } API_EXPORT mk_thread API_CALL mk_thread_from_pool(){ - return EventPollerPool::Instance().getPoller().get(); + return (mk_thread)(EventPollerPool::Instance().getPoller().get()); } API_EXPORT mk_thread API_CALL mk_thread_from_pool_work(){ - return WorkThreadPool::Instance().getPoller().get(); + return (mk_thread)(WorkThreadPool::Instance().getPoller().get()); } API_EXPORT void API_CALL mk_async_do(mk_thread ctx,on_mk_async cb, void *user_data){ @@ -123,7 +123,7 @@ API_EXPORT mk_timer API_CALL mk_timer_create2(mk_thread ctx, uint64_t delay_ms, std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); TimerForC::Ptr *ret = new TimerForC::Ptr(new TimerForC(cb, ptr)); (*ret)->start(delay_ms,*poller); - return ret; + return (mk_timer)ret; } API_EXPORT void API_CALL mk_timer_release(mk_timer ctx){ @@ -148,7 +148,7 @@ public: }; API_EXPORT mk_thread_pool API_CALL mk_thread_pool_create(const char *name, size_t n_thread, int priority) { - return new WorkThreadPoolForC(name, n_thread, priority); + return (mk_thread_pool)new WorkThreadPoolForC(name, n_thread, priority); } API_EXPORT int API_CALL mk_thread_pool_release(mk_thread_pool pool) { @@ -159,11 +159,11 @@ API_EXPORT int API_CALL mk_thread_pool_release(mk_thread_pool pool) { API_EXPORT mk_thread API_CALL mk_thread_from_thread_pool(mk_thread_pool pool) { assert(pool); - return ((WorkThreadPoolForC *) pool)->getPoller().get(); + return (mk_thread)(((WorkThreadPoolForC *) pool)->getPoller().get()); } API_EXPORT mk_sem API_CALL mk_sem_create() { - return new toolkit::semaphore; + return (mk_sem)new toolkit::semaphore; } API_EXPORT void API_CALL mk_sem_release(mk_sem sem) { diff --git a/api/source/mk_track.cpp b/api/source/mk_track.cpp index 578a5258..d20b01d0 100644 --- a/api/source/mk_track.cpp +++ b/api/source/mk_track.cpp @@ -82,8 +82,8 @@ public: API_EXPORT mk_track API_CALL mk_track_create(int codec_id, codec_args *args) { switch (getTrackType((CodecId) codec_id)) { - case TrackVideo: return new Track::Ptr(std::make_shared(codec_id, args)); - case TrackAudio: return new Track::Ptr(std::make_shared(codec_id, args)); + case TrackVideo: return (mk_track)new Track::Ptr(std::make_shared(codec_id, args)); + case TrackAudio: return (mk_track)new Track::Ptr(std::make_shared(codec_id, args)); default: WarnL << "unrecognized codec:" << codec_id; return nullptr; } } @@ -95,7 +95,7 @@ API_EXPORT void API_CALL mk_track_unref(mk_track track) { API_EXPORT mk_track API_CALL mk_track_ref(mk_track track) { assert(track); - return new Track::Ptr(*( (Track::Ptr *)track)); + return (mk_track)new Track::Ptr(*( (Track::Ptr *)track)); } API_EXPORT int API_CALL mk_track_codec_id(mk_track track) { diff --git a/api/source/mk_transcode.cpp b/api/source/mk_transcode.cpp index 294db8eb..6d118325 100644 --- a/api/source/mk_transcode.cpp +++ b/api/source/mk_transcode.cpp @@ -29,12 +29,12 @@ std::vector toCodecList(const char *codec_name_list[]) { API_EXPORT mk_decoder API_CALL mk_decoder_create(mk_track track, int thread_num) { assert(track); - return new FFmpegDecoder(*((Track::Ptr *) track), thread_num); + return (mk_decoder)new FFmpegDecoder(*((Track::Ptr *) track), thread_num); } API_EXPORT mk_decoder API_CALL mk_decoder_create2(mk_track track, int thread_num, const char *codec_name_list[]) { assert(track && codec_name_list); - return new FFmpegDecoder(*((Track::Ptr *) track), thread_num, toCodecList(codec_name_list)); + return (mk_decoder)new FFmpegDecoder(*((Track::Ptr *) track), thread_num, toCodecList(codec_name_list)); } API_EXPORT void API_CALL mk_decoder_release(mk_decoder ctx, int flush_frame) { @@ -77,12 +77,12 @@ API_EXPORT const AVCodecContext *API_CALL mk_decoder_get_context(mk_decoder ctx) API_EXPORT mk_frame_pix API_CALL mk_frame_pix_ref(mk_frame_pix frame) { assert(frame); - return new FFmpegFrame::Ptr(*(FFmpegFrame::Ptr *) frame); + return (mk_frame_pix)new FFmpegFrame::Ptr(*(FFmpegFrame::Ptr *) frame); } API_EXPORT mk_frame_pix API_CALL mk_frame_pix_from_av_frame(AVFrame *frame) { assert(frame); - return new FFmpegFrame::Ptr(std::make_shared(std::shared_ptr(av_frame_clone(frame), [](AVFrame *frame){ + return (mk_frame_pix)new FFmpegFrame::Ptr(std::make_shared(std::shared_ptr(av_frame_clone(frame), [](AVFrame *frame){ av_frame_free(&frame); }))); } @@ -99,7 +99,7 @@ API_EXPORT mk_frame_pix API_CALL mk_frame_pix_from_buffer(mk_buffer plane_data[] frame->linesize[i] = line_size[i]; buffer_array.emplace_back(buffer); } - return new FFmpegFrame::Ptr(new FFmpegFrame(std::move(frame)), [buffer_array](FFmpegFrame *frame) { + return (mk_frame_pix)new FFmpegFrame::Ptr(new FFmpegFrame(std::move(frame)), [buffer_array](FFmpegFrame *frame) { for (auto &buffer : buffer_array) { mk_buffer_unref(buffer); } @@ -120,7 +120,7 @@ API_EXPORT AVFrame *API_CALL mk_frame_pix_get_av_frame(mk_frame_pix frame) { ///////////////////////////////////////////////////////////////////////////////////////////// API_EXPORT mk_swscale mk_swscale_create(int output, int width, int height) { - return new FFmpegSws((AVPixelFormat) output, width, height); + return (mk_swscale)new FFmpegSws((AVPixelFormat) output, width, height); } API_EXPORT void mk_swscale_release(mk_swscale ctx) { @@ -132,7 +132,7 @@ API_EXPORT int mk_swscale_input_frame(mk_swscale ctx, mk_frame_pix frame, uint8_ } API_EXPORT mk_frame_pix mk_swscale_input_frame2(mk_swscale ctx, mk_frame_pix frame){ - return new FFmpegFrame::Ptr(((FFmpegSws *) ctx)->inputFrame(*(FFmpegFrame::Ptr *) frame)); + return (mk_frame_pix)new FFmpegFrame::Ptr(((FFmpegSws *) ctx)->inputFrame(*(FFmpegFrame::Ptr *) frame)); } API_EXPORT uint8_t **API_CALL mk_get_av_frame_data(AVFrame *frame) { diff --git a/api/source/mk_util.cpp b/api/source/mk_util.cpp index e6265000..6c775529 100644 --- a/api/source/mk_util.cpp +++ b/api/source/mk_util.cpp @@ -55,11 +55,11 @@ API_EXPORT char* API_CALL mk_util_hex_dump(const void *buf, int len){ } API_EXPORT mk_ini API_CALL mk_ini_create() { - return new mINI; + return (mk_ini)new mINI; } API_EXPORT mk_ini API_CALL mk_ini_default() { - return &(mINI::Instance()); + return (mk_ini)&(mINI::Instance()); } static void emit_ini_file_reload(mk_ini ini) { diff --git a/api/tests/pusher.c b/api/tests/pusher.c index 08c7661e..a6f3cd2d 100644 --- a/api/tests/pusher.c +++ b/api/tests/pusher.c @@ -32,7 +32,7 @@ void release_player(mk_player *ptr) { } } -void release_pusher(mk_media *ptr) { +void release_pusher(mk_pusher *ptr) { if (ptr && *ptr) { mk_pusher_release(*ptr); *ptr = NULL; diff --git a/src/Extension/SPSParser.c b/src/Extension/SPSParser.c index 6f6995cf..a23c515c 100644 --- a/src/Extension/SPSParser.c +++ b/src/Extension/SPSParser.c @@ -252,9 +252,9 @@ static inline int getBitsLeft(void *pvHandle) *functions ********************************************/ /** - * @brief Function getOneBit() ¶Á1¸öbit + * @brief Function getOneBit() get next bit * @param[in] h T_GetBitContext structrue - * @retval 0: success, -1 : failure + * @retval other : success, -1 : failure * @pre * @post */ @@ -291,10 +291,10 @@ exit: /** - * @brief Function getBits() ¶Án¸öbits£¬n²»Äܳ¬¹ý32 + * @brief Function getBits() get next bits * @param[in] h T_GetBitContext structrue * @param[in] n how many bits you want? - * @retval 0: success, -1 : failure + * @retval other : success, -1 : failure * @pre * @post */ @@ -446,7 +446,7 @@ static inline unsigned int showBitsLong(void *pvHandle, int iN) /** - * @brief Function parseCodenum() Ö¸Êý¸çÂײ¼±àÂë½âÎö£¬²Î¿¼h264±ê×¼µÚ9½Ú + * @brief Function parseCodenum() * @param[in] buf * @retval u32CodeNum * @pre @@ -469,7 +469,7 @@ static int parseCodenum(void *pvBuf) } /** - * @brief Function parseUe() Ö¸Êý¸çÂײ¼±àÂë½âÎö ue(),²Î¿¼h264±ê×¼µÚ9½Ú + * @brief Function parseUe() * @param[in] buf sps_pps parse buf * @retval u32CodeNum * @pre @@ -482,7 +482,7 @@ static int parseUe(void *pvBuf) /** - * @brief Function parseSe() Ö¸Êý¸çÂײ¼±àÂë½âÎö se(), ²Î¿¼h264±ê×¼µÚ9½Ú + * @brief Function parseSe() * @param[in] buf sps_pps parse buf * @retval u32CodeNum * @pre @@ -502,7 +502,7 @@ static int parseSe(void *pvBuf) /** - * @brief Function getBitContextFree() ÉêÇëµÄget_bit_context½á¹¹ÄÚ´æÊÍ·Å + * @brief Function getBitContextFree() * @param[in] buf T_GetBitContext buf * @retval none * @pre @@ -527,18 +527,13 @@ static void getBitContextFree(void *pvBuf) /** - * @brief Function deEmulationPrevention() ½â¾ºÕù´úÂë + * @brief Function deEmulationPrevention() * @param[in] buf T_GetBitContext buf * @retval none * @pre * @post * @note: - * µ÷ÊÔʱ×ÜÊÇ·¢ÏÖvui.time_scaleÖµÌØ±ðÆæ¹Ö£¬×ÜÊÇ16777216£¬ºóÀ´²éѯԭÒòÈçÏÂ: * http://www.cnblogs.com/eustoma/archive/2012/02/13/2415764.html - * H.264±àÂëʱ£¬ÔÚÿ¸öNALǰÌí¼ÓÆðʼÂë 0x000001£¬½âÂëÆ÷ÔÚÂëÁ÷Öмì²âµ½ÆðʼÂ룬µ±Ç°NAL½áÊø¡£ - * ΪÁË·ÀÖ¹NALÄÚ²¿³öÏÖ0x000001µÄÊý¾Ý£¬h.264ÓÖÌá³ö'·ÀÖ¹¾ºÕù emulation prevention"»úÖÆ£¬ - * ÔÚ±àÂëÍêÒ»¸öNALʱ£¬Èç¹û¼ì²â³öÓÐÁ¬ÐøÁ½¸ö0x00×Ö½Ú£¬¾ÍÔÚºóÃæ²åÈëÒ»¸ö0x03¡£ - * µ±½âÂëÆ÷ÔÚNALÄÚ²¿¼ì²âµ½0x000003µÄÊý¾Ý£¬¾Í°Ñ0x03Åׯú£¬»Ö¸´Ô­Ê¼Êý¾Ý¡£ * 0x000000 >>>>>> 0x00000300 * 0x000001 >>>>>> 0x00000301 * 0x000002 >>>>>> 0x00000302 @@ -581,22 +576,20 @@ static void *deEmulationPrevention(void *pvBuf) tmp_buf_size = ptPtr->iBufSize; for(i=0; i<(tmp_buf_size-2); i++) { - /*¼ì²â0x000003*/ + iVal = (pu8TmpPtr[i]^0x00) + (pu8TmpPtr[i+1]^0x00) + (pu8TmpPtr[i+2]^0x03); if(iVal == 0) { - /*ÌÞ³ý0x03*/ + for(j=i+2; jiBufSize--; } } - - /*ÖØÐ¼ÆËãtotal_bit*/ ptPtr->iTotalBit = ptPtr->iBufSize << 3; return (void *)ptPtr; diff --git a/src/Extension/SPSParser.h b/src/Extension/SPSParser.h index 6cdb6005..1943b757 100644 --- a/src/Extension/SPSParser.h +++ b/src/Extension/SPSParser.h @@ -439,11 +439,11 @@ typedef struct T_HEVCSPS { typedef struct T_GetBitContext{ - uint8_t *pu8Buf; /*Ö¸ÏòSPS start*/ - int iBufSize; /*SPS ³¤¶È*/ - int iBitPos; /*bitÒѶÁȡλÖÃ*/ - int iTotalBit; /*bit×ܳ¤¶È*/ - int iCurBitPos; /*µ±Ç°¶ÁȡλÖÃ*/ + uint8_t *pu8Buf; // buf + int iBufSize; // buf size + int iBitPos; // bit position + int iTotalBit; // bit number + int iCurBitPos; // current bit position }T_GetBitContext; diff --git a/src/Http/HttpFileManager.cpp b/src/Http/HttpFileManager.cpp index 7e71186c..b690801a 100644 --- a/src/Http/HttpFileManager.cpp +++ b/src/Http/HttpFileManager.cpp @@ -496,6 +496,10 @@ void HttpFileManager::onAccessPath(Session &sender, Parser &parser, const HttpFi auto fullUrl = string(HTTP_SCHEMA) + "://" + parser["Host"] + parser.FullUrl(); MediaInfo media_info(fullUrl); auto file_path = getFilePath(parser, media_info, sender); + if (file_path.size() == 0) { + sendNotFound(cb); + return; + } //访问的是文件夹 if (File::is_dir(file_path.data())) { auto indexFile = searchIndexFile(file_path); diff --git a/src/Rtp/RtpProcess.cpp b/src/Rtp/RtpProcess.cpp index 0673931e..fac1bd2b 100644 --- a/src/Rtp/RtpProcess.cpp +++ b/src/Rtp/RtpProcess.cpp @@ -71,6 +71,10 @@ RtpProcess::~RtpProcess() { } bool RtpProcess::inputRtp(bool is_udp, const Socket::Ptr &sock, const char *data, size_t len, const struct sockaddr *addr, uint64_t *dts_out) { + if (!isRtp(data, len)) { + WarnP(this) << "Not rtp packet"; + return false; + } if (_sock != sock) { // 第一次运行本函数 bool first = !_sock; diff --git a/src/Rtp/RtpSelector.cpp b/src/Rtp/RtpSelector.cpp index 8ac620ac..f2ab0d41 100644 --- a/src/Rtp/RtpSelector.cpp +++ b/src/Rtp/RtpSelector.cpp @@ -42,7 +42,7 @@ RtpProcess::Ptr RtpSelector::getProcess(const string &stream_id,bool makeNew) { } if (it != _map_rtp_process.end() && makeNew) { //已经被其他线程持有了,不得再被持有,否则会存在线程安全的问题 - throw std::runtime_error(StrPrinter << "RtpProcess(" << stream_id << ") already existed"); + throw ProcessExisted(StrPrinter << "RtpProcess(" << stream_id << ") already existed"); } RtpProcessHelper::Ptr &ref = _map_rtp_process[stream_id]; if (!ref) { diff --git a/src/Rtp/RtpSelector.h b/src/Rtp/RtpSelector.h index d5d6dda1..db0683e8 100644 --- a/src/Rtp/RtpSelector.h +++ b/src/Rtp/RtpSelector.h @@ -44,6 +44,13 @@ public: RtpSelector() = default; ~RtpSelector() = default; + class ProcessExisted : public std::runtime_error { + public: + template + ProcessExisted(T && ...args) : std::runtime_error(std::forward(args)...) {} + ~ProcessExisted() override = default; + }; + static bool getSSRC(const char *data,size_t data_len, uint32_t &ssrc); static RtpSelector &Instance(); diff --git a/src/Rtp/RtpSession.cpp b/src/Rtp/RtpSession.cpp index f99734c8..e66e6f42 100644 --- a/src/Rtp/RtpSession.cpp +++ b/src/Rtp/RtpSession.cpp @@ -12,6 +12,7 @@ #include "RtpSession.h" #include "RtpSelector.h" #include "Network/TcpServer.h" +#include "Rtsp/Rtsp.h" #include "Rtsp/RtpReceiver.h" #include "Common/config.h" @@ -75,6 +76,15 @@ void RtpSession::onManager() { } void RtpSession::onRtpPacket(const char *data, size_t len) { + if (_delay_close) { + // 正在延时关闭中,忽略所有数据 + return; + } + if (!isRtp(data, len)) { + // 忽略非rtp数据 + WarnP(this) << "Not rtp packet"; + return; + } if (!_is_udp) { if (_search_rtp) { //搜索上下文期间,数据丢弃 @@ -101,8 +111,18 @@ void RtpSession::onRtpPacket(const char *data, size_t len) { //未指定流id就使用ssrc为流id _stream_id = printSSRC(_ssrc); } - //tcp情况下,一个tcp链接只可能是一路流,不需要通过多个ssrc来区分,所以不需要频繁getProcess - _process = RtpSelector::Instance().getProcess(_stream_id, true); + try { + _process = RtpSelector::Instance().getProcess(_stream_id, true); + } catch (RtpSelector::ProcessExisted &ex) { + if (!_is_udp) { + // tcp情况下立即断开连接 + throw; + } + // udp情况下延时断开连接(等待超时自动关闭),防止频繁创建销毁RtpSession对象 + WarnP(this) << ex.what(); + _delay_close = true; + return; + } _process->setOnlyAudio(_only_audio); _process->setDelegate(dynamic_pointer_cast(shared_from_this())); } diff --git a/src/Rtp/RtpSession.h b/src/Rtp/RtpSession.h index e6019a61..d966fe5a 100644 --- a/src/Rtp/RtpSession.h +++ b/src/Rtp/RtpSession.h @@ -44,6 +44,7 @@ protected: const char *onSearchPacketTail(const char *data, size_t len) override; private: + bool _delay_close = false; bool _is_udp = false; bool _search_rtp = false; bool _search_rtp_finished = false; diff --git a/src/Rtsp/RtpReceiver.cpp b/src/Rtsp/RtpReceiver.cpp index 96b077d7..363a3175 100644 --- a/src/Rtsp/RtpReceiver.cpp +++ b/src/Rtsp/RtpReceiver.cpp @@ -14,7 +14,7 @@ namespace mediakit { RtpTrack::RtpTrack() { - setOnSort([this](uint16_t seq, RtpPacket::Ptr &packet) { + setOnSort([this](uint16_t seq, RtpPacket::Ptr packet) { onRtpSorted(std::move(packet)); }); } @@ -114,7 +114,7 @@ void RtpTrack::setNtpStamp(uint32_t rtp_stamp, uint64_t ntp_stamp_ms) { } } -void RtpTrack::setPT(uint8_t pt){ +void RtpTrack::setPayloadType(uint8_t pt) { _pt = pt; } diff --git a/src/Rtsp/RtpReceiver.h b/src/Rtsp/RtpReceiver.h index 392df104..5a2e4da6 100644 --- a/src/Rtsp/RtpReceiver.h +++ b/src/Rtsp/RtpReceiver.h @@ -18,42 +18,37 @@ #include "Extension/Frame.h" // for NtpStamp #include "Common/Stamp.h" +#include "Util/TimeTicker.h" namespace mediakit { -template +template class PacketSortor { public: + static constexpr SEQ SEQ_MAX = (std::numeric_limits::max)(); PacketSortor() = default; ~PacketSortor() = default; - void setOnSort(std::function cb) { - _cb = std::move(cb); - } + void setOnSort(std::function cb) { _cb = std::move(cb); } /** * 清空状态 */ void clear() { + _started = false; _seq_cycle_count = 0; _pkt_sort_cache_map.clear(); - _next_seq_out = 0; - _max_sort_size = kMin; } /** * 获取排序缓存长度 */ - size_t getJitterSize() const{ - return _pkt_sort_cache_map.size(); - } + size_t getJitterSize() const { return _pkt_sort_cache_map.size(); } /** * 获取seq回环次数 */ - size_t getCycleCount() const{ - return _seq_cycle_count; - } + size_t getCycleCount() const { return _seq_cycle_count; } /** * 输入并排序 @@ -61,110 +56,106 @@ public: * @param packet 包负载 */ void sortPacket(SEQ seq, T packet) { - if(!_is_inited && _next_seq_out == 0){ - _next_seq_out = seq; - _is_inited = true; + if (!_started) { + // 记录第一个seq + _started = true; + _last_seq_out = seq - 1; } - if (seq < _next_seq_out) { - if (_next_seq_out < seq + kMax) { - //过滤seq回退包(回环包除外) - return; - } - } else if (_next_seq_out && seq - _next_seq_out > ((std::numeric_limits::max)() >> 1)) { - //过滤seq跳变非常大的包(防止回环时乱序时收到非常大的seq) + if (seq == static_cast(_last_seq_out + 1)) { + // 收到下一个seq + output(seq, std::move(packet)); + return; + } + + if (seq < _last_seq_out && _last_seq_out != SEQ_MAX && seq < 1024 && _last_seq_out > SEQ_MAX - 1024) { + // seq回环,清空回环前缓存 + flush(); + _last_seq_out = SEQ_MAX; + ++_seq_cycle_count; + sortPacket(seq, std::move(packet)); + return; + } + + if (seq <= _last_seq_out && _last_seq_out != SEQ_MAX) { + // 这个回退包已经不再等待 return; } - //放入排序缓存 _pkt_sort_cache_map.emplace(seq, std::move(packet)); - //尝试输出排序后的包 - tryPopPacket(); + auto it_min = _pkt_sort_cache_map.begin(); + auto it_max = _pkt_sort_cache_map.rbegin(); + if (it_max->first - it_min->first > (SEQ_MAX >> 1)) { + // 回环后,收到回环前的大值seq, 忽略掉 + _pkt_sort_cache_map.erase((++it_max).base()); + return; + } + + tryFlushFrontPacket(); + + if (_pkt_sort_cache_map.size() > _max_buffer_size || (_ticker.elapsedTime() > _max_buffer_ms && !_pkt_sort_cache_map.empty())) { + // buffer太长,强行减小 + WarnL << "packet dropped: " << static_cast(_last_seq_out + 1) << " -> " + << static_cast(_pkt_sort_cache_map.begin()->first - 1) + << ", jitter buffer size: " << _pkt_sort_cache_map.size() + << ", jitter buffer ms: " << _ticker.elapsedTime(); + popIterator(_pkt_sort_cache_map.begin()); + } } - void flush(){ - //清空缓存 + void flush() { + // 清空缓存 while (!_pkt_sort_cache_map.empty()) { popIterator(_pkt_sort_cache_map.begin()); } } private: - void popPacket() { - auto it = _pkt_sort_cache_map.begin(); - if (it->first >= _next_seq_out) { - //过滤回跳包 - popIterator(it); - return; - } - - if (_next_seq_out - it->first > (0xFFFF >> 1)) { - //产生回环了 - if (_pkt_sort_cache_map.size() < 2 * kMin) { - //等足够多的数据后才处理回环, 因为后面还可能出现大的SEQ - return; + void tryFlushFrontPacket() { + while (!_pkt_sort_cache_map.empty()) { + auto it = _pkt_sort_cache_map.begin(); + auto next_seq = static_cast(_last_seq_out + 1); + if (it->first < next_seq) { + _pkt_sort_cache_map.erase(it); + continue; } - ++_seq_cycle_count; - //找到大的SEQ并清空掉,然后从小的SEQ重新开始排序 - auto hit = _pkt_sort_cache_map.upper_bound((SEQ) (_next_seq_out - _pkt_sort_cache_map.size())); - while (hit != _pkt_sort_cache_map.end()) { - //回环前,清空剩余的大的SEQ的数据 - _cb(hit->first, hit->second); - hit = _pkt_sort_cache_map.erase(hit); + if (it->first == next_seq) { + // 连续的seq + popIterator(it); + continue; } - //下一个回环的数据 - popIterator(_pkt_sort_cache_map.begin()); - return; + break; } - //删除回跳的数据包 - _pkt_sort_cache_map.erase(it); } void popIterator(typename std::map::iterator it) { auto seq = it->first; auto data = std::move(it->second); _pkt_sort_cache_map.erase(it); - _next_seq_out = seq + 1; - _cb(seq, data); + output(seq, std::move(data)); } - void tryPopPacket() { - int count = 0; - while ((!_pkt_sort_cache_map.empty() && _pkt_sort_cache_map.begin()->first == _next_seq_out)) { - //找到下个包,直接输出 - popPacket(); - ++count; - } - - if (count) { - setSortSize(); - } else if (_pkt_sort_cache_map.size() > _max_sort_size) { - //排序缓存溢出,不再继续排序 - popPacket(); - setSortSize(); - } - } - - void setSortSize() { - _max_sort_size = kMin + _pkt_sort_cache_map.size(); - if (_max_sort_size > kMax) { - _max_sort_size = kMax; - } + void output(SEQ seq, T packet) { + _last_seq_out = seq; + _cb(seq, std::move(packet)); + _ticker.resetTime(); } private: - //第一个包是已经进入 - bool _is_inited = false; - + bool _started = false; + //排序缓存最大保存数据长度,单位毫秒 + size_t _max_buffer_ms = 3000; + //排序缓存最大保存数据个数 + size_t _max_buffer_size = 1024; + //记录上次output至今的时间 + toolkit::Ticker _ticker; //下次应该输出的SEQ - SEQ _next_seq_out = 0; + SEQ _last_seq_out = 0; //seq回环次数计数 size_t _seq_cycle_count = 0; - //排序缓存长度 - size_t _max_sort_size = kMin; //pkt排序缓存,根据seq排序 std::map _pkt_sort_cache_map; //回调 - std::function _cb; + std::function _cb; }; class RtpTrack : private PacketSortor { @@ -183,7 +174,7 @@ public: uint32_t getSSRC() const; RtpPacket::Ptr inputRtp(TrackType type, int sample_rate, uint8_t *ptr, size_t len); void setNtpStamp(uint32_t rtp_stamp, uint64_t ntp_stamp_ms); - void setPT(uint8_t pt); + void setPayloadType(uint8_t pt); protected: virtual void onRtpSorted(RtpPacket::Ptr rtp) {} @@ -261,9 +252,9 @@ public: _track[index].setNtpStamp(rtp_stamp, ntp_stamp_ms); } - void setPT(int index, uint8_t pt){ + void setPayloadType(int index, uint8_t pt){ assert(index < kCount && index >= 0); - _track[index].setPT(pt); + _track[index].setPayloadType(pt); } void clear() { diff --git a/src/Rtsp/Rtsp.cpp b/src/Rtsp/Rtsp.cpp index 251b7f65..b80a88f7 100644 --- a/src/Rtsp/Rtsp.cpp +++ b/src/Rtsp/Rtsp.cpp @@ -442,6 +442,22 @@ string printSSRC(uint32_t ui32Ssrc) { return tmp; } +bool isRtp(const char *buf, size_t size) { + if (size < 2) { + return false; + } + RtpHeader *header = (RtpHeader *)buf; + return ((header->pt < 64) || (header->pt >= 96)); +} + +bool isRtcp(const char *buf, size_t size) { + if (size < 2) { + return false; + } + RtpHeader *header = (RtpHeader *)buf; + return ((header->pt >= 64) && (header->pt < 96)); +} + Buffer::Ptr makeRtpOverTcpPrefix(uint16_t size, uint8_t interleaved) { auto rtp_tcp = BufferRaw::create(); rtp_tcp->setCapacity(RtpPacket::kRtpTcpHeaderSize); diff --git a/src/Rtsp/Rtsp.h b/src/Rtsp/Rtsp.h index 9a581b95..5e7370c6 100644 --- a/src/Rtsp/Rtsp.h +++ b/src/Rtsp/Rtsp.h @@ -337,5 +337,8 @@ void makeSockPair(std::pair &pair, c //十六进制方式打印ssrc std::string printSSRC(uint32_t ui32Ssrc); +bool isRtp(const char *buf, size_t size); +bool isRtcp(const char *buf, size_t size); + } //namespace mediakit #endif //RTSP_RTSP_H_ diff --git a/src/Rtsp/RtspPlayer.cpp b/src/Rtsp/RtspPlayer.cpp index e99e2591..2d7ab925 100644 --- a/src/Rtsp/RtspPlayer.cpp +++ b/src/Rtsp/RtspPlayer.cpp @@ -225,7 +225,7 @@ void RtspPlayer::handleResDESCRIBE(const Parser& parser) { _rtcp_context.clear(); for (auto &track : _sdp_track) { if(track->_pt != 0xff){ - setPT(_rtcp_context.size(),track->_pt); + setPayloadType(_rtcp_context.size(),track->_pt); } _rtcp_context.emplace_back(std::make_shared()); } diff --git a/tests/test_rtcp_nack.cpp b/tests/test_rtcp_nack.cpp index f3f8a6c2..864c3b2f 100644 --- a/tests/test_rtcp_nack.cpp +++ b/tests/test_rtcp_nack.cpp @@ -15,20 +15,38 @@ using namespace std; using namespace toolkit; using namespace mediakit; -extern void testFCI(); - int main() { Logger::Instance().add(std::make_shared()); + Logger::Instance().setWriter(std::make_shared()); - srand((unsigned) time(NULL)); - + srand((unsigned)time(NULL)); NackContext ctx; - for (int i = 1; i < 1000; ++i) { - if (i % (1 + (rand() % 30)) == 0) { - DebugL << "drop:" << i; + ctx.setOnNack([](const FCI_NACK &nack){ + InfoL << nack.dumpString(); + }); + auto drop_start = 0; + auto drop_len = 0; + uint16_t offset = 0xFFFF - 200 - 50; + for (int i = 1; i < 10000; ++i) { + if (i % 100 == 0) { + drop_start = i + rand() % 16; + drop_len = 4 + rand() % 16; + InfoL << "start drop:" << (uint16_t)(drop_start + offset) << " -> " + << (uint16_t)(drop_start + offset + drop_len); + } + uint16_t seq = i + offset; + if ((i >= drop_start && i <= drop_start + drop_len) || seq == 65535 || seq == 0 || seq == 1) { + TraceL << "drop:" << (uint16_t)(i + offset); } else { - ctx.received(i); - + static auto last_seq = seq; + if (seq - last_seq > 16) { + ctx.received(last_seq); + ctx.received(seq); + DebugL << "seq reduce:" << last_seq; + last_seq = seq; + } else { + ctx.received(seq); + } } } sleep(1); diff --git a/tests/test_sortor.cpp b/tests/test_sortor.cpp index be153b9c..921f248c 100644 --- a/tests/test_sortor.cpp +++ b/tests/test_sortor.cpp @@ -102,7 +102,7 @@ void test_real() { PacketSortor sortor; list sorted_list; - sortor.setOnSort([&](uint16_t seq, const uint16_t &packet) { + sortor.setOnSort([&](uint16_t seq, uint16_t packet) { sorted_list.push_back(seq); }); diff --git a/webrtc/DtlsTransport.cpp b/webrtc/DtlsTransport.cpp index 0e2f160a..66183f8e 100644 --- a/webrtc/DtlsTransport.cpp +++ b/webrtc/DtlsTransport.cpp @@ -33,1453 +33,1452 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. using namespace std; #define LOG_OPENSSL_ERROR(desc) \ - do \ - { \ - if (ERR_peek_error() == 0) \ - MS_ERROR("OpenSSL error [desc:'%s']", desc); \ - else \ - { \ - int64_t err; \ - while ((err = ERR_get_error()) != 0) \ - { \ - MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ - } \ - ERR_clear_error(); \ - } \ - } while (false) + do \ + { \ + if (ERR_peek_error() == 0) \ + MS_ERROR("OpenSSL error [desc:'%s']", desc); \ + else \ + { \ + int64_t err; \ + while ((err = ERR_get_error()) != 0) \ + { \ + MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ + } \ + ERR_clear_error(); \ + } \ + } while (false) /* Static methods for OpenSSL callbacks. */ inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) { - MS_TRACE(); + MS_TRACE(); - // Always valid since DTLS certificates are self-signed. - return 1; + // Always valid since DTLS certificates are self-signed. + return 1; } inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) { - if (timerUs == 0) - return 100000; - else if (timerUs >= 4000000) - return 4000000; - else - return 2 * timerUs; + if (timerUs == 0) + return 100000; + else if (timerUs >= 4000000) + return 4000000; + else + return 2 * timerUs; } namespace RTC { - /* Static. */ + /* Static. */ - // clang-format off - static constexpr int DtlsMtu{ 1350 }; - // AES-HMAC: http://tools.ietf.org/html/rfc3711 - static constexpr size_t SrtpMasterKeyLength{ 16 }; - static constexpr size_t SrtpMasterSaltLength{ 14 }; - static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; - // AES-GCM: http://tools.ietf.org/html/rfc7714 - static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; - static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; - static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; - static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; - static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; - static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; - // clang-format on + // clang-format off + static constexpr int DtlsMtu{ 1350 }; + // AES-HMAC: http://tools.ietf.org/html/rfc3711 + static constexpr size_t SrtpMasterKeyLength{ 16 }; + static constexpr size_t SrtpMasterSaltLength{ 14 }; + static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; + // AES-GCM: http://tools.ietf.org/html/rfc7714 + static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; + static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; + static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; + static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; + // clang-format on - /* Class variables. */ - // clang-format off - std::map DtlsTransport::string2FingerprintAlgorithm = - { - { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, - { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, - { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, - { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, - { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } - }; - std::map DtlsTransport::fingerprintAlgorithm2String = - { - { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, - { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, - { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, - { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, - { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } - }; - std::map DtlsTransport::string2Role = - { - { "auto", DtlsTransport::Role::AUTO }, - { "client", DtlsTransport::Role::CLIENT }, - { "server", DtlsTransport::Role::SERVER } - }; - std::vector DtlsTransport::srtpCryptoSuites = - { - { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, - { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, - { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, - { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } - }; - // clang-format on + /* Class variables. */ + // clang-format off + std::map DtlsTransport::string2FingerprintAlgorithm = + { + { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, + { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, + { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, + { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, + { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } + }; + std::map DtlsTransport::fingerprintAlgorithm2String = + { + { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, + { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, + { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, + { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, + { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } + }; + std::map DtlsTransport::string2Role = + { + { "auto", DtlsTransport::Role::AUTO }, + { "client", DtlsTransport::Role::CLIENT }, + { "server", DtlsTransport::Role::SERVER } + }; + std::vector DtlsTransport::srtpCryptoSuites = + { + { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, + { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } + }; + // clang-format on - INSTANCE_IMP(DtlsTransport::DtlsEnvironment); + INSTANCE_IMP(DtlsTransport::DtlsEnvironment); - /* Class methods. */ + /* Class methods. */ DtlsTransport::DtlsEnvironment::DtlsEnvironment() - { - MS_TRACE(); + { + MS_TRACE(); - // Generate a X509 certificate and private key (unless PEM files are provided). - if (true /* - Settings::configuration.dtlsCertificateFile.empty() || - Settings::configuration.dtlsPrivateKeyFile.empty()*/) - { - GenerateCertificateAndPrivateKey(); - } - else - { - ReadCertificateAndPrivateKeyFromFiles(); - } + // Generate a X509 certificate and private key (unless PEM files are provided). + if (true /* + Settings::configuration.dtlsCertificateFile.empty() || + Settings::configuration.dtlsPrivateKeyFile.empty()*/) + { + GenerateCertificateAndPrivateKey(); + } + else + { + ReadCertificateAndPrivateKeyFromFiles(); + } - // Create a global SSL_CTX. - CreateSslCtx(); + // Create a global SSL_CTX. + CreateSslCtx(); - // Generate certificate fingerprints. - GenerateFingerprints(); - } + // Generate certificate fingerprints. + GenerateFingerprints(); + } DtlsTransport::DtlsEnvironment::~DtlsEnvironment() - { - MS_TRACE(); + { + MS_TRACE(); - if (privateKey) - EVP_PKEY_free(privateKey); - if (certificate) - X509_free(certificate); - if (sslCtx) - SSL_CTX_free(sslCtx); - } + if (privateKey) + EVP_PKEY_free(privateKey); + if (certificate) + X509_free(certificate); + if (sslCtx) + SSL_CTX_free(sslCtx); + } - void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() - { - MS_TRACE(); + void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() + { + MS_TRACE(); - int ret{ 0 }; - EC_KEY* ecKey{ nullptr }; - X509_NAME* certName{ nullptr }; - std::string subject = - std::string("mediasoup") + to_string(rand() % 999999 + 100000); + int ret{ 0 }; + EC_KEY* ecKey{ nullptr }; + X509_NAME* certName{ nullptr }; + std::string subject = + std::string("mediasoup") + to_string(rand() % 999999 + 100000); - // Create key with curve. - ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // Create key with curve. + ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - if (!ecKey) - { - LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); + if (!ecKey) + { + LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); - goto error; - } + goto error; + } - EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); + EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); - // NOTE: This can take some time. - ret = EC_KEY_generate_key(ecKey); + // NOTE: This can take some time. + ret = EC_KEY_generate_key(ecKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); - goto error; - } + goto error; + } - // Create a private key object. - privateKey = EVP_PKEY_new(); + // Create a private key object. + privateKey = EVP_PKEY_new(); - if (!privateKey) - { - LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); + if (!privateKey) + { + LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); - goto error; - } + goto error; + } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) - ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); - goto error; - } + goto error; + } - // The EC key now belongs to the private key, so don't clean it up separately. - ecKey = nullptr; + // The EC key now belongs to the private key, so don't clean it up separately. + ecKey = nullptr; - // Create the X509 certificate. - certificate = X509_new(); + // Create the X509 certificate. + certificate = X509_new(); - if (!certificate) - { - LOG_OPENSSL_ERROR("X509_new() failed"); + if (!certificate) + { + LOG_OPENSSL_ERROR("X509_new() failed"); - goto error; - } + goto error; + } - // Set version 3 (note that 0 means version 1). - X509_set_version(certificate, 2); + // Set version 3 (note that 0 means version 1). + X509_set_version(certificate, 2); - // Set serial number (avoid default 0). - ASN1_INTEGER_set( - X509_get_serialNumber(certificate), - static_cast(rand() % 999999 + 100000)); + // Set serial number (avoid default 0). + ASN1_INTEGER_set( + X509_get_serialNumber(certificate), + static_cast(rand() % 999999 + 100000)); - // Set valid period. - X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. - X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. + // Set valid period. + X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. + X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. - // Set the public key for the certificate using the key. - ret = X509_set_pubkey(certificate, privateKey); + // Set the public key for the certificate using the key. + ret = X509_set_pubkey(certificate, privateKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); - goto error; - } + goto error; + } - // Set certificate fields. - certName = X509_get_subject_name(certificate); + // Set certificate fields. + certName = X509_get_subject_name(certificate); - if (!certName) - { - LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); + if (!certName) + { + LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); - goto error; - } + goto error; + } - X509_NAME_add_entry_by_txt( - certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); - X509_NAME_add_entry_by_txt( - certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt( + certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt( + certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); - // It is self-signed so set the issuer name to be the same as the subject. - ret = X509_set_issuer_name(certificate, certName); + // It is self-signed so set the issuer name to be the same as the subject. + ret = X509_set_issuer_name(certificate, certName); - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); - goto error; - } + goto error; + } - // Sign the certificate with its own private key. - ret = X509_sign(certificate, privateKey, EVP_sha1()); + // Sign the certificate with its own private key. + ret = X509_sign(certificate, privateKey, EVP_sha1()); - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_sign() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_sign() failed"); - goto error; - } + goto error; + } - return; + return; - error: + error: - if (ecKey) - EC_KEY_free(ecKey); + if (ecKey) + EC_KEY_free(ecKey); - if (privateKey) - EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. + if (privateKey) + EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. - if (certificate) - X509_free(certificate); + if (certificate) + X509_free(certificate); - MS_THROW_ERROR("DTLS certificate and private key generation failed"); - } + MS_THROW_ERROR("DTLS certificate and private key generation failed"); + } - void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles() - { + void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles() + { #if 0 - MS_TRACE(); + MS_TRACE(); - FILE* file{ nullptr }; + FILE* file{ nullptr }; - file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); + file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); - if (!file) - { - MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); + if (!file) + { + MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); - goto error; - } + goto error; + } - certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); + certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); - if (!certificate) - { - LOG_OPENSSL_ERROR("PEM_read_X509() failed"); + if (!certificate) + { + LOG_OPENSSL_ERROR("PEM_read_X509() failed"); - goto error; - } + goto error; + } - fclose(file); + fclose(file); - file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); + file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); - if (!file) - { - MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); + if (!file) + { + MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); - goto error; - } + goto error; + } - privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); + privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); - if (!privateKey) - { - LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); + if (!privateKey) + { + LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); - goto error; - } + goto error; + } - fclose(file); + fclose(file); - return; + return; - error: + error: - MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); + MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); #endif - } + } - void DtlsTransport::DtlsEnvironment::CreateSslCtx() - { - MS_TRACE(); + void DtlsTransport::DtlsEnvironment::CreateSslCtx() + { + MS_TRACE(); - std::string dtlsSrtpCryptoSuites; - int ret; + std::string dtlsSrtpCryptoSuites; + int ret; - /* Set the global DTLS context. */ + /* Set the global DTLS context. */ - // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). - sslCtx = SSL_CTX_new(DTLS_method()); + // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). + sslCtx = SSL_CTX_new(DTLS_method()); - if (!sslCtx) - { - LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); + if (!sslCtx) + { + LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); - goto error; - } + goto error; + } - ret = SSL_CTX_use_certificate(sslCtx, certificate); + ret = SSL_CTX_use_certificate(sslCtx, certificate); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); - goto error; - } + goto error; + } - ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); + ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); - goto error; - } + goto error; + } - ret = SSL_CTX_check_private_key(sslCtx); + ret = SSL_CTX_check_private_key(sslCtx); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); - goto error; - } + goto error; + } - // Set options. - SSL_CTX_set_options( - sslCtx, - SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | - SSL_OP_NO_QUERY_MTU); + // Set options. + SSL_CTX_set_options( + sslCtx, + SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | + SSL_OP_NO_QUERY_MTU); - // Don't use sessions cache. - SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); + // Don't use sessions cache. + SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); - // Read always as much into the buffer as possible. - // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL - // versions makes this call required. - SSL_CTX_set_read_ahead(sslCtx, 1); + // Read always as much into the buffer as possible. + // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL + // versions makes this call required. + SSL_CTX_set_read_ahead(sslCtx, 1); - SSL_CTX_set_verify_depth(sslCtx, 4); + SSL_CTX_set_verify_depth(sslCtx, 4); - // Require certificate from peer. - SSL_CTX_set_verify( - sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); + // Require certificate from peer. + SSL_CTX_set_verify( + sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); - // Set SSL info callback. - SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ + // Set SSL info callback. + SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); }); - // Set ciphers. - ret = SSL_CTX_set_cipher_list( - sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); + // Set ciphers. + ret = SSL_CTX_set_cipher_list( + sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); - goto error; - } + goto error; + } - // Enable ECDH ciphers. - // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters - // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 - // NOTE: https://bugs.ruby-lang.org/issues/12324 + // Enable ECDH ciphers. + // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters + // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 + // NOTE: https://bugs.ruby-lang.org/issues/12324 - // For OpenSSL >= 1.0.2. - SSL_CTX_set_ecdh_auto(sslCtx, 1); + // For OpenSSL >= 1.0.2. + SSL_CTX_set_ecdh_auto(sslCtx, 1); - // Set the "use_srtp" DTLS extension. - for (auto it = DtlsTransport::srtpCryptoSuites.begin(); - it != DtlsTransport::srtpCryptoSuites.end(); - ++it) - { - if (it != DtlsTransport::srtpCryptoSuites.begin()) - dtlsSrtpCryptoSuites += ":"; + // Set the "use_srtp" DTLS extension. + for (auto it = DtlsTransport::srtpCryptoSuites.begin(); + it != DtlsTransport::srtpCryptoSuites.end(); + ++it) + { + if (it != DtlsTransport::srtpCryptoSuites.begin()) + dtlsSrtpCryptoSuites += ":"; - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); - dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; - } + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); + dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; + } - MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); + MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); - // NOTE: This function returns 0 on success. - ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); + // NOTE: This function returns 0 on success. + ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); - if (ret != 0) - { - MS_ERROR( - "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); - LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); + if (ret != 0) + { + MS_ERROR( + "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); + LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); - goto error; - } + goto error; + } - return; + return; - error: + error: - if (sslCtx) - { - SSL_CTX_free(sslCtx); - sslCtx = nullptr; - } + if (sslCtx) + { + SSL_CTX_free(sslCtx); + sslCtx = nullptr; + } - MS_THROW_ERROR("SSL context creation failed"); - } + MS_THROW_ERROR("SSL context creation failed"); + } - void DtlsTransport::DtlsEnvironment::GenerateFingerprints() - { - MS_TRACE(); + void DtlsTransport::DtlsEnvironment::GenerateFingerprints() + { + MS_TRACE(); - for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) - { - const std::string& algorithmString = kv.first; - FingerprintAlgorithm algorithm = kv.second; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{ 0 }; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; + for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) + { + const std::string& algorithmString = kv.first; + FingerprintAlgorithm algorithm = kv.second; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; - switch (algorithm) - { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; + switch (algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; - default: - MS_THROW_ERROR("unknown algorithm"); - } + default: + MS_THROW_ERROR("unknown algorithm"); + } - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - if (ret == 0) - { - MS_ERROR("X509_digest() failed"); - MS_THROW_ERROR("Fingerprints generation failed"); - } + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); + MS_THROW_ERROR("Fingerprints generation failed"); + } - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{ 0 }; i < size; ++i) - { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; - MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); + MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); - // Store it in the vector. - DtlsTransport::Fingerprint fingerprint; + // Store it in the vector. + DtlsTransport::Fingerprint fingerprint; - fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); - fingerprint.value = hexFingerprint; + fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); + fingerprint.value = hexFingerprint; - localFingerprints.push_back(fingerprint); - } - } + localFingerprints.push_back(fingerprint); + } + } - /* Instance methods. */ + /* Instance methods. */ - DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) - { - MS_TRACE(); + DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) + { + MS_TRACE(); env = DtlsEnvironment::Instance().shared_from_this(); - /* Set SSL. */ + /* Set SSL. */ - this->ssl = SSL_new(env->sslCtx); + this->ssl = SSL_new(env->sslCtx); - if (!this->ssl) - { - LOG_OPENSSL_ERROR("SSL_new() failed"); + if (!this->ssl) + { + LOG_OPENSSL_ERROR("SSL_new() failed"); - goto error; - } + goto error; + } - // Set this as custom data. - SSL_set_ex_data(this->ssl, 0, static_cast(this)); + // Set this as custom data. + SSL_set_ex_data(this->ssl, 0, static_cast(this)); - this->sslBioFromNetwork = BIO_new(BIO_s_mem()); + this->sslBioFromNetwork = BIO_new(BIO_s_mem()); - if (!this->sslBioFromNetwork) - { - LOG_OPENSSL_ERROR("BIO_new() failed"); + if (!this->sslBioFromNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); - SSL_free(this->ssl); + SSL_free(this->ssl); - goto error; - } - - this->sslBioToNetwork = BIO_new(BIO_s_mem()); - - if (!this->sslBioToNetwork) - { - LOG_OPENSSL_ERROR("BIO_new() failed"); - - BIO_free(this->sslBioFromNetwork); - SSL_free(this->ssl); - - goto error; - } - - SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); - - // Set the MTU so that we don't send packets that are too large with no fragmentation. - SSL_set_mtu(this->ssl, DtlsMtu); - DTLS_set_link_mtu(this->ssl, DtlsMtu); - - // Set callback handler for setting DTLS timer interval. - DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); - - return; - - error: - - // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as - // well. - if (this->sslBioFromNetwork) - BIO_free(this->sslBioFromNetwork); - - if (this->sslBioToNetwork) - BIO_free(this->sslBioToNetwork); - - if (this->ssl) - SSL_free(this->ssl); - - // NOTE: If this is not catched by the caller the program will abort, but - // this should never happen. - MS_THROW_ERROR("DtlsTransport instance creation failed"); - } - - DtlsTransport::~DtlsTransport() - { - MS_TRACE(); - - if (IsRunning()) - { - // Send close alert to the peer. - SSL_shutdown(this->ssl); - SendPendingOutgoingDtlsData(); - } - - if (this->ssl) - { - SSL_free(this->ssl); - - this->ssl = nullptr; - this->sslBioFromNetwork = nullptr; - this->sslBioToNetwork = nullptr; - } - - // Close the DTLS timer. - this->timer = nullptr; - } - - void DtlsTransport::Dump() const - { - MS_TRACE(); - - std::string state{ "new" }; - std::string role{ "none " }; - - switch (this->state) - { - case DtlsState::CONNECTING: - state = "connecting"; - break; - case DtlsState::CONNECTED: - state = "connected"; - break; - case DtlsState::FAILED: - state = "failed"; - break; - case DtlsState::CLOSED: - state = "closed"; - break; - default:; - } - - switch (this->localRole) - { - case Role::AUTO: - role = "auto"; - break; - case Role::SERVER: - role = "server"; - break; - case Role::CLIENT: - role = "client"; - break; - default:; - } - - MS_DUMP(""); - MS_DUMP(" state : %s", state.c_str()); - MS_DUMP(" role : %s", role.c_str()); - MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); - MS_DUMP(""); - } - - void DtlsTransport::Run(Role localRole) - { - MS_TRACE(); - - MS_ASSERT( - localRole == Role::CLIENT || localRole == Role::SERVER, - "local DTLS role must be 'client' or 'server'"); - - Role previousLocalRole = this->localRole; - - if (localRole == previousLocalRole) - { - MS_ERROR("same local DTLS role provided, doing nothing"); - - return; - } - - // If the previous local DTLS role was 'client' or 'server' do reset. - if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) - { - MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); - - Reset(); - } - - // Update local role. - this->localRole = localRole; - - // Set state and notify the listener. - this->state = DtlsState::CONNECTING; - this->listener->OnDtlsTransportConnecting(this); - - switch (this->localRole) - { - case Role::CLIENT: - { - MS_DEBUG_TAG(dtls, "running [role:client]"); - - SSL_set_connect_state(this->ssl); - SSL_do_handshake(this->ssl); - SendPendingOutgoingDtlsData(); - SetTimeout(); - - break; - } - - case Role::SERVER: - { - MS_DEBUG_TAG(dtls, "running [role:server]"); - - SSL_set_accept_state(this->ssl); - SSL_do_handshake(this->ssl); - - break; - } - - default: - { - MS_ABORT("invalid local DTLS role"); - } - } - } - - bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) - { - MS_TRACE(); - - MS_ASSERT( - fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); - - this->remoteFingerprint = fingerprint; - - // The remote fingerpring may have been set after DTLS handshake was done, - // so we may need to process it now. - if (this->handshakeDone && this->state != DtlsState::CONNECTED) - { - MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); - - return ProcessHandshake(); - } - - return true; - } - - void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) - { - MS_TRACE(); - - int written; - int read; + goto error; + } + + this->sslBioToNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioToNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + BIO_free(this->sslBioFromNetwork); + SSL_free(this->ssl); + + goto error; + } + + SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); + + // Set the MTU so that we don't send packets that are too large with no fragmentation. + SSL_set_mtu(this->ssl, DtlsMtu); + DTLS_set_link_mtu(this->ssl, DtlsMtu); + + // Set callback handler for setting DTLS timer interval. + DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); + + return; + + error: + + // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as + // well. + if (this->sslBioFromNetwork) + BIO_free(this->sslBioFromNetwork); + + if (this->sslBioToNetwork) + BIO_free(this->sslBioToNetwork); + + if (this->ssl) + SSL_free(this->ssl); + + // NOTE: If this is not catched by the caller the program will abort, but + // this should never happen. + MS_THROW_ERROR("DtlsTransport instance creation failed"); + } + + DtlsTransport::~DtlsTransport() + { + MS_TRACE(); + + if (IsRunning()) + { + // Send close alert to the peer. + SSL_shutdown(this->ssl); + SendPendingOutgoingDtlsData(); + } + + if (this->ssl) + { + SSL_free(this->ssl); + + this->ssl = nullptr; + this->sslBioFromNetwork = nullptr; + this->sslBioToNetwork = nullptr; + } + + // Close the DTLS timer. + this->timer = nullptr; + } + + void DtlsTransport::Dump() const + { + MS_TRACE(); + + std::string state{ "new" }; + std::string role{ "none " }; + + switch (this->state) + { + case DtlsState::CONNECTING: + state = "connecting"; + break; + case DtlsState::CONNECTED: + state = "connected"; + break; + case DtlsState::FAILED: + state = "failed"; + break; + case DtlsState::CLOSED: + state = "closed"; + break; + default:; + } + + switch (this->localRole) + { + case Role::AUTO: + role = "auto"; + break; + case Role::SERVER: + role = "server"; + break; + case Role::CLIENT: + role = "client"; + break; + default:; + } + + MS_DUMP(""); + MS_DUMP(" state : %s", state.c_str()); + MS_DUMP(" role : %s", role.c_str()); + MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); + MS_DUMP(""); + } + + void DtlsTransport::Run(Role localRole) + { + MS_TRACE(); + + MS_ASSERT( + localRole == Role::CLIENT || localRole == Role::SERVER, + "local DTLS role must be 'client' or 'server'"); + + Role previousLocalRole = this->localRole; + + if (localRole == previousLocalRole) + { + MS_ERROR("same local DTLS role provided, doing nothing"); + + return; + } + + // If the previous local DTLS role was 'client' or 'server' do reset. + if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) + { + MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); + + Reset(); + } + + // Update local role. + this->localRole = localRole; + + // Set state and notify the listener. + this->state = DtlsState::CONNECTING; + this->listener->OnDtlsTransportConnecting(this); + + switch (this->localRole) + { + case Role::CLIENT: + { + MS_DEBUG_TAG(dtls, "running [role:client]"); + + SSL_set_connect_state(this->ssl); + SSL_do_handshake(this->ssl); + SendPendingOutgoingDtlsData(); + SetTimeout(); + + break; + } + + case Role::SERVER: + { + MS_DEBUG_TAG(dtls, "running [role:server]"); + + SSL_set_accept_state(this->ssl); + SSL_do_handshake(this->ssl); + + break; + } + + default: + { + MS_ABORT("invalid local DTLS role"); + } + } + } + + bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) + { + MS_TRACE(); + + MS_ASSERT( + fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); + + this->remoteFingerprint = fingerprint; + + // The remote fingerpring may have been set after DTLS handshake was done, + // so we may need to process it now. + if (this->handshakeDone && this->state != DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); + + return ProcessHandshake(); + } + + return true; + } + + void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) + { + MS_TRACE(); + + int written; + int read; - if (!IsRunning()) - { - MS_ERROR("cannot process data while not running"); + if (!IsRunning()) + { + MS_ERROR("cannot process data while not running"); - return; - } + return; + } - // Write the received DTLS data into the sslBioFromNetwork. - written = - BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); + // Write the received DTLS data into the sslBioFromNetwork. + written = + BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); - if (written != static_cast(len)) - { - MS_WARN_TAG( - dtls, - "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", - static_cast(written), - len); - } - - // Must call SSL_read() to process received DTLS data. - read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); + if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, + "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", + static_cast(written), + len); + } + + // Must call SSL_read() to process received DTLS data. + read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); - // Send data if it's ready. - SendPendingOutgoingDtlsData(); + // Send data if it's ready. + SendPendingOutgoingDtlsData(); - // Check SSL status and return if it is bad/closed. - if (!CheckStatus(read)) - return; + // Check SSL status and return if it is bad/closed. + if (!CheckStatus(read)) + return; - // Set/update the DTLS timeout. - if (!SetTimeout()) - return; + // Set/update the DTLS timeout. + if (!SetTimeout()) + return; - // Application data received. Notify to the listener. - if (read > 0) - { - // It is allowed to receive DTLS data even before validating remote fingerprint. - if (!this->handshakeDone) - { - MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); - - return; - } + // Application data received. Notify to the listener. + if (read > 0) + { + // It is allowed to receive DTLS data even before validating remote fingerprint. + if (!this->handshakeDone) + { + MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); + + return; + } - // Notify the listener. - this->listener->OnDtlsTransportApplicationDataReceived( - this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); - } - } + // Notify the listener. + this->listener->OnDtlsTransportApplicationDataReceived( + this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); + } + } - void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) - { - MS_TRACE(); + void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) + { + MS_TRACE(); - // We cannot send data to the peer if its remote fingerprint is not validated. - if (this->state != DtlsState::CONNECTED) - { - MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); + // We cannot send data to the peer if its remote fingerprint is not validated. + if (this->state != DtlsState::CONNECTED) + { + MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); - return; - } + return; + } - if (len == 0) - { - MS_WARN_TAG(dtls, "ignoring 0 length data"); + if (len == 0) + { + MS_WARN_TAG(dtls, "ignoring 0 length data"); - return; - } + return; + } - int written; + int written; - written = SSL_write(this->ssl, static_cast(data), static_cast(len)); + written = SSL_write(this->ssl, static_cast(data), static_cast(len)); - if (written < 0) - { - LOG_OPENSSL_ERROR("SSL_write() failed"); + if (written < 0) + { + LOG_OPENSSL_ERROR("SSL_write() failed"); - if (!CheckStatus(written)) - return; - } - else if (written != static_cast(len)) - { - MS_WARN_TAG( - dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); - } + if (!CheckStatus(written)) + return; + } + else if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); + } - // Send data. - SendPendingOutgoingDtlsData(); - } - - void DtlsTransport::Reset() - { - MS_TRACE(); - - int ret; + // Send data. + SendPendingOutgoingDtlsData(); + } + + void DtlsTransport::Reset() + { + MS_TRACE(); + + int ret; - if (!IsRunning()) - return; - - MS_WARN_TAG(dtls, "resetting DTLS transport"); - - // Stop the DTLS timer. - this->timer = nullptr; - - // We need to reset the SSL instance so we need to "shutdown" it, but we - // don't want to send a Close Alert to the peer, so just don't call - // SendPendingOutgoingDTLSData(). - SSL_shutdown(this->ssl); - - this->localRole = Role::NONE; - this->state = DtlsState::NEW; - this->handshakeDone = false; - this->handshakeDoneNow = false; - - // Reset SSL status. - // NOTE: For this to properly work, SSL_shutdown() must be called before. - // NOTE: This may fail if not enough DTLS handshake data has been received, - // but we don't care so just clear the error queue. - ret = SSL_clear(this->ssl); - - if (ret == 0) - ERR_clear_error(); - } - - inline bool DtlsTransport::CheckStatus(int returnCode) - { - MS_TRACE(); - - int err; - bool wasHandshakeDone = this->handshakeDone; - - err = SSL_get_error(this->ssl, returnCode); - - switch (err) - { - case SSL_ERROR_NONE: - break; + if (!IsRunning()) + return; + + MS_WARN_TAG(dtls, "resetting DTLS transport"); + + // Stop the DTLS timer. + this->timer = nullptr; + + // We need to reset the SSL instance so we need to "shutdown" it, but we + // don't want to send a Close Alert to the peer, so just don't call + // SendPendingOutgoingDTLSData(). + SSL_shutdown(this->ssl); + + this->localRole = Role::NONE; + this->state = DtlsState::NEW; + this->handshakeDone = false; + this->handshakeDoneNow = false; + + // Reset SSL status. + // NOTE: For this to properly work, SSL_shutdown() must be called before. + // NOTE: This may fail if not enough DTLS handshake data has been received, + // but we don't care so just clear the error queue. + ret = SSL_clear(this->ssl); + + if (ret == 0) + ERR_clear_error(); + } + + inline bool DtlsTransport::CheckStatus(int returnCode) + { + MS_TRACE(); + + int err; + bool wasHandshakeDone = this->handshakeDone; + + err = SSL_get_error(this->ssl, returnCode); + + switch (err) + { + case SSL_ERROR_NONE: + break; - case SSL_ERROR_SSL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); - break; + case SSL_ERROR_SSL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); + break; - case SSL_ERROR_WANT_READ: - break; + case SSL_ERROR_WANT_READ: + break; - case SSL_ERROR_WANT_WRITE: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); - break; + case SSL_ERROR_WANT_WRITE: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); + break; - case SSL_ERROR_WANT_X509_LOOKUP: - MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); - break; + case SSL_ERROR_WANT_X509_LOOKUP: + MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); + break; - case SSL_ERROR_SYSCALL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); - break; + case SSL_ERROR_SYSCALL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); + break; - case SSL_ERROR_ZERO_RETURN: - break; + case SSL_ERROR_ZERO_RETURN: + break; - case SSL_ERROR_WANT_CONNECT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); - break; + case SSL_ERROR_WANT_CONNECT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); + break; - case SSL_ERROR_WANT_ACCEPT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); - break; + case SSL_ERROR_WANT_ACCEPT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); + break; - default: - MS_WARN_TAG(dtls, "SSL status: unknown error"); - } + default: + MS_WARN_TAG(dtls, "SSL status: unknown error"); + } - // Check if the handshake (or re-handshake) has been done right now. - if (this->handshakeDoneNow) - { - this->handshakeDoneNow = false; - this->handshakeDone = true; - - // Stop the timer. - this->timer = nullptr; + // Check if the handshake (or re-handshake) has been done right now. + if (this->handshakeDoneNow) + { + this->handshakeDoneNow = false; + this->handshakeDone = true; + + // Stop the timer. + this->timer = nullptr; - // Process the handshake just once (ignore if DTLS renegotiation). - if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) - return ProcessHandshake(); + // Process the handshake just once (ignore if DTLS renegotiation). + if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) + return ProcessHandshake(); - return true; - } - // Check if the peer sent close alert or a fatal error happened. - else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) - { - if (this->state == DtlsState::CONNECTED) - { - MS_DEBUG_TAG(dtls, "disconnected"); + return true; + } + // Check if the peer sent close alert or a fatal error happened. + else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) + { + if (this->state == DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "disconnected"); - Reset(); + Reset(); - // Set state and notify the listener. - this->state = DtlsState::CLOSED; - this->listener->OnDtlsTransportClosed(this); - } - else - { - MS_WARN_TAG(dtls, "connection failed"); + // Set state and notify the listener. + this->state = DtlsState::CLOSED; + this->listener->OnDtlsTransportClosed(this); + } + else + { + MS_WARN_TAG(dtls, "connection failed"); - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - } - - return false; - } - else - { - return true; - } - } + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + } + + return false; + } + else + { + return true; + } + } - inline void DtlsTransport::SendPendingOutgoingDtlsData() - { - MS_TRACE(); + inline void DtlsTransport::SendPendingOutgoingDtlsData() + { + MS_TRACE(); - if (BIO_eof(this->sslBioToNetwork)) - return; + if (BIO_eof(this->sslBioToNetwork)) + return; - int64_t read; - char* data{ nullptr }; + int64_t read; + char* data{ nullptr }; - read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT + read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT - if (read <= 0) - return; - - MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); + if (read <= 0) + return; + + MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); - // Notify the listener. - this->listener->OnDtlsTransportSendData( - this, reinterpret_cast(data), static_cast(read)); + // Notify the listener. + this->listener->OnDtlsTransportSendData( + this, reinterpret_cast(data), static_cast(read)); - // Clear the BIO buffer. - // NOTE: the (void) avoids the -Wunused-value warning. - (void)BIO_reset(this->sslBioToNetwork); - } + // Clear the BIO buffer. + // NOTE: the (void) avoids the -Wunused-value warning. + (void)BIO_reset(this->sslBioToNetwork); + } - inline bool DtlsTransport::SetTimeout() - { - MS_TRACE(); + inline bool DtlsTransport::SetTimeout() + { + MS_TRACE(); - MS_ASSERT( - this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, - "invalid DTLS state"); + MS_ASSERT( + this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, + "invalid DTLS state"); - int64_t ret; + int64_t ret; struct timeval dtlsTimeout{ 0, 0 }; - uint64_t timeoutMs; + uint64_t timeoutMs; - // NOTE: If ret == 0 then ignore the value in dtlsTimeout. - // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. - ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT + // NOTE: If ret == 0 then ignore the value in dtlsTimeout. + // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. + ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT - if (ret == 0) - return true; + if (ret == 0) + return true; - timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); + timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); - if (timeoutMs == 0) - { - return true; - } - else if (timeoutMs < 30000) - { - MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); + if (timeoutMs == 0) + { + return true; + } + else if (timeoutMs < 30000) + { + MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); - weak_ptr weak_self = shared_from_this(); - this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ - auto strong_self = weak_self.lock(); - if(strong_self){ + weak_ptr weak_self = shared_from_this(); + this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ + auto strong_self = weak_self.lock(); + if(strong_self){ strong_self->OnTimer(); - } + } return true; - }, this->poller); + }, this->poller); - return true; - } - // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. - else - { - MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); + return true; + } + // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. + else + { + MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); - Reset(); + Reset(); - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); - return false; - } - } + return false; + } + } - inline bool DtlsTransport::ProcessHandshake() - { - MS_TRACE(); + inline bool DtlsTransport::ProcessHandshake() + { + MS_TRACE(); - MS_ASSERT(this->handshakeDone, "handshake not done yet"); - MS_ASSERT( - this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + MS_ASSERT(this->handshakeDone, "handshake not done yet"); + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); - // Validate the remote fingerprint. - if (!CheckRemoteFingerprint()) - { - Reset(); + // Validate the remote fingerprint. + if (!CheckRemoteFingerprint()) + { + Reset(); - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); - return false; - } + return false; + } - // Get the negotiated SRTP crypto suite. - RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); + // Get the negotiated SRTP crypto suite. + RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); - if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) - { - // Extract the SRTP keys (will notify the listener with them). - ExtractSrtpKeys(srtpCryptoSuite); + if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) + { + // Extract the SRTP keys (will notify the listener with them). + ExtractSrtpKeys(srtpCryptoSuite); - return true; - } + return true; + } - // NOTE: We assume that "use_srtp" DTLS extension is required even if - // there is no audio/video. - MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); + // NOTE: We assume that "use_srtp" DTLS extension is required even if + // there is no audio/video. + MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); - Reset(); + Reset(); - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); - return false; - } + return false; + } - inline bool DtlsTransport::CheckRemoteFingerprint() - { - MS_TRACE(); + inline bool DtlsTransport::CheckRemoteFingerprint() + { + MS_TRACE(); - MS_ASSERT( - this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); - X509* certificate; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{ 0 }; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; + X509* certificate; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; - certificate = SSL_get_peer_certificate(this->ssl); + certificate = SSL_get_peer_certificate(this->ssl); - if (!certificate) - { - MS_WARN_TAG(dtls, "no certificate was provided by the peer"); + if (!certificate) + { + MS_WARN_TAG(dtls, "no certificate was provided by the peer"); - return false; - } + return false; + } - switch (this->remoteFingerprint.algorithm) - { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; + switch (this->remoteFingerprint.algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; - default: - MS_ABORT("unknown algorithm"); - } + default: + MS_ABORT("unknown algorithm"); + } - // Compare the remote fingerprint with the value given via signaling. - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + // Compare the remote fingerprint with the value given via signaling. + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - if (ret == 0) - { - MS_ERROR("X509_digest() failed"); + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); - X509_free(certificate); + X509_free(certificate); - return false; - } + return false; + } - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{ 0 }; i < size; ++i) - { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; - if (this->remoteFingerprint.value != hexFingerprint) - { - MS_WARN_TAG( - dtls, - "fingerprint in the remote certificate (%s) does not match the announced one (%s)", - hexFingerprint, - this->remoteFingerprint.value.c_str()); - X509_free(certificate); - return false; - } + if (this->remoteFingerprint.value != hexFingerprint) + { + MS_WARN_TAG( + dtls, + "fingerprint in the remote certificate (%s) does not match the announced one (%s)", + hexFingerprint, + this->remoteFingerprint.value.c_str()); + X509_free(certificate); + return false; + } - MS_DEBUG_TAG(dtls, "valid remote fingerprint"); + MS_DEBUG_TAG(dtls, "valid remote fingerprint"); - // Get the remote certificate in PEM format. + // Get the remote certificate in PEM format. - BIO* bio = BIO_new(BIO_s_mem()); + BIO* bio = BIO_new(BIO_s_mem()); - // Ensure the underlying BUF_MEM structure is also freed. - // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since - // BIO_set_close() always returns 1. - (void)BIO_set_close(bio, BIO_CLOSE); + // Ensure the underlying BUF_MEM structure is also freed. + // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since + // BIO_set_close() always returns 1. + (void)BIO_set_close(bio, BIO_CLOSE); - ret = PEM_write_bio_X509(bio, certificate); + ret = PEM_write_bio_X509(bio, certificate); - if (ret != 1) - { - LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); + if (ret != 1) + { + LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); - X509_free(certificate); - BIO_free(bio); - - return false; - } - - BUF_MEM* mem; + X509_free(certificate); + BIO_free(bio); + + return false; + } + + BUF_MEM* mem; - BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] - - if (!mem || !mem->data || mem->length == 0u) - { - LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); - - X509_free(certificate); - BIO_free(bio); - - return false; - } - - this->remoteCert = std::string(mem->data, mem->length); - - X509_free(certificate); - BIO_free(bio); - - return true; - } - - inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) - { - MS_TRACE(); - - size_t srtpKeyLength{ 0 }; - size_t srtpSaltLength{ 0 }; - size_t srtpMasterLength{ 0 }; - - switch (srtpCryptoSuite) - { - case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: - case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: - { - srtpKeyLength = SrtpMasterKeyLength; - srtpSaltLength = SrtpMasterSaltLength; - srtpMasterLength = SrtpMasterLength; - - break; - } - - case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: - { - srtpKeyLength = SrtpAesGcm256MasterKeyLength; - srtpSaltLength = SrtpAesGcm256MasterSaltLength; - srtpMasterLength = SrtpAesGcm256MasterLength; - - break; - } - - case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: - { - srtpKeyLength = SrtpAesGcm128MasterKeyLength; - srtpSaltLength = SrtpAesGcm128MasterSaltLength; - srtpMasterLength = SrtpAesGcm128MasterLength; - - break; - } - - default: - { - MS_ABORT("unknown SRTP crypto suite"); - } - } - - auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; - uint8_t* srtpLocalKey{ nullptr }; - uint8_t* srtpLocalSalt{ nullptr }; - uint8_t* srtpRemoteKey{ nullptr }; - uint8_t* srtpRemoteSalt{ nullptr }; - auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; - auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; - int ret; - - ret = SSL_export_keying_material( - this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); - - MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); - - switch (this->localRole) - { - case Role::SERVER: - { - srtpRemoteKey = srtpMaterial; - srtpLocalKey = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; - - break; - } - - case Role::CLIENT: - { - srtpLocalKey = srtpMaterial; - srtpRemoteKey = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; - - break; - } - - default: - { - MS_ABORT("no DTLS role set"); - } - } - - // Create the SRTP local master key. - std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); - std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); - // Create the SRTP remote master key. - std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); - std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); - - // Set state and notify the listener. - this->state = DtlsState::CONNECTED; - this->listener->OnDtlsTransportConnected( - this, - srtpCryptoSuite, - srtpLocalMasterKey, - srtpMasterLength, - srtpRemoteMasterKey, - srtpMasterLength, - this->remoteCert); - - delete[] srtpMaterial; - delete[] srtpLocalMasterKey; - delete[] srtpRemoteMasterKey; - } - - inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() - { - MS_TRACE(); - - RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; - - // Ensure that the SRTP crypto suite has been negotiated. - // NOTE: This is a OpenSSL type. - SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); - - if (!sslSrtpCryptoSuite) - return negotiatedSrtpCryptoSuite; - - // Get the negotiated SRTP crypto suite. - for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) - { - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); - - if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) - { - MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); - - negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; - } - } - - MS_ASSERT( - negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, - "chosen SRTP crypto suite is not an available one"); - - return negotiatedSrtpCryptoSuite; - } - - inline void DtlsTransport::OnSslInfo(int where, int ret) - { - MS_TRACE(); - - int w = where & -SSL_ST_MASK; - const char* role; - - if ((w & SSL_ST_CONNECT) != 0) - role = "client"; - else if ((w & SSL_ST_ACCEPT) != 0) - role = "server"; - else - role = "undefined"; - - if ((where & SSL_CB_LOOP) != 0) - { - MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); - } - else if ((where & SSL_CB_ALERT) != 0) - { - const char* alertType; - - switch (*SSL_alert_type_string(ret)) - { - case 'W': - alertType = "warning"; - break; - - case 'F': - alertType = "fatal"; - break; - - default: - alertType = "undefined"; - } - - if ((where & SSL_CB_READ) != 0) - { - MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - else if ((where & SSL_CB_WRITE) != 0) - { - MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - else - { - MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - } - else if ((where & SSL_CB_EXIT) != 0) - { - if (ret == 0) - MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); - else if (ret < 0) - MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); - } - else if ((where & SSL_CB_HANDSHAKE_START) != 0) - { - MS_DEBUG_TAG(dtls, "DTLS handshake start"); - } - else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) - { - MS_DEBUG_TAG(dtls, "DTLS handshake done"); - - this->handshakeDoneNow = true; - } - - // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon - // receipt of a close alert does not work (the flag is set after this callback). - } - - inline void DtlsTransport::OnTimer() - { - MS_TRACE(); - - // Workaround for https://github.com/openssl/openssl/issues/7998. - if (this->handshakeDone) - { - MS_DEBUG_DEV("handshake is done so return"); - - return; - } - - DTLSv1_handle_timeout(this->ssl); - - // If required, send DTLS data. - SendPendingOutgoingDtlsData(); - - // Set the DTLS timer again. - SetTimeout(); - } + BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] + + if (!mem || !mem->data || mem->length == 0u) + { + LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); + + X509_free(certificate); + BIO_free(bio); + + return false; + } + + this->remoteCert = std::string(mem->data, mem->length); + + X509_free(certificate); + BIO_free(bio); + + return true; + } + + inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) + { + MS_TRACE(); + + size_t srtpKeyLength{ 0 }; + size_t srtpSaltLength{ 0 }; + size_t srtpMasterLength{ 0 }; + + switch (srtpCryptoSuite) + { + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: + { + srtpKeyLength = SrtpMasterKeyLength; + srtpSaltLength = SrtpMasterSaltLength; + srtpMasterLength = SrtpMasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: + { + srtpKeyLength = SrtpAesGcm256MasterKeyLength; + srtpSaltLength = SrtpAesGcm256MasterSaltLength; + srtpMasterLength = SrtpAesGcm256MasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: + { + srtpKeyLength = SrtpAesGcm128MasterKeyLength; + srtpSaltLength = SrtpAesGcm128MasterSaltLength; + srtpMasterLength = SrtpAesGcm128MasterLength; + + break; + } + + default: + { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; + uint8_t* srtpLocalKey{ nullptr }; + uint8_t* srtpLocalSalt{ nullptr }; + uint8_t* srtpRemoteKey{ nullptr }; + uint8_t* srtpRemoteSalt{ nullptr }; + auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; + auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; + int ret; + + ret = SSL_export_keying_material( + this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); + + MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); + + switch (this->localRole) + { + case Role::SERVER: + { + srtpRemoteKey = srtpMaterial; + srtpLocalKey = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; + + break; + } + + case Role::CLIENT: + { + srtpLocalKey = srtpMaterial; + srtpRemoteKey = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; + + break; + } + + default: + { + MS_ABORT("no DTLS role set"); + } + } + + // Create the SRTP local master key. + std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); + std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); + // Create the SRTP remote master key. + std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); + std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); + + // Set state and notify the listener. + this->state = DtlsState::CONNECTED; + this->listener->OnDtlsTransportConnected( + this, + srtpCryptoSuite, + srtpLocalMasterKey, + srtpMasterLength, + srtpRemoteMasterKey, + srtpMasterLength, + this->remoteCert); + + delete[] srtpMaterial; + delete[] srtpLocalMasterKey; + delete[] srtpRemoteMasterKey; + } + + inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() + { + MS_TRACE(); + + RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; + + // Ensure that the SRTP crypto suite has been negotiated. + // NOTE: This is a OpenSSL type. + SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); + + if (!sslSrtpCryptoSuite) + return negotiatedSrtpCryptoSuite; + + // Get the negotiated SRTP crypto suite. + for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) + { + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); + + if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) + { + MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); + + negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; + } + } + + MS_ASSERT( + negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, + "chosen SRTP crypto suite is not an available one"); + + return negotiatedSrtpCryptoSuite; + } + + inline void DtlsTransport::OnSslInfo(int where, int ret) + { + MS_TRACE(); + + int w = where & -SSL_ST_MASK; + const char* role; + + if ((w & SSL_ST_CONNECT) != 0) + role = "client"; + else if ((w & SSL_ST_ACCEPT) != 0) + role = "server"; + else + role = "undefined"; + + if ((where & SSL_CB_LOOP) != 0) + { + MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_ALERT) != 0) + { + const char* alertType; + + switch (*SSL_alert_type_string(ret)) + { + case 'W': + alertType = "warning"; + break; + + case 'F': + alertType = "fatal"; + break; + + default: + alertType = "undefined"; + } + + if ((where & SSL_CB_READ) != 0) + { + MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else if ((where & SSL_CB_WRITE) != 0) + { + MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else + { + MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + } + else if ((where & SSL_CB_EXIT) != 0) + { + if (ret == 0) + MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); + else if (ret < 0) + MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_HANDSHAKE_START) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake start"); + } + else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake done"); + + this->handshakeDoneNow = true; + } + + // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon + // receipt of a close alert does not work (the flag is set after this callback). + } + + inline void DtlsTransport::OnTimer() + { + MS_TRACE(); + + // Workaround for https://github.com/openssl/openssl/issues/7998. + if (this->handshakeDone) + { + // MS_DEBUG_DEV("handshake is done so return"); + return; + } + + DTLSv1_handle_timeout(this->ssl); + + // If required, send DTLS data. + SendPendingOutgoingDtlsData(); + + // Set the DTLS timer again. + SetTimeout(); + } } // namespace RTC diff --git a/webrtc/DtlsTransport.hpp b/webrtc/DtlsTransport.hpp index fb28a6a4..bf57d01d 100644 --- a/webrtc/DtlsTransport.hpp +++ b/webrtc/DtlsTransport.hpp @@ -33,50 +33,50 @@ using namespace toolkit; namespace RTC { class DtlsTransport : public std::enable_shared_from_this - { - public: - enum class DtlsState - { - NEW = 1, - CONNECTING, - CONNECTED, - FAILED, - CLOSED - }; + { + public: + enum class DtlsState + { + NEW = 1, + CONNECTING, + CONNECTED, + FAILED, + CLOSED + }; - public: - enum class Role - { - NONE = 0, - AUTO = 1, - CLIENT, - SERVER - }; + public: + enum class Role + { + NONE = 0, + AUTO = 1, + CLIENT, + SERVER + }; - public: - enum class FingerprintAlgorithm - { - NONE = 0, - SHA1 = 1, - SHA224, - SHA256, - SHA384, - SHA512 - }; + public: + enum class FingerprintAlgorithm + { + NONE = 0, + SHA1 = 1, + SHA224, + SHA256, + SHA384, + SHA512 + }; - public: - struct Fingerprint - { - FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; - std::string value; - }; + public: + struct Fingerprint + { + FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; + std::string value; + }; - private: - struct SrtpCryptoSuiteMapEntry - { - RTC::SrtpSession::CryptoSuite cryptoSuite; - const char* name; - }; + private: + struct SrtpCryptoSuiteMapEntry + { + RTC::SrtpSession::CryptoSuite cryptoSuite; + const char* name; + }; class DtlsEnvironment : public std::enable_shared_from_this { @@ -99,154 +99,154 @@ namespace RTC std::vector localFingerprints; }; - public: - class Listener - { - public: - // DTLS is in the process of negotiating a secure connection. Incoming - // media can flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; - // DTLS has completed negotiation of a secure connection (including DTLS-SRTP - // and remote fingerprint verification). Outgoing media can now flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnected( - const RTC::DtlsTransport* dtlsTransport, - RTC::SrtpSession::CryptoSuite srtpCryptoSuite, - uint8_t* srtpLocalKey, - size_t srtpLocalKeyLen, - uint8_t* srtpRemoteKey, - size_t srtpRemoteKeyLen, - std::string& remoteCert) = 0; - // The DTLS connection has been closed as the result of an error (such as a - // DTLS alert or a failure to validate the remote fingerprint). - virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; - // The DTLS connection has been closed due to receipt of a close_notify alert. - virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; - // Need to send DTLS data to the peer. - virtual void OnDtlsTransportSendData( - const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; - // DTLS application data received. - virtual void OnDtlsTransportApplicationDataReceived( - const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; - }; + public: + class Listener + { + public: + // DTLS is in the process of negotiating a secure connection. Incoming + // media can flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; + // DTLS has completed negotiation of a secure connection (including DTLS-SRTP + // and remote fingerprint verification). Outgoing media can now flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnected( + const RTC::DtlsTransport* dtlsTransport, + RTC::SrtpSession::CryptoSuite srtpCryptoSuite, + uint8_t* srtpLocalKey, + size_t srtpLocalKeyLen, + uint8_t* srtpRemoteKey, + size_t srtpRemoteKeyLen, + std::string& remoteCert) = 0; + // The DTLS connection has been closed as the result of an error (such as a + // DTLS alert or a failure to validate the remote fingerprint). + virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; + // The DTLS connection has been closed due to receipt of a close_notify alert. + virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; + // Need to send DTLS data to the peer. + virtual void OnDtlsTransportSendData( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + // DTLS application data received. + virtual void OnDtlsTransportApplicationDataReceived( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + }; - public: - static Role StringToRole(const std::string& role) - { - auto it = DtlsTransport::string2Role.find(role); + public: + static Role StringToRole(const std::string& role) + { + auto it = DtlsTransport::string2Role.find(role); - if (it != DtlsTransport::string2Role.end()) - return it->second; - else - return DtlsTransport::Role::NONE; - } - static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) - { - auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); + if (it != DtlsTransport::string2Role.end()) + return it->second; + else + return DtlsTransport::Role::NONE; + } + static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) + { + auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); - if (it != DtlsTransport::string2FingerprintAlgorithm.end()) - return it->second; - else - return DtlsTransport::FingerprintAlgorithm::NONE; - } - static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) - { - auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); + if (it != DtlsTransport::string2FingerprintAlgorithm.end()) + return it->second; + else + return DtlsTransport::FingerprintAlgorithm::NONE; + } + static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) + { + auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); - return it->second; - } - static bool IsDtls(const uint8_t* data, size_t len) - { - // clang-format off - return ( - // Minimum DTLS record length is 13 bytes. - (len >= 13) && - // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes - (data[0] > 19 && data[0] < 64) - ); - // clang-format on - } - - private: - static std::map string2Role; - static std::map string2FingerprintAlgorithm; - static std::map fingerprintAlgorithm2String; - static std::vector srtpCryptoSuites; - - public: - DtlsTransport(EventPoller::Ptr poller, Listener* listener); - ~DtlsTransport(); - - public: - void Dump() const; - void Run(Role localRole); - std::vector& GetLocalFingerprints() const - { - return env->localFingerprints; - } - bool SetRemoteFingerprint(Fingerprint fingerprint); - void ProcessDtlsData(const uint8_t* data, size_t len); - DtlsState GetState() const - { - return this->state; - } - Role GetLocalRole() const - { - return this->localRole; - } - void SendApplicationData(const uint8_t* data, size_t len); - - private: - bool IsRunning() const - { - switch (this->state) - { - case DtlsState::NEW: - return false; - case DtlsState::CONNECTING: - case DtlsState::CONNECTED: - return true; - case DtlsState::FAILED: - case DtlsState::CLOSED: - return false; - } - - // Make GCC 4.9 happy. - return false; - } - void Reset(); - bool CheckStatus(int returnCode); - void SendPendingOutgoingDtlsData(); - bool SetTimeout(); - bool ProcessHandshake(); - bool CheckRemoteFingerprint(); - void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); - RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + return it->second; + } + static bool IsDtls(const uint8_t* data, size_t len) + { + // clang-format off + return ( + // Minimum DTLS record length is 13 bytes. + (len >= 13) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] > 19 && data[0] < 64) + ); + // clang-format on + } private: - void OnSslInfo(int where, int ret); - void OnTimer(); + static std::map string2Role; + static std::map string2FingerprintAlgorithm; + static std::map fingerprintAlgorithm2String; + static std::vector srtpCryptoSuites; - private: + public: + DtlsTransport(EventPoller::Ptr poller, Listener* listener); + ~DtlsTransport(); + + public: + void Dump() const; + void Run(Role localRole); + std::vector& GetLocalFingerprints() const + { + return env->localFingerprints; + } + bool SetRemoteFingerprint(Fingerprint fingerprint); + void ProcessDtlsData(const uint8_t* data, size_t len); + DtlsState GetState() const + { + return this->state; + } + Role GetLocalRole() const + { + return this->localRole; + } + void SendApplicationData(const uint8_t* data, size_t len); + + private: + bool IsRunning() const + { + switch (this->state) + { + case DtlsState::NEW: + return false; + case DtlsState::CONNECTING: + case DtlsState::CONNECTED: + return true; + case DtlsState::FAILED: + case DtlsState::CLOSED: + return false; + } + + // Make GCC 4.9 happy. + return false; + } + void Reset(); + bool CheckStatus(int returnCode); + void SendPendingOutgoingDtlsData(); + bool SetTimeout(); + bool ProcessHandshake(); + bool CheckRemoteFingerprint(); + void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); + RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + + private: + void OnSslInfo(int where, int ret); + void OnTimer(); + + private: DtlsEnvironment::Ptr env; EventPoller::Ptr poller; // Passed by argument. - Listener* listener{ nullptr }; - // Allocated by this. - SSL* ssl{ nullptr }; - BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. - BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. - Timer::Ptr timer; - // Others. - DtlsState state{ DtlsState::NEW }; - Role localRole{ Role::NONE }; - Fingerprint remoteFingerprint; - bool handshakeDone{ false }; - bool handshakeDoneNow{ false }; - std::string remoteCert; - //最大不超过mtu - static constexpr int SslReadBufferSize{ 2000 }; + Listener* listener{ nullptr }; + // Allocated by this. + SSL* ssl{ nullptr }; + BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. + BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. + Timer::Ptr timer; + // Others. + DtlsState state{ DtlsState::NEW }; + Role localRole{ Role::NONE }; + Fingerprint remoteFingerprint; + bool handshakeDone{ false }; + bool handshakeDoneNow{ false }; + std::string remoteCert; + //最大不超过mtu + static constexpr int SslReadBufferSize{ 2000 }; uint8_t sslReadBuffer[SslReadBufferSize]; }; } // namespace RTC diff --git a/webrtc/IceServer.cpp b/webrtc/IceServer.cpp index 48709ab2..f0f79358 100644 --- a/webrtc/IceServer.cpp +++ b/webrtc/IceServer.cpp @@ -24,503 +24,505 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. namespace RTC { - /* Static. */ - /* Instance methods. */ + /* Static. */ + /* Instance methods. */ - IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) - : listener(listener), usernameFragment(usernameFragment), password(password) - { - MS_TRACE(); - } + IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) + : listener(listener), usernameFragment(usernameFragment), password(password) + { + MS_TRACE(); + } - void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) - { - MS_TRACE(); + void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) + { + MS_TRACE(); - // Must be a Binding method. - if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) - { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) - { - MS_WARN_TAG( - ice, - "unknown method %#.3x in STUN Request => 400", - static_cast(packet->GetMethod())); + // Must be a Binding method. + if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) + { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) + { + MS_WARN_TAG( + ice, + "unknown method %#.3x in STUN Request => 400", + static_cast(packet->GetMethod())); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; - } - else - { - MS_WARN_TAG( - ice, - "ignoring STUN Indication or Response with unknown method %#.3x", - static_cast(packet->GetMethod())); - } + delete response; + } + else + { + MS_WARN_TAG( + ice, + "ignoring STUN Indication or Response with unknown method %#.3x", + static_cast(packet->GetMethod())); + } - return; - } + return; + } - // Must use FINGERPRINT (optional for ICE STUN indications). - if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) - { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) - { - MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); + // Must use FINGERPRINT (optional for ICE STUN indications). + if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) + { + if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) + { + MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; - } - else - { - MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); - } + delete response; + } + else + { + MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); + } - return; - } + return; + } - switch (packet->GetClass()) - { - case RTC::StunPacket::Class::REQUEST: - { - // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. - if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) - { - MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); + switch (packet->GetClass()) + { + case RTC::StunPacket::Class::REQUEST: + { + // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. + if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) + { + MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } + return; + } - // Check authentication. - switch (packet->CheckAuthentication(this->usernameFragment, this->password)) - { - case RTC::StunPacket::Authentication::OK: - { - if (!this->oldPassword.empty()) - { - MS_DEBUG_TAG(ice, "new ICE credentials applied"); + // Check authentication. + switch (packet->CheckAuthentication(this->usernameFragment, this->password)) + { + case RTC::StunPacket::Authentication::OK: + { + if (!this->oldPassword.empty()) + { + MS_DEBUG_TAG(ice, "new ICE credentials applied"); - this->oldUsernameFragment.clear(); - this->oldPassword.clear(); - } + this->oldUsernameFragment.clear(); + this->oldPassword.clear(); + } - break; - } + break; + } - case RTC::StunPacket::Authentication::UNAUTHORIZED: - { - // We may have changed our usernameFragment and password, so check - // the old ones. - // clang-format off - if ( - !this->oldUsernameFragment.empty() && - !this->oldPassword.empty() && - packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK - ) - // clang-format on - { - MS_DEBUG_TAG(ice, "using old ICE credentials"); + case RTC::StunPacket::Authentication::UNAUTHORIZED: + { + // We may have changed our usernameFragment and password, so check + // the old ones. + // clang-format off + if ( + !this->oldUsernameFragment.empty() && + !this->oldPassword.empty() && + packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK + ) + // clang-format on + { + MS_DEBUG_TAG(ice, "using old ICE credentials"); - break; - } + break; + } - MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); + MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); - // Reply 401. - RTC::StunPacket* response = packet->CreateErrorResponse(401); + // Reply 401. + RTC::StunPacket* response = packet->CreateErrorResponse(401); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } + return; + } - case RTC::StunPacket::Authentication::BAD_REQUEST: - { - MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); + case RTC::StunPacket::Authentication::BAD_REQUEST: + { + MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); + // Reply 400. + RTC::StunPacket* response = packet->CreateErrorResponse(400); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } - } + return; + } + } #if 0 - // The remote peer must be ICE controlling. - if (packet->GetIceControlled()) - { - MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); + // The remote peer must be ICE controlling. + if (packet->GetIceControlled()) + { + MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); - // Reply 487 (Role Conflict). - RTC::StunPacket* response = packet->CreateErrorResponse(487); + // Reply 487 (Role Conflict). + RTC::StunPacket* response = packet->CreateErrorResponse(487); - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - delete response; + delete response; - return; - } + return; + } #endif - //MS_DEBUG_DEV( - // "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", - // static_cast(packet->GetPriority()), - // packet->HasUseCandidate() ? "true" : "false"); + //MS_DEBUG_DEV( + // "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", + // static_cast(packet->GetPriority()), + // packet->HasUseCandidate() ? "true" : "false"); - // Create a success response. - RTC::StunPacket* response = packet->CreateSuccessResponse(); - - // Add XOR-MAPPED-ADDRESS. - response->SetXorMappedAddress(tuple); + // Create a success response. + RTC::StunPacket* response = packet->CreateSuccessResponse(); - // Authenticate the response. - if (this->oldPassword.empty()) - response->Authenticate(this->password); - else - response->Authenticate(this->oldPassword); + sockaddr_storage peerAddr; + socklen_t addr_len = sizeof(peerAddr); + getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len); + + // Add XOR-MAPPED-ADDRESS. + response->SetXorMappedAddress((struct sockaddr *)&peerAddr); - // Send back. - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); + // Authenticate the response. + if (this->oldPassword.empty()) + response->Authenticate(this->password); + else + response->Authenticate(this->oldPassword); - delete response; + // Send back. + response->Serialize(StunSerializeBuffer); + this->listener->OnIceServerSendStunPacket(this, response, tuple); - // Handle the tuple. - HandleTuple(tuple, packet->HasUseCandidate()); + delete response; - break; - } + // Handle the tuple. + HandleTuple(tuple, packet->HasUseCandidate()); - case RTC::StunPacket::Class::INDICATION: - { - MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); + break; + } - break; - } - - case RTC::StunPacket::Class::SUCCESS_RESPONSE: - { - MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); - - break; - } + case RTC::StunPacket::Class::INDICATION: + { + MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); - case RTC::StunPacket::Class::ERROR_RESPONSE: - { - MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); + break; + } + + case RTC::StunPacket::Class::SUCCESS_RESPONSE: + { + MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); + + break; + } - break; - } - } - } - - bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const - { - MS_TRACE(); - - return HasTuple(tuple) != nullptr; - } - - void IceServer::RemoveTuple(RTC::TransportTuple* tuple) - { - MS_TRACE(); - - RTC::TransportTuple* removedTuple{ nullptr }; - - // Find the removed tuple. - auto it = this->tuples.begin(); - - for (; it != this->tuples.end(); ++it) - { - RTC::TransportTuple* storedTuple = std::addressof(*it); - - if (memcmp(storedTuple, tuple, sizeof (RTC::TransportTuple)) == 0) - { - removedTuple = storedTuple; - - break; - } - } - - // If not found, ignore. - if (!removedTuple) - return; - - // Remove from the list of tuples. - this->tuples.erase(it); - - // If this is not the selected tuple, stop here. - if (removedTuple != this->selectedTuple) - return; - - // Otherwise this was the selected tuple. - this->selectedTuple = nullptr; - - // Mark the first tuple as selected tuple (if any). - if (this->tuples.begin() != this->tuples.end()) - { - SetSelectedTuple(std::addressof(*this->tuples.begin())); - } - // Or just emit 'disconnected'. - else - { - // Update state. - this->state = IceState::DISCONNECTED; - // Notify the listener. - this->listener->OnIceServerDisconnected(this); - } - } - - void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) - { - MS_TRACE(); - - MS_ASSERT( - this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); - - auto* storedTuple = HasTuple(tuple); - - MS_ASSERT( - storedTuple, - "cannot force the selected tuple if the given tuple was not already a valid tuple"); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - } - - void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) - { - MS_TRACE(); - - switch (this->state) - { - case IceState::NEW: - { - // There should be no tuples. - MS_ASSERT( - this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); - - // There shouldn't be a selected tuple. - MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); - - if (!hasUseCandidate) - { - MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::CONNECTED; - // Notify the listener. - this->listener->OnIceServerConnected(this); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::DISCONNECTED: - { - // There should be no tuples. - MS_ASSERT( - this->tuples.empty(), - "state is 'disconnected' but there are %zu tuples", - this->tuples.size()); - - // There shouldn't be a selected tuple. - MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); - - if (!hasUseCandidate) - { - MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::CONNECTED; - // Notify the listener. - this->listener->OnIceServerConnected(this); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::CONNECTED: - { - // There should be some tuples. - MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); - - // There should be a selected tuple. - MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); - - if (!hasUseCandidate) - { - // If a new tuple store it. - if (!HasTuple(tuple)) - AddTuple(tuple); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); - - auto* storedTuple = HasTuple(tuple); - - // If a new tuple store it. - if (!storedTuple) - storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::COMPLETED: - { - // There should be some tuples. - MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); - - // There should be a selected tuple. - MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); - - if (!hasUseCandidate) - { - // If a new tuple store it. - if (!HasTuple(tuple)) - AddTuple(tuple); - } - else - { - auto* storedTuple = HasTuple(tuple); - - // If a new tuple store it. - if (!storedTuple) - storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - } - - break; - } - } - } - - inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) - { - MS_TRACE(); - - // Add the new tuple at the beginning of the list. - this->tuples.push_front(*tuple); - - auto* storedTuple = std::addressof(*this->tuples.begin()); - - // Return the address of the inserted tuple. - return storedTuple; - } - - inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const - { - MS_TRACE(); - - // If there is no selected tuple yet then we know that the tuples list - // is empty. - if (!this->selectedTuple) - return nullptr; - - // Check the current selected tuple. - if (memcmp(selectedTuple, tuple, sizeof (RTC::TransportTuple)) == 0) - return this->selectedTuple; - - // Otherwise check other stored tuples. - for (const auto& it : this->tuples) - { - auto* storedTuple = const_cast(std::addressof(it)); - - if (memcmp(storedTuple, tuple, sizeof (RTC::TransportTuple)) == 0) - return storedTuple; - } - - return nullptr; - } - - inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) - { - MS_TRACE(); - - // If already the selected tuple do nothing. - if (storedTuple == this->selectedTuple) - return; - - this->selectedTuple = storedTuple; - - // Notify the listener. - this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); - } + case RTC::StunPacket::Class::ERROR_RESPONSE: + { + MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); + + break; + } + } + } + + bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const + { + MS_TRACE(); + + return HasTuple(tuple) != nullptr; + } + + void IceServer::RemoveTuple(RTC::TransportTuple* tuple) + { + MS_TRACE(); + + RTC::TransportTuple* removedTuple{ nullptr }; + + // Find the removed tuple. + auto it = this->tuples.begin(); + + for (; it != this->tuples.end(); ++it) + { + RTC::TransportTuple* storedTuple = *it; + + if (storedTuple == tuple) + { + removedTuple = storedTuple; + + break; + } + } + + // If not found, ignore. + if (!removedTuple) + return; + + // Remove from the list of tuples. + this->tuples.erase(it); + + // If this is not the selected tuple, stop here. + if (removedTuple != this->selectedTuple) + return; + + // Otherwise this was the selected tuple. + this->selectedTuple = nullptr; + + // Mark the first tuple as selected tuple (if any). + if (!this->tuples.empty()) + { + SetSelectedTuple(this->tuples.front()); + } + // Or just emit 'disconnected'. + else + { + // Update state. + this->state = IceState::DISCONNECTED; + // Notify the listener. + this->listener->OnIceServerDisconnected(this); + } + } + + void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) + { + MS_TRACE(); + + MS_ASSERT( + this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); + + auto* storedTuple = HasTuple(tuple); + + MS_ASSERT( + storedTuple, + "cannot force the selected tuple if the given tuple was not already a valid tuple"); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + } + + void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) + { + MS_TRACE(); + + switch (this->state) + { + case IceState::NEW: + { + // There should be no tuples. + MS_ASSERT( + this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); + + // There shouldn't be a selected tuple. + MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); + + if (!hasUseCandidate) + { + MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::CONNECTED; + // Notify the listener. + this->listener->OnIceServerConnected(this); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::DISCONNECTED: + { + // There should be no tuples. + MS_ASSERT( + this->tuples.empty(), + "state is 'disconnected' but there are %zu tuples", + this->tuples.size()); + + // There shouldn't be a selected tuple. + MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); + + if (!hasUseCandidate) + { + MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::CONNECTED; + // Notify the listener. + this->listener->OnIceServerConnected(this); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); + + // Store the tuple. + auto* storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::CONNECTED: + { + // There should be some tuples. + MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); + + // There should be a selected tuple. + MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); + + if (!hasUseCandidate) + { + // If a new tuple store it. + if (!HasTuple(tuple)) + AddTuple(tuple); + } + else + { + MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); + + auto* storedTuple = HasTuple(tuple); + + // If a new tuple store it. + if (!storedTuple) + storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + // Update state. + this->state = IceState::COMPLETED; + // Notify the listener. + this->listener->OnIceServerCompleted(this); + } + + break; + } + + case IceState::COMPLETED: + { + // There should be some tuples. + MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); + + // There should be a selected tuple. + MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); + + if (!hasUseCandidate) + { + // If a new tuple store it. + if (!HasTuple(tuple)) + AddTuple(tuple); + } + else + { + auto* storedTuple = HasTuple(tuple); + + // If a new tuple store it. + if (!storedTuple) + storedTuple = AddTuple(tuple); + + // Mark it as selected tuple. + SetSelectedTuple(storedTuple); + } + + break; + } + } + } + + inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) + { + MS_TRACE(); + + // Add the new tuple at the beginning of the list. + this->tuples.push_front(tuple); + + // Return the address of the inserted tuple. + return tuple; + } + + inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const + { + MS_TRACE(); + + // If there is no selected tuple yet then we know that the tuples list + // is empty. + if (!this->selectedTuple) + return nullptr; + + // Check the current selected tuple. + if (selectedTuple == tuple) + return this->selectedTuple; + + // Otherwise check other stored tuples. + for (const auto& it : this->tuples) + { + auto& storedTuple = it; + if (storedTuple == tuple) + return storedTuple; + } + + return nullptr; + } + + inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) + { + MS_TRACE(); + + // If already the selected tuple do nothing. + if (storedTuple == this->selectedTuple) + return; + + this->selectedTuple = storedTuple; + this->lastSelectedTuple = storedTuple->shared_from_this(); + + // Notify the listener. + this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); + } } // namespace RTC diff --git a/webrtc/IceServer.hpp b/webrtc/IceServer.hpp index 8b9742ad..316d32af 100644 --- a/webrtc/IceServer.hpp +++ b/webrtc/IceServer.hpp @@ -20,6 +20,7 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. #define MS_RTC_ICE_SERVER_HPP #include "StunPacket.hpp" +#include "Network/Session.h" #include "logger.h" #include "Utils.hpp" #include @@ -27,110 +28,111 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. #include #include -using _TransportTuple = struct sockaddr; - namespace RTC { - using TransportTuple = _TransportTuple; - class IceServer - { - public: - enum class IceState - { - NEW = 1, - CONNECTED, - COMPLETED, - DISCONNECTED - }; + using TransportTuple = toolkit::Session; + class IceServer + { + public: + enum class IceState + { + NEW = 1, + CONNECTED, + COMPLETED, + DISCONNECTED + }; - public: - class Listener - { - public: - virtual ~Listener() = default; + public: + class Listener + { + public: + virtual ~Listener() = default; - public: - /** - * These callbacks are guaranteed to be called before ProcessStunPacket() - * returns, so the given pointers are still usable. - */ - virtual void OnIceServerSendStunPacket( - const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; - virtual void OnIceServerSelectedTuple( - const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; - virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; - virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; - virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; - }; + public: + /** + * These callbacks are guaranteed to be called before ProcessStunPacket() + * returns, so the given pointers are still usable. + */ + virtual void OnIceServerSendStunPacket( + const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; + virtual void OnIceServerSelectedTuple( + const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; + virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; + virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; + virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; + }; - public: - IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); + public: + IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); - public: - void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); - const std::string& GetUsernameFragment() const - { - return this->usernameFragment; - } - const std::string& GetPassword() const - { - return this->password; - } - IceState GetState() const - { - return this->state; - } - RTC::TransportTuple* GetSelectedTuple() const - { - return this->selectedTuple; - } - void SetUsernameFragment(const std::string& usernameFragment) - { - this->oldUsernameFragment = this->usernameFragment; - this->usernameFragment = usernameFragment; - } - void SetPassword(const std::string& password) - { - this->oldPassword = this->password; - this->password = password; - } - bool IsValidTuple(const RTC::TransportTuple* tuple) const; - void RemoveTuple(RTC::TransportTuple* tuple); - // This should be just called in 'connected' or completed' state - // and the given tuple must be an already valid tuple. - void ForceSelectedTuple(const RTC::TransportTuple* tuple); + public: + void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); + const std::string& GetUsernameFragment() const + { + return this->usernameFragment; + } + const std::string& GetPassword() const + { + return this->password; + } + IceState GetState() const + { + return this->state; + } + RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const + { + return try_last_tuple ? this->lastSelectedTuple.lock().get() : this->selectedTuple; + } + void SetUsernameFragment(const std::string& usernameFragment) + { + this->oldUsernameFragment = this->usernameFragment; + this->usernameFragment = usernameFragment; + } + void SetPassword(const std::string& password) + { + this->oldPassword = this->password; + this->password = password; + } + bool IsValidTuple(const RTC::TransportTuple* tuple) const; + void RemoveTuple(RTC::TransportTuple* tuple); + // This should be just called in 'connected' or completed' state + // and the given tuple must be an already valid tuple. + void ForceSelectedTuple(const RTC::TransportTuple* tuple); - private: - void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); - /** - * Store the given tuple and return its stored address. - */ - RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); - /** - * If the given tuple exists return its stored address, nullptr otherwise. - */ - RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; - /** - * Set the given tuple as the selected tuple. - * NOTE: The given tuple MUST be already stored within the list. - */ - void SetSelectedTuple(RTC::TransportTuple* storedTuple); + const std::list& GetTuples() const { return tuples; } - private: - // Passed by argument. - Listener* listener{ nullptr }; - // Others. - std::string usernameFragment; - std::string password; - std::string oldUsernameFragment; - std::string oldPassword; - IceState state{ IceState::NEW }; - std::list tuples; - RTC::TransportTuple* selectedTuple{ nullptr }; - //最大不超过mtu + private: + void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); + /** + * Store the given tuple and return its stored address. + */ + RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); + /** + * If the given tuple exists return its stored address, nullptr otherwise. + */ + RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; + /** + * Set the given tuple as the selected tuple. + * NOTE: The given tuple MUST be already stored within the list. + */ + void SetSelectedTuple(RTC::TransportTuple* storedTuple); + + private: + // Passed by argument. + Listener* listener{ nullptr }; + // Others. + std::string usernameFragment; + std::string password; + std::string oldUsernameFragment; + std::string oldPassword; + IceState state{ IceState::NEW }; + std::list tuples; + RTC::TransportTuple *selectedTuple { nullptr }; + std::weak_ptr lastSelectedTuple; + //最大不超过mtu static constexpr size_t StunSerializeBufferSize{ 1600 }; uint8_t StunSerializeBuffer[StunSerializeBufferSize]; - }; + }; } // namespace RTC #endif diff --git a/webrtc/Nack.cpp b/webrtc/Nack.cpp index 63e615b5..8bc51bf9 100644 --- a/webrtc/Nack.cpp +++ b/webrtc/Nack.cpp @@ -9,7 +9,7 @@ */ #include "Nack.h" - +#include using namespace std; using namespace toolkit; @@ -27,7 +27,7 @@ void NackList::pushBack(RtpPacket::Ptr rtp) { } _cache_ms_check = 0; while (getCacheMS() >= kMaxNackMS) { - //需要清除部分nack缓存 + // 需要清除部分nack缓存 popFront(); } } @@ -36,7 +36,7 @@ void NackList::forEach(const FCI_NACK &nack, const function= front_stamp) { return back_stamp - front_stamp; } - //很有可能回环了 + // 很有可能回环了 return back_stamp + (UINT32_MAX - front_stamp); } return 0; @@ -95,101 +95,119 @@ int64_t NackList::getRtpStamp(uint16_t seq) { //////////////////////////////////////////////////////////////////////////////////////////////// +NackContext::NackContext() { + setOnNack(nullptr); +} + void NackContext::received(uint16_t seq, bool is_rtx) { - if (!_last_max_seq && _seq.empty()) { - _last_max_seq = seq - 1; + if (!_started) { + // 记录第一个seq + _started = true; + _nack_seq = seq - 1; } - if (is_rtx || (seq < _last_max_seq && !(seq < 1024 && _last_max_seq > UINT16_MAX - 1024))) { - //重传包或 - // seq回退,且非回环,那么这个应该是重传包 - onRtx(seq); + + if (seq < _nack_seq && _nack_seq != UINT16_MAX && seq < 1024 && _nack_seq > UINT16_MAX - 1024) { + // seq回环,清空回环前状态 + makeNack(UINT16_MAX, true); + _seq.emplace(seq); + return; + } + + if (is_rtx || (seq < _nack_seq && _nack_seq != UINT16_MAX)) { + // seq非回环回退包,猜测其为重传包,清空其nack状态 + clearNackStatus(seq); + return; + } + + auto pr = _seq.emplace(seq); + if (!pr.second) { + // seq重复, 忽略 return; } - _seq.emplace(seq); auto max_seq = *_seq.rbegin(); auto min_seq = *_seq.begin(); auto diff = max_seq - min_seq; - if (!diff) { + if (diff > (UINT16_MAX >> 1)) { + // 回环后,收到回环前的大值seq, 忽略掉 + _seq.erase(max_seq); return; } - - if (diff > UINT16_MAX / 2) { - //回环 + if (min_seq == (uint16_t)(_nack_seq + 1) && _seq.size() == (size_t)diff + 1) { + // 都是连续的seq,未丢包 _seq.clear(); - _last_max_seq = min_seq; - _nack_send_status.clear(); - return; - } - - if (_seq.size() == (size_t)diff + 1 && _last_max_seq + 1 == min_seq) { - //都是连续的seq,未丢包 - _seq.clear(); - _last_max_seq = max_seq; + _nack_seq = max_seq; } else { // seq不连续,有丢包 - if (min_seq == _last_max_seq + 1) { - //前面部分seq是连续的,未丢包,移除之 - eraseFrontSeq(); - } + makeNack(max_seq, false); + } +} - //有丢包,丢包从_last_max_seq开始 - auto nack_rtp_count = FCI_NACK::kBitSize; - if (max_seq > nack_rtp_count + _last_max_seq) { - vector vec; - vec.resize(FCI_NACK::kBitSize, false); - for (size_t i = 0; i < nack_rtp_count; ++i) { - vec[i] = _seq.find(_last_max_seq + i + 2) == _seq.end(); - } - doNack(FCI_NACK(_last_max_seq + 1, vec), true); - _last_max_seq += nack_rtp_count + 1; - if (_last_max_seq >= max_seq) { - _seq.clear(); - } else { - auto it = _seq.emplace_hint(_seq.begin(), _last_max_seq + 1); - _seq.erase(_seq.begin(), it); - } +void NackContext::makeNack(uint16_t max_seq, bool flush) { + // 尝试移除前面部分连续的seq + eraseFrontSeq(); + // 最多生成5个nack包,防止seq大幅跳跃导致一直循环 + auto max_nack = 5u; + while (_nack_seq != max_seq && max_nack--) { + // 一次不能发送超过16+1个rtp的状态 + uint16_t nack_rtp_count = std::min(FCI_NACK::kBitSize, max_seq - (uint16_t)(_nack_seq + 1)); + if (!flush && nack_rtp_count < kNackRtpSize) { + // 非flush状态下,seq个数不足以发送一次nack + break; } + vector vec; + vec.resize(nack_rtp_count, false); + for (size_t i = 0; i < nack_rtp_count; ++i) { + vec[i] = _seq.find((uint16_t)(_nack_seq + i + 2)) == _seq.end(); + } + doNack(FCI_NACK(_nack_seq + 1, vec), true); + _nack_seq += nack_rtp_count + 1; + // 返回第一个比_last_max_seq大的元素 + auto it = _seq.upper_bound(_nack_seq); + // 移除 <=_last_max_seq 的seq + _seq.erase(_seq.begin(), it); } } void NackContext::setOnNack(onNack cb) { - _cb = std::move(cb); + if (cb) { + _cb = std::move(cb); + } else { + _cb = [](const FCI_NACK &nack) {}; + } } void NackContext::doNack(const FCI_NACK &nack, bool record_nack) { if (record_nack) { recordNack(nack); } - if (_cb) { - _cb(nack); - } + _cb(nack); } void NackContext::eraseFrontSeq() { - //前面部分seq是连续的,未丢包,移除之 + // 前面部分seq是连续的,未丢包,移除之 for (auto it = _seq.begin(); it != _seq.end();) { - if (*it != _last_max_seq + 1) { + if (*it != (uint16_t)(_nack_seq + 1)) { // seq不连续,丢包了 break; } - _last_max_seq = *it; + _nack_seq = *it; it = _seq.erase(it); } } -void NackContext::onRtx(uint16_t seq) { +void NackContext::clearNackStatus(uint16_t seq) { auto it = _nack_send_status.find(seq); if (it == _nack_send_status.end()) { return; } - auto rtt = getCurrentMillisecond() - it->second.update_stamp; + //收到重传包与第一个nack包间的时间约等于rtt时间 + auto rtt = getCurrentMillisecond() - it->second.first_stamp; _nack_send_status.erase(it); if (rtt >= 0) { - // rtt不肯小于0 + // rtt不能小于0 _rtt = rtt; - // InfoL << "rtt:" << rtt; } } @@ -205,7 +223,7 @@ void NackContext::recordNack(const FCI_NACK &nack) { } ++i; } - //记录太多了,移除一部分早期的记录 + // 记录太多了,移除一部分早期的记录 while (_nack_send_status.size() > kNackMaxSize) { _nack_send_status.erase(_nack_send_status.begin()); } @@ -216,18 +234,18 @@ uint64_t NackContext::reSendNack() { auto now = getCurrentMillisecond(); for (auto it = _nack_send_status.begin(); it != _nack_send_status.end();) { if (now - it->second.first_stamp > kNackMaxMS) { - //该rtp丢失太久了,不再要求重传 + // 该rtp丢失太久了,不再要求重传 it = _nack_send_status.erase(it); continue; } if (now - it->second.update_stamp < kNackIntervalRatio * _rtt) { - //距离上次nack不足2倍的rtt,不用再发送nack + // 距离上次nack不足2倍的rtt,不用再发送nack ++it; continue; } - //此rtp需要请求重传 + // 此rtp需要请求重传 nack_rtp.emplace(it->first); - //更新nack发送时间戳 + // 更新nack发送时间戳 it->second.update_stamp = now; if (++(it->second.nack_count) == kNackMaxCount) { // nack次数太多,移除之 @@ -237,11 +255,6 @@ uint64_t NackContext::reSendNack() { ++it; } - if (_nack_send_status.empty()) { - //不需要再发送nack - return 0; - } - int pid = -1; vector vec; for (auto it = nack_rtp.begin(); it != nack_rtp.end();) { @@ -253,12 +266,12 @@ uint64_t NackContext::reSendNack() { } auto inc = *it - pid; if (inc > (ssize_t)FCI_NACK::kBitSize) { - //新的nack包 + // 新的nack包 doNack(FCI_NACK(pid, vec), false); pid = -1; continue; } - //这个包丢了 + // 这个包丢了 vec[inc - 1] = true; ++it; } @@ -266,8 +279,8 @@ uint64_t NackContext::reSendNack() { doNack(FCI_NACK(pid, vec), false); } - //重传间隔不得低于5ms - return max(_rtt, 5); + // 没有任何包需要重传时返回0,否则返回下次重传间隔(不得低于5ms) + return _nack_send_status.empty() ? 0 : max(_rtt, 5); } } // namespace mediakit diff --git a/webrtc/Nack.h b/webrtc/Nack.h index 2b38e6a3..8780e27a 100644 --- a/webrtc/Nack.h +++ b/webrtc/Nack.h @@ -49,11 +49,15 @@ public: // rtp丢包状态最长保留时间 static constexpr auto kNackMaxMS = 3 * 1000; // nack最多请求重传10次 - static constexpr auto kNackMaxCount = 10; + static constexpr auto kNackMaxCount = 15; // nack重传频率,rtt的倍数 static constexpr auto kNackIntervalRatio = 1.0f; + // nack包中rtp个数,减小此值可以让nack包响应更灵敏 + static constexpr auto kNackRtpSize = 8; - NackContext() = default; + static_assert(kNackRtpSize >=0 && kNackRtpSize <= FCI_NACK::kBitSize, "NackContext::kNackRtpSize must between 0 and 16"); + + NackContext(); ~NackContext() = default; void received(uint16_t seq, bool is_rtx = false); @@ -64,13 +68,16 @@ private: void eraseFrontSeq(); void doNack(const FCI_NACK &nack, bool record_nack); void recordNack(const FCI_NACK &nack); - void onRtx(uint16_t seq); + void clearNackStatus(uint16_t seq); + void makeNack(uint16_t max, bool flush = false); private: + bool _started = false; int _rtt = 50; onNack _cb; std::set _seq; - uint16_t _last_max_seq = 0; + // 最新nack包中的rtp seq值 + uint16_t _nack_seq = 0; struct NackStatus { uint64_t first_stamp; diff --git a/webrtc/SctpAssociation.cpp b/webrtc/SctpAssociation.cpp index 0aec443b..84a2c04f 100644 --- a/webrtc/SctpAssociation.cpp +++ b/webrtc/SctpAssociation.cpp @@ -23,14 +23,14 @@ static constexpr uint16_t MaxSctpStreams{ 65535 }; /* clang-format off */ static constexpr uint16_t EventTypes[] = { - SCTP_ADAPTATION_INDICATION, - SCTP_ASSOC_CHANGE, - SCTP_ASSOC_RESET_EVENT, - SCTP_REMOTE_ERROR, - SCTP_SHUTDOWN_EVENT, - SCTP_SEND_FAILED_EVENT, - SCTP_STREAM_RESET_EVENT, - SCTP_STREAM_CHANGE_EVENT + SCTP_ADAPTATION_INDICATION, + SCTP_ASSOC_CHANGE, + SCTP_ASSOC_RESET_EVENT, + SCTP_REMOTE_ERROR, + SCTP_SHUTDOWN_EVENT, + SCTP_SEND_FAILED_EVENT, + SCTP_STREAM_RESET_EVENT, + SCTP_STREAM_CHANGE_EVENT }; /* clang-format on */ @@ -44,45 +44,45 @@ inline static int onRecvSctpData( int flags, void* ulpInfo) { - auto* sctpAssociation = static_cast(ulpInfo); + auto* sctpAssociation = static_cast(ulpInfo); - if (sctpAssociation == nullptr) - { - std::free(data); + if (sctpAssociation == nullptr) + { + std::free(data); - return 0; - } + return 0; + } - if (flags & MSG_NOTIFICATION) - { - sctpAssociation->OnUsrSctpReceiveSctpNotification( - static_cast(data), len); - } - else - { - uint16_t streamId = rcv.rcv_sid; - uint32_t ppid = ntohl(rcv.rcv_ppid); - uint16_t ssn = rcv.rcv_ssn; + if (flags & MSG_NOTIFICATION) + { + sctpAssociation->OnUsrSctpReceiveSctpNotification( + static_cast(data), len); + } + else + { + uint16_t streamId = rcv.rcv_sid; + uint32_t ppid = ntohl(rcv.rcv_ppid); + uint16_t ssn = rcv.rcv_ssn; - MS_DEBUG_TAG( - sctp, - "data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32 - ", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]", - len, - rcv.rcv_sid, - rcv.rcv_ssn, - rcv.rcv_tsn, - ntohl(rcv.rcv_ppid), - rcv.rcv_context, - flags); + MS_DEBUG_TAG( + sctp, + "data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32 + ", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]", + len, + rcv.rcv_sid, + rcv.rcv_ssn, + rcv.rcv_tsn, + ntohl(rcv.rcv_ppid), + rcv.rcv_context, + flags); - sctpAssociation->OnUsrSctpReceiveSctpData( - streamId, ssn, ppid, flags, static_cast(data), len); - } + sctpAssociation->OnUsrSctpReceiveSctpData( + streamId, ssn, ppid, flags, static_cast(data), len); + } - std::free(data); + std::free(data); - return 1; + return 1; } /* Static methods for usrsctp global callbacks. */ @@ -136,824 +136,824 @@ namespace RTC //////////////////////////////////////////////////////////////////////////////////// - /* Instance methods. */ + /* Instance methods. */ - SctpAssociation::SctpAssociation( - Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel) - : listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize), - isDataChannel(isDataChannel) - { - MS_TRACE(); + SctpAssociation::SctpAssociation( + Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel) + : listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize), + isDataChannel(isDataChannel) + { + MS_TRACE(); _env = SctpEnv::Instance().shared_from_this(); - // Register ourselves in usrsctp. - usrsctp_register_address(static_cast(this)); + // Register ourselves in usrsctp. + usrsctp_register_address(static_cast(this)); - int ret; + int ret; - this->socket = usrsctp_socket( - AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast(this)); + this->socket = usrsctp_socket( + AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast(this)); - if (this->socket == nullptr) - MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno)); + if (this->socket == nullptr) + MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno)); - usrsctp_set_ulpinfo(this->socket, static_cast(this)); + usrsctp_set_ulpinfo(this->socket, static_cast(this)); - // Make the socket non-blocking. - ret = usrsctp_set_non_blocking(this->socket, 1); + // Make the socket non-blocking. + ret = usrsctp_set_non_blocking(this->socket, 1); - if (ret < 0) - MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno)); - // Set SO_LINGER. - // This ensures that the usrsctp close call deletes the association. This - // prevents usrsctp from calling the global send callback with references to - // this class as the address. - struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Set SO_LINGER. + // This ensures that the usrsctp close call deletes the association. This + // prevents usrsctp from calling the global send callback with references to + // this class as the address. + struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init) - lingerOpt.l_onoff = 1; - lingerOpt.l_linger = 0; + lingerOpt.l_onoff = 1; + lingerOpt.l_linger = 0; - ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt)); + ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno)); - // Set SCTP_ENABLE_STREAM_RESET. - struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Set SCTP_ENABLE_STREAM_RESET. + struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) - av.assoc_value = - SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ; + av.assoc_value = + SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)); - if (ret < 0) - { - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno)); - } + if (ret < 0) + { + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno)); + } - // Set SCTP_NODELAY. - uint32_t noDelay = 1; + // Set SCTP_NODELAY. + uint32_t noDelay = 1; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno)); - // Enable events. - struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Enable events. + struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&event, 0, sizeof(event)); - event.se_on = 1; + std::memset(&event, 0, sizeof(event)); + event.se_on = 1; - for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i) - { - event.se_type = EventTypes[i]; + for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i) + { + event.se_type = EventTypes[i]; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno)); - } + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno)); + } - // Init message. - struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Init message. + struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&initmsg, 0, sizeof(initmsg)); - initmsg.sinit_num_ostreams = this->os; - initmsg.sinit_max_instreams = this->mis; + std::memset(&initmsg, 0, sizeof(initmsg)); + initmsg.sinit_num_ostreams = this->os; + initmsg.sinit_max_instreams = this->mis; - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg)); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno)); - // Server side. - struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Server side. + struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&sconn, 0, sizeof(sconn)); - sconn.sconn_family = AF_CONN; - sconn.sconn_port = htons(5000); - sconn.sconn_addr = static_cast(this); + std::memset(&sconn, 0, sizeof(sconn)); + sconn.sconn_family = AF_CONN; + sconn.sconn_port = htons(5000); + sconn.sconn_addr = static_cast(this); #ifdef HAVE_SCONN_LEN sconn.sconn_len = sizeof(sconn); #endif - ret = usrsctp_bind(this->socket, reinterpret_cast(&sconn), sizeof(sconn)); + ret = usrsctp_bind(this->socket, reinterpret_cast(&sconn), sizeof(sconn)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno)); - } + if (ret < 0) + MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno)); + } - SctpAssociation::~SctpAssociation() - { - MS_TRACE(); + SctpAssociation::~SctpAssociation() + { + MS_TRACE(); - usrsctp_set_ulpinfo(this->socket, nullptr); - usrsctp_close(this->socket); + usrsctp_set_ulpinfo(this->socket, nullptr); + usrsctp_close(this->socket); - // Deregister ourselves from usrsctp. - usrsctp_deregister_address(static_cast(this)); + // Deregister ourselves from usrsctp. + usrsctp_deregister_address(static_cast(this)); - delete[] this->messageBuffer; - } + delete[] this->messageBuffer; + } - void SctpAssociation::TransportConnected() - { - MS_TRACE(); + void SctpAssociation::TransportConnected() + { + MS_TRACE(); - // Just run the SCTP stack if our state is 'new'. - if (this->state != SctpState::NEW) - return; + // Just run the SCTP stack if our state is 'new'. + if (this->state != SctpState::NEW) + return; - try - { - int ret; - struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init) + try + { + int ret; + struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&rconn, 0, sizeof(rconn)); - rconn.sconn_family = AF_CONN; - rconn.sconn_port = htons(5000); - rconn.sconn_addr = static_cast(this); + std::memset(&rconn, 0, sizeof(rconn)); + rconn.sconn_family = AF_CONN; + rconn.sconn_port = htons(5000); + rconn.sconn_addr = static_cast(this); #ifdef HAVE_SCONN_LEN - rconn.sconn_len = sizeof(rconn); + rconn.sconn_len = sizeof(rconn); #endif - ret = usrsctp_connect(this->socket, reinterpret_cast(&rconn), sizeof(rconn)); + ret = usrsctp_connect(this->socket, reinterpret_cast(&rconn), sizeof(rconn)); - if (ret < 0 && errno != EINPROGRESS) - MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno)); + if (ret < 0 && errno != EINPROGRESS) + MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno)); - // Disable MTU discovery. - sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Disable MTU discovery. + sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&peerAddrParams, 0, sizeof(peerAddrParams)); - std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn)); - peerAddrParams.spp_flags = SPP_PMTUD_DISABLE; + std::memset(&peerAddrParams, 0, sizeof(peerAddrParams)); + std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn)); + peerAddrParams.spp_flags = SPP_PMTUD_DISABLE; - // The MTU value provided specifies the space available for chunks in the - // packet, so let's subtract the SCTP header size. - peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams); + // The MTU value provided specifies the space available for chunks in the + // packet, so let's subtract the SCTP header size. + peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams); - ret = usrsctp_setsockopt( - this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams)); + ret = usrsctp_setsockopt( + this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams)); - if (ret < 0) - MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno)); + if (ret < 0) + MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno)); - // Announce connecting state. - this->state = SctpState::CONNECTING; - this->listener->OnSctpAssociationConnecting(this); - } - catch (... /*error*/) - { - this->state = SctpState::FAILED; - this->listener->OnSctpAssociationFailed(this); + // Announce connecting state. + this->state = SctpState::CONNECTING; + this->listener->OnSctpAssociationConnecting(this); + } + catch (... /*error*/) + { + this->state = SctpState::FAILED; + this->listener->OnSctpAssociationFailed(this); throw; - } - } + } + } - void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) - { - MS_TRACE(); + void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) + { + MS_TRACE(); #if MS_LOG_DEV_LEVEL == 3 - MS_DUMP_DATA(data, len); + MS_DUMP_DATA(data, len); #endif - usrsctp_conninput(static_cast(this), data, len, 0); - } + usrsctp_conninput(static_cast(this), data, len, 0); + } - void SctpAssociation::SendSctpMessage( + void SctpAssociation::SendSctpMessage( const RTC::SctpStreamParameters ¶meters, uint32_t ppid, const uint8_t* msg, size_t len) - { - MS_TRACE(); + { + MS_TRACE(); - // This must be controlled by the DataConsumer. - MS_ASSERT( - len <= this->maxSctpMessageSize, - "given message exceeds max allowed message size [message size:%zu, max message size:%zu]", - len, - this->maxSctpMessageSize); + // This must be controlled by the DataConsumer. + MS_ASSERT( + len <= this->maxSctpMessageSize, + "given message exceeds max allowed message size [message size:%zu, max message size:%zu]", + len, + this->maxSctpMessageSize); - // Fill stcp_sendv_spa. - struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Fill stcp_sendv_spa. + struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&spa, 0, sizeof(spa)); - spa.sendv_flags = SCTP_SEND_SNDINFO_VALID; - spa.sendv_sndinfo.snd_sid = parameters.streamId; - spa.sendv_sndinfo.snd_ppid = htonl(ppid); - spa.sendv_sndinfo.snd_flags = SCTP_EOR; + std::memset(&spa, 0, sizeof(spa)); + spa.sendv_flags = SCTP_SEND_SNDINFO_VALID; + spa.sendv_sndinfo.snd_sid = parameters.streamId; + spa.sendv_sndinfo.snd_ppid = htonl(ppid); + spa.sendv_sndinfo.snd_flags = SCTP_EOR; - // If ordered it must be reliable. - if (parameters.ordered) - { - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; - spa.sendv_prinfo.pr_value = 0; - } - // Configure reliability: https://tools.ietf.org/html/rfc3758 - else - { - spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; - spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; + // If ordered it must be reliable. + if (parameters.ordered) + { + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; + spa.sendv_prinfo.pr_value = 0; + } + // Configure reliability: https://tools.ietf.org/html/rfc3758 + else + { + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; - if (parameters.maxPacketLifeTime != 0) - { - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; - spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime; - } - else if (parameters.maxRetransmits != 0) - { - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; - spa.sendv_prinfo.pr_value = parameters.maxRetransmits; - } - } + if (parameters.maxPacketLifeTime != 0) + { + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; + spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime; + } + else if (parameters.maxRetransmits != 0) + { + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; + spa.sendv_prinfo.pr_value = parameters.maxRetransmits; + } + } - int ret = usrsctp_sendv( - this->socket, msg, len, nullptr, 0, &spa, static_cast(sizeof(spa)), SCTP_SENDV_SPA, 0); + int ret = usrsctp_sendv( + this->socket, msg, len, nullptr, 0, &spa, static_cast(sizeof(spa)), SCTP_SENDV_SPA, 0); - if (ret < 0) - { - MS_WARN_TAG( - sctp, - "error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s", - parameters.streamId, - ppid, - len, - std::strerror(errno)); - } - } + if (ret < 0) + { + MS_WARN_TAG( + sctp, + "error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s", + parameters.streamId, + ppid, + len, + std::strerror(errno)); + } + } - void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters ¶ms) - { - MS_TRACE(); + void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters ¶ms) + { + MS_TRACE(); - auto streamId = params.streamId; + auto streamId = params.streamId; - // We need more OS. - if (streamId > this->os - 1) - AddOutgoingStreams(/*force*/ false); - } + // We need more OS. + if (streamId > this->os - 1) + AddOutgoingStreams(/*force*/ false); + } - void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters ¶ms) - { - MS_TRACE(); + void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters ¶ms) + { + MS_TRACE(); - auto streamId = params.streamId; + auto streamId = params.streamId; - // Send SCTP_RESET_STREAMS to the remote. - // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 - if (this->isDataChannel) - ResetSctpStream(streamId, StreamDirection::OUTGOING); - else - ResetSctpStream(streamId, StreamDirection::INCOMING); - } + // Send SCTP_RESET_STREAMS to the remote. + // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 + if (this->isDataChannel) + ResetSctpStream(streamId, StreamDirection::OUTGOING); + else + ResetSctpStream(streamId, StreamDirection::INCOMING); + } - void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters ¶ms) - { - MS_TRACE(); + void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters ¶ms) + { + MS_TRACE(); - auto streamId = params.streamId; + auto streamId = params.streamId; - // Send SCTP_RESET_STREAMS to the remote. - ResetSctpStream(streamId, StreamDirection::OUTGOING); - } + // Send SCTP_RESET_STREAMS to the remote. + ResetSctpStream(streamId, StreamDirection::OUTGOING); + } - void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction) - { - MS_TRACE(); + void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction) + { + MS_TRACE(); - // Do nothing if an outgoing stream that could not be allocated by us. - if (direction == StreamDirection::OUTGOING && streamId > this->os - 1) - return; + // Do nothing if an outgoing stream that could not be allocated by us. + if (direction == StreamDirection::OUTGOING && streamId > this->os - 1) + return; - int ret; - struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) - socklen_t len = sizeof(av); + int ret; + struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) + socklen_t len = sizeof(av); #ifndef SCTP_RECONFIG_SUPPORTED #define SCTP_RECONFIG_SUPPORTED 0x00000029 #endif - ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len); + ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len); - if (ret == 0) - { - if (av.assoc_value != 1) - { - MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated"); + if (ret == 0) + { + if (av.assoc_value != 1) + { + MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated"); - return; - } - } - else - { - MS_WARN_TAG( - sctp, - "could not retrieve whether stream reconfiguration has been negotiated: %s\n", - std::strerror(errno)); + return; + } + } + else + { + MS_WARN_TAG( + sctp, + "could not retrieve whether stream reconfiguration has been negotiated: %s\n", + std::strerror(errno)); - return; - } + return; + } - // As per spec: https://tools.ietf.org/html/rfc6525#section-4.1 - len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t); + // As per spec: https://tools.ietf.org/html/rfc6525#section-4.1 + len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t); - auto* srs = static_cast(std::malloc(len)); + auto* srs = static_cast(std::malloc(len)); - switch (direction) - { - case StreamDirection::INCOMING: - srs->srs_flags = SCTP_STREAM_RESET_INCOMING; - break; + switch (direction) + { + case StreamDirection::INCOMING: + srs->srs_flags = SCTP_STREAM_RESET_INCOMING; + break; - case StreamDirection::OUTGOING: - srs->srs_flags = SCTP_STREAM_RESET_OUTGOING; - break; - } + case StreamDirection::OUTGOING: + srs->srs_flags = SCTP_STREAM_RESET_OUTGOING; + break; + } - srs->srs_number_streams = 1; - srs->srs_stream_list[0] = streamId; // No need for htonl(). + srs->srs_number_streams = 1; + srs->srs_stream_list[0] = streamId; // No need for htonl(). - ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len); + ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len); - if (ret == 0) - { - MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId); - } - else - { - MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno)); - } + if (ret == 0) + { + MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId); + } + else + { + MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno)); + } - std::free(srs); - } + std::free(srs); + } - void SctpAssociation::AddOutgoingStreams(bool force) - { - MS_TRACE(); + void SctpAssociation::AddOutgoingStreams(bool force) + { + MS_TRACE(); - uint16_t additionalOs{ 0 }; + uint16_t additionalOs{ 0 }; - if (MaxSctpStreams - this->os >= 32) - additionalOs = 32; - else - additionalOs = MaxSctpStreams - this->os; + if (MaxSctpStreams - this->os >= 32) + additionalOs = 32; + else + additionalOs = MaxSctpStreams - this->os; - if (additionalOs == 0) - { - MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os); + if (additionalOs == 0) + { + MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os); - return; - } + return; + } - auto nextDesiredOs = this->os + additionalOs; + auto nextDesiredOs = this->os + additionalOs; - // Already in progress, ignore (unless forced). - if (!force && nextDesiredOs == this->desiredOs) - return; + // Already in progress, ignore (unless forced). + if (!force && nextDesiredOs == this->desiredOs) + return; - // Update desired value. - this->desiredOs = nextDesiredOs; + // Update desired value. + this->desiredOs = nextDesiredOs; - // If not connected, defer it. - if (this->state != SctpState::CONNECTED) - { - MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase"); + // If not connected, defer it. + if (this->state != SctpState::CONNECTED) + { + MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase"); - return; - } + return; + } - struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init) + struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init) - std::memset(&sas, 0, sizeof(sas)); - sas.sas_instrms = 0; - sas.sas_outstrms = additionalOs; + std::memset(&sas, 0, sizeof(sas)); + sas.sas_instrms = 0; + sas.sas_outstrms = additionalOs; - MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs); + MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs); - int ret = usrsctp_setsockopt( - this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast(sizeof(sas))); + int ret = usrsctp_setsockopt( + this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast(sizeof(sas))); - if (ret < 0) - MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno)); - } + if (ret < 0) + MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno)); + } - void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len) - { - MS_TRACE(); + void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len) + { + MS_TRACE(); - const uint8_t* data = static_cast(buffer); + const uint8_t* data = static_cast(buffer); #if MS_LOG_DEV_LEVEL == 3 - MS_DUMP_DATA(data, len); + MS_DUMP_DATA(data, len); #endif - this->listener->OnSctpAssociationSendData(this, data, len); - } + this->listener->OnSctpAssociationSendData(this, data, len); + } - void SctpAssociation::OnUsrSctpReceiveSctpData( - uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len) - { - // Ignore WebRTC DataChannel Control DATA chunks. - if (ppid == 50) - { - MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)"); + void SctpAssociation::OnUsrSctpReceiveSctpData( + uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len) + { + // Ignore WebRTC DataChannel Control DATA chunks. + if (ppid == 50) + { + MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)"); - return; - } + return; + } - if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived) - { - MS_WARN_TAG( - sctp, - "message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16 - ", last ssn received:%" PRIu16 "]", - ssn, - this->lastSsnReceived); + if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived) + { + MS_WARN_TAG( + sctp, + "message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16 + ", last ssn received:%" PRIu16 "]", + ssn, + this->lastSsnReceived); - this->messageBufferLen = 0; - } + this->messageBufferLen = 0; + } - // Update last SSN received. - this->lastSsnReceived = ssn; + // Update last SSN received. + this->lastSsnReceived = ssn; - auto eor = static_cast(flags & MSG_EOR); + auto eor = static_cast(flags & MSG_EOR); - if (this->messageBufferLen + len > this->maxSctpMessageSize) - { - MS_WARN_TAG( - sctp, - "ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]", - this->messageBufferLen + len, - this->maxSctpMessageSize, - eor ? 1 : 0); + if (this->messageBufferLen + len > this->maxSctpMessageSize) + { + MS_WARN_TAG( + sctp, + "ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]", + this->messageBufferLen + len, + this->maxSctpMessageSize, + eor ? 1 : 0); - this->lastSsnReceived = 0; + this->lastSsnReceived = 0; - return; - } + return; + } - // If end of message and there is no buffered data, notify it directly. - if (eor && this->messageBufferLen == 0) - { - MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]"); + // If end of message and there is no buffered data, notify it directly. + if (eor && this->messageBufferLen == 0) + { + MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]"); - this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len); - } - // If end of message and there is buffered data, append data and notify buffer. - else if (eor && this->messageBufferLen != 0) - { - std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); - this->messageBufferLen += len; + this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len); + } + // If end of message and there is buffered data, append data and notify buffer. + else if (eor && this->messageBufferLen != 0) + { + std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); + this->messageBufferLen += len; - MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen); + MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen); - this->listener->OnSctpAssociationMessageReceived( - this, streamId, ppid, this->messageBuffer, this->messageBufferLen); + this->listener->OnSctpAssociationMessageReceived( + this, streamId, ppid, this->messageBuffer, this->messageBufferLen); - this->messageBufferLen = 0; - } - // If non end of message, append data to the buffer. - else if (!eor) - { - // Allocate the buffer if not already done. - if (!this->messageBuffer) - this->messageBuffer = new uint8_t[this->maxSctpMessageSize]; + this->messageBufferLen = 0; + } + // If non end of message, append data to the buffer. + else if (!eor) + { + // Allocate the buffer if not already done. + if (!this->messageBuffer) + this->messageBuffer = new uint8_t[this->maxSctpMessageSize]; - std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); - this->messageBufferLen += len; + std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); + this->messageBufferLen += len; - MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen); - } - } + MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen); + } + } - void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len) - { - if (notification->sn_header.sn_length != (uint32_t)len) - return; + void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len) + { + if (notification->sn_header.sn_length != (uint32_t)len) + return; - switch (notification->sn_header.sn_type) - { - case SCTP_ADAPTATION_INDICATION: - { - MS_DEBUG_TAG( - sctp, - "SCTP adaptation indication [%x]", - notification->sn_adaptation_event.sai_adaptation_ind); + switch (notification->sn_header.sn_type) + { + case SCTP_ADAPTATION_INDICATION: + { + MS_DEBUG_TAG( + sctp, + "SCTP adaptation indication [%x]", + notification->sn_adaptation_event.sai_adaptation_ind); - break; - } + break; + } - case SCTP_ASSOC_CHANGE: - { - switch (notification->sn_assoc_change.sac_state) - { - case SCTP_COMM_UP: - { - MS_DEBUG_TAG( - sctp, - "SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]", - notification->sn_assoc_change.sac_outbound_streams, - notification->sn_assoc_change.sac_inbound_streams); + case SCTP_ASSOC_CHANGE: + { + switch (notification->sn_assoc_change.sac_state) + { + case SCTP_COMM_UP: + { + MS_DEBUG_TAG( + sctp, + "SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]", + notification->sn_assoc_change.sac_outbound_streams, + notification->sn_assoc_change.sac_inbound_streams); - // Update our OS. - this->os = notification->sn_assoc_change.sac_outbound_streams; + // Update our OS. + this->os = notification->sn_assoc_change.sac_outbound_streams; - // Increase if requested before connected. - if (this->desiredOs > this->os) - AddOutgoingStreams(/*force*/ true); + // Increase if requested before connected. + if (this->desiredOs > this->os) + AddOutgoingStreams(/*force*/ true); - if (this->state != SctpState::CONNECTED) - { - this->state = SctpState::CONNECTED; - this->listener->OnSctpAssociationConnected(this); - } + if (this->state != SctpState::CONNECTED) + { + this->state = SctpState::CONNECTED; + this->listener->OnSctpAssociationConnected(this); + } - break; - } + break; + } - case SCTP_COMM_LOST: - { - if (notification->sn_header.sn_length > 0) - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + case SCTP_COMM_LOST: + { + if (notification->sn_header.sn_length > 0) + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = notification->sn_header.sn_length; + uint32_t len = notification->sn_header.sn_length; - for (uint32_t i{ 0 }; i < len; ++i) - { - std::snprintf( - buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); - } + for (uint32_t i{ 0 }; i < len; ++i) + { + std::snprintf( + buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); + } - MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer); - } - else - { - MS_DEBUG_TAG(sctp, "SCTP communication lost"); - } + MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer); + } + else + { + MS_DEBUG_TAG(sctp, "SCTP communication lost"); + } - if (this->state != SctpState::CLOSED) - { - this->state = SctpState::CLOSED; - this->listener->OnSctpAssociationClosed(this); - } + if (this->state != SctpState::CLOSED) + { + this->state = SctpState::CLOSED; + this->listener->OnSctpAssociationClosed(this); + } - break; - } + break; + } - case SCTP_RESTART: - { - MS_DEBUG_TAG( - sctp, - "SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]", - notification->sn_assoc_change.sac_outbound_streams, - notification->sn_assoc_change.sac_inbound_streams); + case SCTP_RESTART: + { + MS_DEBUG_TAG( + sctp, + "SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]", + notification->sn_assoc_change.sac_outbound_streams, + notification->sn_assoc_change.sac_inbound_streams); - // Update our OS. - this->os = notification->sn_assoc_change.sac_outbound_streams; + // Update our OS. + this->os = notification->sn_assoc_change.sac_outbound_streams; - // Increase if requested before connected. - if (this->desiredOs > this->os) - AddOutgoingStreams(/*force*/ true); + // Increase if requested before connected. + if (this->desiredOs > this->os) + AddOutgoingStreams(/*force*/ true); - if (this->state != SctpState::CONNECTED) - { - this->state = SctpState::CONNECTED; - this->listener->OnSctpAssociationConnected(this); - } + if (this->state != SctpState::CONNECTED) + { + this->state = SctpState::CONNECTED; + this->listener->OnSctpAssociationConnected(this); + } - break; - } + break; + } - case SCTP_SHUTDOWN_COMP: - { - MS_DEBUG_TAG(sctp, "SCTP association gracefully closed"); + case SCTP_SHUTDOWN_COMP: + { + MS_DEBUG_TAG(sctp, "SCTP association gracefully closed"); - if (this->state != SctpState::CLOSED) - { - this->state = SctpState::CLOSED; - this->listener->OnSctpAssociationClosed(this); - } + if (this->state != SctpState::CLOSED) + { + this->state = SctpState::CLOSED; + this->listener->OnSctpAssociationClosed(this); + } - break; - } + break; + } - case SCTP_CANT_STR_ASSOC: - { - if (notification->sn_header.sn_length > 0) - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + case SCTP_CANT_STR_ASSOC: + { + if (notification->sn_header.sn_length > 0) + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = notification->sn_header.sn_length; + uint32_t len = notification->sn_header.sn_length; - for (uint32_t i{ 0 }; i < len; ++i) - { - std::snprintf( - buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); - } + for (uint32_t i{ 0 }; i < len; ++i) + { + std::snprintf( + buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); + } - MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer); - } + MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer); + } - if (this->state != SctpState::FAILED) - { - this->state = SctpState::FAILED; - this->listener->OnSctpAssociationFailed(this); - } + if (this->state != SctpState::FAILED) + { + this->state = SctpState::FAILED; + this->listener->OnSctpAssociationFailed(this); + } - break; - } + break; + } - default:; - } + default:; + } - break; - } + break; + } - // https://tools.ietf.org/html/rfc6525#section-6.1.2. - case SCTP_ASSOC_RESET_EVENT: - { - MS_DEBUG_TAG(sctp, "SCTP association reset event received"); + // https://tools.ietf.org/html/rfc6525#section-6.1.2. + case SCTP_ASSOC_RESET_EVENT: + { + MS_DEBUG_TAG(sctp, "SCTP association reset event received"); - break; - } + break; + } - // An Operation Error is not considered fatal in and of itself, but may be - // used with an ABORT chunk to report a fatal condition. - case SCTP_REMOTE_ERROR: - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + // An Operation Error is not considered fatal in and of itself, but may be + // used with an ABORT chunk to report a fatal condition. + case SCTP_REMOTE_ERROR: + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error); + uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error); - for (uint32_t i{ 0 }; i < len; i++) - { - std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]); - } + for (uint32_t i{ 0 }; i < len; i++) + { + std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]); + } - MS_WARN_TAG( - sctp, - "remote SCTP association error [type:0x%04x, data:%s]", - notification->sn_remote_error.sre_error, - buffer); + MS_WARN_TAG( + sctp, + "remote SCTP association error [type:0x%04x, data:%s]", + notification->sn_remote_error.sre_error, + buffer); - break; - } + break; + } - // When a peer sends a SHUTDOWN, SCTP delivers this notification to - // inform the application that it should cease sending data. - case SCTP_SHUTDOWN_EVENT: - { - MS_DEBUG_TAG(sctp, "remote SCTP association shutdown"); + // When a peer sends a SHUTDOWN, SCTP delivers this notification to + // inform the application that it should cease sending data. + case SCTP_SHUTDOWN_EVENT: + { + MS_DEBUG_TAG(sctp, "remote SCTP association shutdown"); - if (this->state != SctpState::CLOSED) - { - this->state = SctpState::CLOSED; - this->listener->OnSctpAssociationClosed(this); - } + if (this->state != SctpState::CLOSED) + { + this->state = SctpState::CLOSED; + this->listener->OnSctpAssociationClosed(this); + } - break; - } + break; + } - case SCTP_SEND_FAILED_EVENT: - { - static const size_t BufferSize{ 1024 }; - static char buffer[BufferSize]; + case SCTP_SEND_FAILED_EVENT: + { + static const size_t BufferSize{ 1024 }; + static char buffer[BufferSize]; - uint32_t len = - notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event); + uint32_t len = + notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event); - for (uint32_t i{ 0 }; i < len; ++i) - { - std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]); - } + for (uint32_t i{ 0 }; i < len; ++i) + { + std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]); + } - MS_WARN_TAG( - sctp, - "SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32 - ", sent:%s, error:0x%08x, info:%s]", - notification->sn_send_failed_event.ssfe_info.snd_sid, - ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid), - (notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no", - notification->sn_send_failed_event.ssfe_error, - buffer); + MS_WARN_TAG( + sctp, + "SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32 + ", sent:%s, error:0x%08x, info:%s]", + notification->sn_send_failed_event.ssfe_info.snd_sid, + ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid), + (notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no", + notification->sn_send_failed_event.ssfe_error, + buffer); - break; - } + break; + } - case SCTP_STREAM_RESET_EVENT: - { - bool incoming{ false }; - bool outgoing{ false }; - uint16_t numStreams = - (notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) / - sizeof(uint16_t); + case SCTP_STREAM_RESET_EVENT: + { + bool incoming{ false }; + bool outgoing{ false }; + uint16_t numStreams = + (notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) / + sizeof(uint16_t); - if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) - incoming = true; + if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) + incoming = true; - if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) - outgoing = true; + if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) + outgoing = true; //todo 打印sctp调试信息 - if (false /*MS_HAS_DEBUG_TAG(sctp)*/) - { - std::string streamIds; + if (false /*MS_HAS_DEBUG_TAG(sctp)*/) + { + std::string streamIds; - for (uint16_t i{ 0 }; i < numStreams; ++i) - { - auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; + for (uint16_t i{ 0 }; i < numStreams; ++i) + { + auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; - // Don't log more than 5 stream ids. - if (i > 4) - { - streamIds.append("..."); + // Don't log more than 5 stream ids. + if (i > 4) + { + streamIds.append("..."); - break; - } + break; + } - if (i > 0) - streamIds.append(","); + if (i > 0) + streamIds.append(","); - streamIds.append(std::to_string(streamId)); - } + streamIds.append(std::to_string(streamId)); + } - MS_DEBUG_TAG( - sctp, - "SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]", - notification->sn_strreset_event.strreset_flags, - incoming ? "true" : "false", - outgoing ? "true" : "false", - numStreams, - streamIds.c_str()); - } + MS_DEBUG_TAG( + sctp, + "SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]", + notification->sn_strreset_event.strreset_flags, + incoming ? "true" : "false", + outgoing ? "true" : "false", + numStreams, + streamIds.c_str()); + } - // Special case for WebRTC DataChannels in which we must also reset our - // outgoing SCTP stream. - if (incoming && !outgoing && this->isDataChannel) - { - for (uint16_t i{ 0 }; i < numStreams; ++i) - { - auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; + // Special case for WebRTC DataChannels in which we must also reset our + // outgoing SCTP stream. + if (incoming && !outgoing && this->isDataChannel) + { + for (uint16_t i{ 0 }; i < numStreams; ++i) + { + auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; - ResetSctpStream(streamId, StreamDirection::OUTGOING); - } - } + ResetSctpStream(streamId, StreamDirection::OUTGOING); + } + } - break; - } + break; + } - case SCTP_STREAM_CHANGE_EVENT: - { - if (notification->sn_strchange_event.strchange_flags == 0) - { - MS_DEBUG_TAG( - sctp, - "SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", - notification->sn_strchange_event.strchange_outstrms, - notification->sn_strchange_event.strchange_instrms, - notification->sn_strchange_event.strchange_flags); - } - else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED) - { - MS_WARN_TAG( - sctp, - "SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", - notification->sn_strchange_event.strchange_outstrms, - notification->sn_strchange_event.strchange_instrms, - notification->sn_strchange_event.strchange_flags); + case SCTP_STREAM_CHANGE_EVENT: + { + if (notification->sn_strchange_event.strchange_flags == 0) + { + MS_DEBUG_TAG( + sctp, + "SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", + notification->sn_strchange_event.strchange_outstrms, + notification->sn_strchange_event.strchange_instrms, + notification->sn_strchange_event.strchange_flags); + } + else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED) + { + MS_WARN_TAG( + sctp, + "SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", + notification->sn_strchange_event.strchange_outstrms, + notification->sn_strchange_event.strchange_instrms, + notification->sn_strchange_event.strchange_flags); - break; - } - else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED) - { - MS_WARN_TAG( - sctp, - "SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", - notification->sn_strchange_event.strchange_outstrms, - notification->sn_strchange_event.strchange_instrms, - notification->sn_strchange_event.strchange_flags); + break; + } + else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED) + { + MS_WARN_TAG( + sctp, + "SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", + notification->sn_strchange_event.strchange_outstrms, + notification->sn_strchange_event.strchange_instrms, + notification->sn_strchange_event.strchange_flags); - break; - } + break; + } - // Update OS. - this->os = notification->sn_strchange_event.strchange_outstrms; + // Update OS. + this->os = notification->sn_strchange_event.strchange_outstrms; - break; - } + break; + } - default: - { - MS_WARN_TAG( - sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type); - } - } - } + default: + { + MS_WARN_TAG( + sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type); + } + } + } //////////////////////////////////////////////////////////////////////////////////////// diff --git a/webrtc/SctpAssociation.hpp b/webrtc/SctpAssociation.hpp index 548221c5..9c46d275 100644 --- a/webrtc/SctpAssociation.hpp +++ b/webrtc/SctpAssociation.hpp @@ -18,104 +18,104 @@ namespace RTC uint16_t maxRetransmits{ 0u }; }; - class SctpAssociation - { - public: - enum class SctpState - { - NEW = 1, - CONNECTING, - CONNECTED, - FAILED, - CLOSED - }; + class SctpAssociation + { + public: + enum class SctpState + { + NEW = 1, + CONNECTING, + CONNECTED, + FAILED, + CLOSED + }; - private: - enum class StreamDirection - { - INCOMING = 1, - OUTGOING - }; + private: + enum class StreamDirection + { + INCOMING = 1, + OUTGOING + }; - public: - class Listener - { - public: - virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0; - virtual void OnSctpAssociationSendData( - RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0; - virtual void OnSctpAssociationMessageReceived( - RTC::SctpAssociation* sctpAssociation, - uint16_t streamId, - uint32_t ppid, - const uint8_t* msg, - size_t len) = 0; - }; + public: + class Listener + { + public: + virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0; + virtual void OnSctpAssociationSendData( + RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0; + virtual void OnSctpAssociationMessageReceived( + RTC::SctpAssociation* sctpAssociation, + uint16_t streamId, + uint32_t ppid, + const uint8_t* msg, + size_t len) = 0; + }; - public: - static bool IsSctp(const uint8_t* data, size_t len) - { - // clang-format off - return ( - (len >= 12) && - // Must have Source Port Number and Destination Port Number set to 5000 (hack). - (Utils::Byte::Get2Bytes(data, 0) == 5000) && - (Utils::Byte::Get2Bytes(data, 2) == 5000) - ); - // clang-format on - } + public: + static bool IsSctp(const uint8_t* data, size_t len) + { + // clang-format off + return ( + (len >= 12) && + // Must have Source Port Number and Destination Port Number set to 5000 (hack). + (Utils::Byte::Get2Bytes(data, 0) == 5000) && + (Utils::Byte::Get2Bytes(data, 2) == 5000) + ); + // clang-format on + } - public: - SctpAssociation( - Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel); - virtual ~SctpAssociation(); + public: + SctpAssociation( + Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel); + virtual ~SctpAssociation(); - public: - void TransportConnected(); - size_t GetMaxSctpMessageSize() const - { - return this->maxSctpMessageSize; - } - SctpState GetState() const - { - return this->state; - } - void ProcessSctpData(const uint8_t* data, size_t len); - void SendSctpMessage(const RTC::SctpStreamParameters ¶ms, uint32_t ppid, const uint8_t* msg, size_t len); - void HandleDataConsumer(const RTC::SctpStreamParameters ¶ms); - void DataProducerClosed(const RTC::SctpStreamParameters ¶ms); - void DataConsumerClosed(const RTC::SctpStreamParameters ¶ms); + public: + void TransportConnected(); + size_t GetMaxSctpMessageSize() const + { + return this->maxSctpMessageSize; + } + SctpState GetState() const + { + return this->state; + } + void ProcessSctpData(const uint8_t* data, size_t len); + void SendSctpMessage(const RTC::SctpStreamParameters ¶ms, uint32_t ppid, const uint8_t* msg, size_t len); + void HandleDataConsumer(const RTC::SctpStreamParameters ¶ms); + void DataProducerClosed(const RTC::SctpStreamParameters ¶ms); + void DataConsumerClosed(const RTC::SctpStreamParameters ¶ms); - private: - void ResetSctpStream(uint16_t streamId, StreamDirection); - void AddOutgoingStreams(bool force = false); + private: + void ResetSctpStream(uint16_t streamId, StreamDirection); + void AddOutgoingStreams(bool force = false); - public: + public: /* Callbacks fired by usrsctp events. */ virtual void OnUsrSctpSendSctpData(void* buffer, size_t len); virtual void OnUsrSctpReceiveSctpData(uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len); virtual void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len); - private: - // Passed by argument. - Listener* listener{ nullptr }; - uint16_t os{ 1024u }; - uint16_t mis{ 1024u }; - size_t maxSctpMessageSize{ 262144u }; - bool isDataChannel{ false }; - // Allocated by this. - uint8_t* messageBuffer{ nullptr }; - // Others. - SctpState state{ SctpState::NEW }; - struct socket* socket{ nullptr }; - uint16_t desiredOs{ 0u }; - size_t messageBufferLen{ 0u }; - uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support. + private: + // Passed by argument. + Listener* listener{ nullptr }; + uint16_t os{ 1024u }; + uint16_t mis{ 1024u }; + size_t maxSctpMessageSize{ 262144u }; + bool isDataChannel{ false }; + // Allocated by this. + uint8_t* messageBuffer{ nullptr }; + // Others. + SctpState state{ SctpState::NEW }; + struct socket* socket{ nullptr }; + uint16_t desiredOs{ 0u }; + size_t messageBufferLen{ 0u }; + uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support. std::shared_ptr _env; - }; + }; //保证线程安全 class SctpAssociationImp : public SctpAssociation, public std::enable_shared_from_this{ diff --git a/webrtc/StunPacket.cpp b/webrtc/StunPacket.cpp index c7a403fb..9f648d55 100644 --- a/webrtc/StunPacket.cpp +++ b/webrtc/StunPacket.cpp @@ -97,785 +97,785 @@ namespace RTC return str; } - /* Class variables. */ - - const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; - - /* Class methods. */ - - StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) - { - MS_TRACE(); - - if (!StunPacket::IsStun(data, len)) - return nullptr; - - /* - The message type field is decomposed further into the following - structure: - - 0 1 - 2 3 4 5 6 7 8 9 0 1 2 3 4 5 - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - |M |M |M|M|M|C|M|M|M|C|M|M|M|M| - |11|10|9|8|7|1|6|5|4|0|3|2|1|0| - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - - Figure 3: Format of STUN Message Type Field - - Here the bits in the message type field are shown as most significant - (M11) through least significant (M0). M11 through M0 represent a 12- - bit encoding of the method. C1 and C0 represent a 2-bit encoding of - the class. - */ - - // Get type field. - uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); - - // Get length field. - uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); - - // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. - if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) - { - MS_WARN_TAG( - ice, - "length field + 20 does not match total size (or it is not multiple of 4 bytes), " - "packet discarded"); - - return nullptr; - } - - // Get STUN method. - uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); - - // Get STUN class. - uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); - - // Create a new StunPacket (data + 8 points to the received TransactionID field). - auto* packet = new StunPacket( - static_cast(msgClass), static_cast(msgMethod), data + 8, data, len); - - /* - STUN Attributes - - After the STUN header are zero or more attributes. Each attribute - MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. - Each STUN attribute MUST end on a 32-bit boundary. As mentioned - above, all fields in an attribute are transmitted most significant - bit first. - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Value (variable) .... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - - // Start looking for attributes after STUN header (Byte #20). - size_t pos{ 20 }; - // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. - bool hasMessageIntegrity{ false }; - bool hasFingerprint{ false }; - size_t fingerprintAttrPos; // Will point to the beginning of the attribute. - uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. - - // Ensure there are at least 4 remaining bytes (attribute with 0 length). - while (pos + 4 <= len) - { - // Get the attribute type. - auto attrType = static_cast(Utils::Byte::Get2Bytes(data, pos)); - - // Get the attribute length. - uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); - - // Ensure the attribute length is not greater than the remaining size. - if ((pos + 4 + attrLength) > len) - { - MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded"); - - delete packet; - return nullptr; - } - - // FINGERPRINT must be the last attribute. - if (hasFingerprint) - { - MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded"); - - delete packet; - return nullptr; - } - - // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. - if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) - { - MS_WARN_TAG( - ice, - "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " - "packet discarded"); - - delete packet; - return nullptr; - } - - const uint8_t* attrValuePos = data + pos + 4; - - switch (attrType) - { - case Attribute::USERNAME: - { - packet->SetUsername( - reinterpret_cast(attrValuePos), static_cast(attrLength)); - - break; - } - - case Attribute::PRIORITY: - { - // Ensure attribute length is 4 bytes. - if (attrLength != 4) - { - MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::ICE_CONTROLLING: - { - // Ensure attribute length is 8 bytes. - if (attrLength != 8) - { - MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::ICE_CONTROLLED: - { - // Ensure attribute length is 8 bytes. - if (attrLength != 8) - { - MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::USE_CANDIDATE: - { - // Ensure attribute length is 0 bytes. - if (attrLength != 0) - { - MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetUseCandidate(); - - break; - } - - case Attribute::MESSAGE_INTEGRITY: - { - // Ensure attribute length is 20 bytes. - if (attrLength != 20) - { - MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - hasMessageIntegrity = true; - packet->SetMessageIntegrity(attrValuePos); - - break; - } - - case Attribute::FINGERPRINT: - { - // Ensure attribute length is 4 bytes. - if (attrLength != 4) - { - MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - hasFingerprint = true; - fingerprintAttrPos = pos; - fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); - packet->SetFingerprint(); - - break; - } - - case Attribute::ERROR_CODE: - { - // Ensure attribute length >= 4bytes. - if (attrLength < 4) - { - MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); - uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); - auto errorCode = static_cast(errorClass * 100 + errorNumber); - - packet->SetErrorCode(errorCode); - - break; - } - - default:; - } - - // Set next attribute position. - pos = - static_cast(Utils::Byte::PadTo4Bytes(static_cast(pos + 4 + attrLength))); - } - - // Ensure current position matches the total length. - if (pos != len) - { - MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded"); - - delete packet; - return nullptr; - } - - // If it has FINGERPRINT attribute then verify it. - if (hasFingerprint) - { - // Compute the CRC32 of the received packet up to (but excluding) the - // FINGERPRINT attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; - - // Compare with the FINGERPRINT value in the packet. - if (fingerprint != computedFingerprint) - { - MS_WARN_TAG( - ice, - "computed FINGERPRINT value does not match the value in the packet, " - "packet discarded"); - - delete packet; - return nullptr; - } - } - - return packet; - } - - /* Instance methods. */ - - StunPacket::StunPacket( - Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size) - : klass(klass), method(method), transactionId(transactionId), data(const_cast(data)), - size(size) - { - MS_TRACE(); - } - - StunPacket::~StunPacket() - { - MS_TRACE(); - } + /* Class variables. */ + + const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; + + /* Class methods. */ + + StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) + { + MS_TRACE(); + + if (!StunPacket::IsStun(data, len)) + return nullptr; + + /* + The message type field is decomposed further into the following + structure: + + 0 1 + 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + + Figure 3: Format of STUN Message Type Field + + Here the bits in the message type field are shown as most significant + (M11) through least significant (M0). M11 through M0 represent a 12- + bit encoding of the method. C1 and C0 represent a 2-bit encoding of + the class. + */ + + // Get type field. + uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); + + // Get length field. + uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); + + // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. + if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) + { + MS_WARN_TAG( + ice, + "length field + 20 does not match total size (or it is not multiple of 4 bytes), " + "packet discarded"); + + return nullptr; + } + + // Get STUN method. + uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); + + // Get STUN class. + uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); + + // Create a new StunPacket (data + 8 points to the received TransactionID field). + auto* packet = new StunPacket( + static_cast(msgClass), static_cast(msgMethod), data + 8, data, len); + + /* + STUN Attributes + + After the STUN header are zero or more attributes. Each attribute + MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. + Each STUN attribute MUST end on a 32-bit boundary. As mentioned + above, all fields in an attribute are transmitted most significant + bit first. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type | Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Value (variable) .... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + // Start looking for attributes after STUN header (Byte #20). + size_t pos{ 20 }; + // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. + bool hasMessageIntegrity{ false }; + bool hasFingerprint{ false }; + size_t fingerprintAttrPos; // Will point to the beginning of the attribute. + uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. + + // Ensure there are at least 4 remaining bytes (attribute with 0 length). + while (pos + 4 <= len) + { + // Get the attribute type. + auto attrType = static_cast(Utils::Byte::Get2Bytes(data, pos)); + + // Get the attribute length. + uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); + + // Ensure the attribute length is not greater than the remaining size. + if ((pos + 4 + attrLength) > len) + { + MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded"); + + delete packet; + return nullptr; + } + + // FINGERPRINT must be the last attribute. + if (hasFingerprint) + { + MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded"); + + delete packet; + return nullptr; + } + + // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. + if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) + { + MS_WARN_TAG( + ice, + "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " + "packet discarded"); + + delete packet; + return nullptr; + } + + const uint8_t* attrValuePos = data + pos + 4; + + switch (attrType) + { + case Attribute::USERNAME: + { + packet->SetUsername( + reinterpret_cast(attrValuePos), static_cast(attrLength)); + + break; + } + + case Attribute::PRIORITY: + { + // Ensure attribute length is 4 bytes. + if (attrLength != 4) + { + MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::ICE_CONTROLLING: + { + // Ensure attribute length is 8 bytes. + if (attrLength != 8) + { + MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::ICE_CONTROLLED: + { + // Ensure attribute length is 8 bytes. + if (attrLength != 8) + { + MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); + + break; + } + + case Attribute::USE_CANDIDATE: + { + // Ensure attribute length is 0 bytes. + if (attrLength != 0) + { + MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + packet->SetUseCandidate(); + + break; + } + + case Attribute::MESSAGE_INTEGRITY: + { + // Ensure attribute length is 20 bytes. + if (attrLength != 20) + { + MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + hasMessageIntegrity = true; + packet->SetMessageIntegrity(attrValuePos); + + break; + } + + case Attribute::FINGERPRINT: + { + // Ensure attribute length is 4 bytes. + if (attrLength != 4) + { + MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + hasFingerprint = true; + fingerprintAttrPos = pos; + fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); + packet->SetFingerprint(); + + break; + } + + case Attribute::ERROR_CODE: + { + // Ensure attribute length >= 4bytes. + if (attrLength < 4) + { + MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded"); + + delete packet; + return nullptr; + } + + uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); + uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); + auto errorCode = static_cast(errorClass * 100 + errorNumber); + + packet->SetErrorCode(errorCode); + + break; + } + + default:; + } + + // Set next attribute position. + pos = + static_cast(Utils::Byte::PadTo4Bytes(static_cast(pos + 4 + attrLength))); + } + + // Ensure current position matches the total length. + if (pos != len) + { + MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded"); + + delete packet; + return nullptr; + } + + // If it has FINGERPRINT attribute then verify it. + if (hasFingerprint) + { + // Compute the CRC32 of the received packet up to (but excluding) the + // FINGERPRINT attribute and XOR it with 0x5354554e. + uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; + + // Compare with the FINGERPRINT value in the packet. + if (fingerprint != computedFingerprint) + { + MS_WARN_TAG( + ice, + "computed FINGERPRINT value does not match the value in the packet, " + "packet discarded"); + + delete packet; + return nullptr; + } + } + + return packet; + } + + /* Instance methods. */ + + StunPacket::StunPacket( + Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size) + : klass(klass), method(method), transactionId(transactionId), data(const_cast(data)), + size(size) + { + MS_TRACE(); + } + + StunPacket::~StunPacket() + { + MS_TRACE(); + } #if 0 - void StunPacket::Dump() const - { - MS_TRACE(); + void StunPacket::Dump() const + { + MS_TRACE(); - MS_DUMP(""); + MS_DUMP(""); - std::string klass; - switch (this->klass) - { - case Class::REQUEST: - klass = "Request"; - break; - case Class::INDICATION: - klass = "Indication"; - break; - case Class::SUCCESS_RESPONSE: - klass = "SuccessResponse"; - break; - case Class::ERROR_RESPONSE: - klass = "ErrorResponse"; - break; - } - if (this->method == Method::BINDING) - { - MS_DUMP(" Binding %s", klass.c_str()); - } - else - { - // This prints the unknown method number. Example: TURN Allocate => 0x003. - MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast(this->method)); - } - MS_DUMP(" size: %zu bytes", this->size); + std::string klass; + switch (this->klass) + { + case Class::REQUEST: + klass = "Request"; + break; + case Class::INDICATION: + klass = "Indication"; + break; + case Class::SUCCESS_RESPONSE: + klass = "SuccessResponse"; + break; + case Class::ERROR_RESPONSE: + klass = "ErrorResponse"; + break; + } + if (this->method == Method::BINDING) + { + MS_DUMP(" Binding %s", klass.c_str()); + } + else + { + // This prints the unknown method number. Example: TURN Allocate => 0x003. + MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast(this->method)); + } + MS_DUMP(" size: %zu bytes", this->size); - static char transactionId[25]; + static char transactionId[25]; - for (int i{ 0 }; i < 12; ++i) - { - // NOTE: n must be 3 because snprintf adds a \0 after printed chars. - std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); - } - MS_DUMP(" transactionId: %s", transactionId); - if (this->errorCode != 0u) - MS_DUMP(" errorCode: %" PRIu16, this->errorCode); - if (!this->username.empty()) - MS_DUMP(" username: %s", this->username.c_str()); - if (this->priority != 0u) - MS_DUMP(" priority: %" PRIu32, this->priority); - if (this->iceControlling != 0u) - MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); - if (this->iceControlled != 0u) - MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); - if (this->hasUseCandidate) - MS_DUMP(" useCandidate"); - if (this->xorMappedAddress != nullptr) - { - int family; - uint16_t port; - std::string ip; + for (int i{ 0 }; i < 12; ++i) + { + // NOTE: n must be 3 because snprintf adds a \0 after printed chars. + std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); + } + MS_DUMP(" transactionId: %s", transactionId); + if (this->errorCode != 0u) + MS_DUMP(" errorCode: %" PRIu16, this->errorCode); + if (!this->username.empty()) + MS_DUMP(" username: %s", this->username.c_str()); + if (this->priority != 0u) + MS_DUMP(" priority: %" PRIu32, this->priority); + if (this->iceControlling != 0u) + MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); + if (this->iceControlled != 0u) + MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); + if (this->hasUseCandidate) + MS_DUMP(" useCandidate"); + if (this->xorMappedAddress != nullptr) + { + int family; + uint16_t port; + std::string ip; - Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); + Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); - MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); - } - if (this->messageIntegrity != nullptr) - { - static char messageIntegrity[41]; + MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); + } + if (this->messageIntegrity != nullptr) + { + static char messageIntegrity[41]; - for (int i{ 0 }; i < 20; ++i) - { - std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); - } + for (int i{ 0 }; i < 20; ++i) + { + std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); + } - MS_DUMP(" messageIntegrity: %s", messageIntegrity); - } - if (this->hasFingerprint) - MS_DUMP(" has fingerprint"); + MS_DUMP(" messageIntegrity: %s", messageIntegrity); + } + if (this->hasFingerprint) + MS_DUMP(" has fingerprint"); - MS_DUMP(""); - } + MS_DUMP(""); + } #endif - StunPacket::Authentication StunPacket::CheckAuthentication( - const std::string& localUsername, const std::string& localPassword) - { - MS_TRACE(); + StunPacket::Authentication StunPacket::CheckAuthentication( + const std::string& localUsername, const std::string& localPassword) + { + MS_TRACE(); - switch (this->klass) - { - case Class::REQUEST: - case Class::INDICATION: - { - // Both USERNAME and MESSAGE-INTEGRITY must be present. - if (!this->messageIntegrity || this->username.empty()) - return Authentication::BAD_REQUEST; + switch (this->klass) + { + case Class::REQUEST: + case Class::INDICATION: + { + // Both USERNAME and MESSAGE-INTEGRITY must be present. + if (!this->messageIntegrity || this->username.empty()) + return Authentication::BAD_REQUEST; - // Check that USERNAME attribute begins with our local username plus ":". - size_t localUsernameLen = localUsername.length(); + // Check that USERNAME attribute begins with our local username plus ":". + size_t localUsernameLen = localUsername.length(); - if ( - this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' || - (this->username.compare(0, localUsernameLen, localUsername) != 0)) - { - return Authentication::UNAUTHORIZED; - } + if ( + this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' || + (this->username.compare(0, localUsernameLen, localUsername) != 0)) + { + return Authentication::UNAUTHORIZED; + } - break; - } - // This method cannot check authentication in received responses (as we - // are ICE-Lite and don't generate requests). - case Class::SUCCESS_RESPONSE: - case Class::ERROR_RESPONSE: - { - MS_ERROR("cannot check authentication for a STUN response"); + break; + } + // This method cannot check authentication in received responses (as we + // are ICE-Lite and don't generate requests). + case Class::SUCCESS_RESPONSE: + case Class::ERROR_RESPONSE: + { + MS_ERROR("cannot check authentication for a STUN response"); - return Authentication::BAD_REQUEST; - } - } + return Authentication::BAD_REQUEST; + } + } - // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, - // so the header length field must be modified (and later restored). - if (this->hasFingerprint) - // Set the header length field: full size - header length (20) - FINGERPRINT length (8). - Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); + // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, + // so the header length field must be modified (and later restored). + if (this->hasFingerprint) + // Set the header length field: full size - header length (20) - FINGERPRINT length (8). + Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); - // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. + // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. auto computedMessageIntegrity = openssl_HMACsha1( localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data); - Authentication result; + Authentication result; - // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. - if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) - result = Authentication::OK; - else - result = Authentication::UNAUTHORIZED; + // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. + if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) + result = Authentication::OK; + else + result = Authentication::UNAUTHORIZED; - // Restore the header length field. - if (this->hasFingerprint) - Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20)); + // Restore the header length field. + if (this->hasFingerprint) + Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20)); - return result; - } + return result; + } - StunPacket* StunPacket::CreateSuccessResponse() - { - MS_TRACE(); + StunPacket* StunPacket::CreateSuccessResponse() + { + MS_TRACE(); - MS_ASSERT( - this->klass == Class::REQUEST, - "attempt to create a success response for a non Request STUN packet"); + MS_ASSERT( + this->klass == Class::REQUEST, + "attempt to create a success response for a non Request STUN packet"); - return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); - } + return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); + } - StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) - { - MS_TRACE(); + StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) + { + MS_TRACE(); - MS_ASSERT( - this->klass == Class::REQUEST, - "attempt to create an error response for a non Request STUN packet"); + MS_ASSERT( + this->klass == Class::REQUEST, + "attempt to create an error response for a non Request STUN packet"); - auto* response = - new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); + auto* response = + new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); - response->SetErrorCode(errorCode); + response->SetErrorCode(errorCode); - return response; - } + return response; + } - void StunPacket::Authenticate(const std::string& password) - { - // Just for Request, Indication and SuccessResponse messages. - if (this->klass == Class::ERROR_RESPONSE) - { - MS_ERROR("cannot set password for ErrorResponse messages"); + void StunPacket::Authenticate(const std::string& password) + { + // Just for Request, Indication and SuccessResponse messages. + if (this->klass == Class::ERROR_RESPONSE) + { + MS_ERROR("cannot set password for ErrorResponse messages"); - return; - } + return; + } - this->password = password; - } + this->password = password; + } - void StunPacket::Serialize(uint8_t* buffer) - { - MS_TRACE(); + void StunPacket::Serialize(uint8_t* buffer) + { + MS_TRACE(); - // Some useful variables. - uint16_t usernamePaddedLen{ 0 }; - uint16_t xorMappedAddressPaddedLen{ 0 }; - bool addXorMappedAddress = - ((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && - this->klass == Class::SUCCESS_RESPONSE); - bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); - bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); - bool addFingerprint{ true }; // Do always. + // Some useful variables. + uint16_t usernamePaddedLen{ 0 }; + uint16_t xorMappedAddressPaddedLen{ 0 }; + bool addXorMappedAddress = + ((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && + this->klass == Class::SUCCESS_RESPONSE); + bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); + bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); + bool addFingerprint{ true }; // Do always. - // Update data pointer. - this->data = buffer; + // Update data pointer. + this->data = buffer; - // First calculate the total required size for the entire packet. - this->size = 20; // Header. + // First calculate the total required size for the entire packet. + this->size = 20; // Header. - if (!this->username.empty()) - { - usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast(this->username.length())); - this->size += 4 + usernamePaddedLen; - } + if (!this->username.empty()) + { + usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast(this->username.length())); + this->size += 4 + usernamePaddedLen; + } - if (this->priority != 0u) - this->size += 4 + 4; + if (this->priority != 0u) + this->size += 4 + 4; - if (this->iceControlling != 0u) - this->size += 4 + 8; + if (this->iceControlling != 0u) + this->size += 4 + 8; - if (this->iceControlled != 0u) - this->size += 4 + 8; + if (this->iceControlled != 0u) + this->size += 4 + 8; - if (this->hasUseCandidate) - this->size += 4; + if (this->hasUseCandidate) + this->size += 4; - if (addXorMappedAddress) - { - switch (this->xorMappedAddress->sa_family) - { - case AF_INET: - { - xorMappedAddressPaddedLen = 8; - this->size += 4 + 8; + if (addXorMappedAddress) + { + switch (this->xorMappedAddress->sa_family) + { + case AF_INET: + { + xorMappedAddressPaddedLen = 8; + this->size += 4 + 8; - break; - } + break; + } - case AF_INET6: - { - xorMappedAddressPaddedLen = 20; - this->size += 4 + 20; + case AF_INET6: + { + xorMappedAddressPaddedLen = 20; + this->size += 4 + 20; - break; - } + break; + } - default: - { - MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); + default: + { + MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); - addXorMappedAddress = false; - } - } - } + addXorMappedAddress = false; + } + } + } - if (addErrorCode) - this->size += 4 + 4; + if (addErrorCode) + this->size += 4 + 4; - if (addMessageIntegrity) - this->size += 4 + 20; + if (addMessageIntegrity) + this->size += 4 + 20; - if (addFingerprint) - this->size += 4 + 4; + if (addFingerprint) + this->size += 4 + 4; - // Merge class and method fields into type. - uint16_t typeField = (static_cast(this->method) & 0x0f80) << 2; + // Merge class and method fields into type. + uint16_t typeField = (static_cast(this->method) & 0x0f80) << 2; - typeField |= (static_cast(this->method) & 0x0070) << 1; - typeField |= (static_cast(this->method) & 0x000f); - typeField |= (static_cast(this->klass) & 0x02) << 7; - typeField |= (static_cast(this->klass) & 0x01) << 4; + typeField |= (static_cast(this->method) & 0x0070) << 1; + typeField |= (static_cast(this->method) & 0x000f); + typeField |= (static_cast(this->klass) & 0x02) << 7; + typeField |= (static_cast(this->klass) & 0x01) << 4; - // Set type field. - Utils::Byte::Set2Bytes(buffer, 0, typeField); - // Set length field. - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size) - 20); - // Set magic cookie. - std::memcpy(buffer + 4, StunPacket::magicCookie, 4); - // Set TransactionId field. - std::memcpy(buffer + 8, this->transactionId, 12); - // Update the transaction ID pointer. - this->transactionId = buffer + 8; - // Add atributes. - size_t pos{ 20 }; + // Set type field. + Utils::Byte::Set2Bytes(buffer, 0, typeField); + // Set length field. + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size) - 20); + // Set magic cookie. + std::memcpy(buffer + 4, StunPacket::magicCookie, 4); + // Set TransactionId field. + std::memcpy(buffer + 8, this->transactionId, 12); + // Update the transaction ID pointer. + this->transactionId = buffer + 8; + // Add atributes. + size_t pos{ 20 }; - // Add USERNAME. - if (usernamePaddedLen != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USERNAME)); - Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast(this->username.length())); - std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); - pos += 4 + usernamePaddedLen; - } + // Add USERNAME. + if (usernamePaddedLen != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USERNAME)); + Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast(this->username.length())); + std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); + pos += 4 + usernamePaddedLen; + } - // Add PRIORITY. - if (this->priority != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::PRIORITY)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); - pos += 4 + 4; - } + // Add PRIORITY. + if (this->priority != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::PRIORITY)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); + pos += 4 + 4; + } - // Add ICE-CONTROLLING. - if (this->iceControlling != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLING)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 8); - Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); - pos += 4 + 8; - } + // Add ICE-CONTROLLING. + if (this->iceControlling != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLING)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 8); + Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); + pos += 4 + 8; + } - // Add ICE-CONTROLLED. - if (this->iceControlled != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLED)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 8); - Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled); - pos += 4 + 8; - } + // Add ICE-CONTROLLED. + if (this->iceControlled != 0u) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLED)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 8); + Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled); + pos += 4 + 8; + } - // Add USE-CANDIDATE. - if (this->hasUseCandidate) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USE_CANDIDATE)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 0); - pos += 4; - } + // Add USE-CANDIDATE. + if (this->hasUseCandidate) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USE_CANDIDATE)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 0); + pos += 4; + } - // Add XOR-MAPPED-ADDRESS - if (addXorMappedAddress) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::XOR_MAPPED_ADDRESS)); - Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen); + // Add XOR-MAPPED-ADDRESS + if (addXorMappedAddress) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::XOR_MAPPED_ADDRESS)); + Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen); - uint8_t* attrValue = buffer + pos + 4; + uint8_t* attrValue = buffer + pos + 4; - switch (this->xorMappedAddress->sa_family) - { - case AF_INET: - { - // Set first byte to 0. - attrValue[0] = 0; - // Set inet family. - attrValue[1] = 0x01; - // Set port and XOR it. - std::memcpy( - attrValue + 2, - &(reinterpret_cast(this->xorMappedAddress))->sin_port, - 2); - attrValue[2] ^= StunPacket::magicCookie[0]; - attrValue[3] ^= StunPacket::magicCookie[1]; - // Set address and XOR it. - std::memcpy( - attrValue + 4, - &(reinterpret_cast(this->xorMappedAddress))->sin_addr.s_addr, - 4); - attrValue[4] ^= StunPacket::magicCookie[0]; - attrValue[5] ^= StunPacket::magicCookie[1]; - attrValue[6] ^= StunPacket::magicCookie[2]; - attrValue[7] ^= StunPacket::magicCookie[3]; + switch (this->xorMappedAddress->sa_family) + { + case AF_INET: + { + // Set first byte to 0. + attrValue[0] = 0; + // Set inet family. + attrValue[1] = 0x01; + // Set port and XOR it. + std::memcpy( + attrValue + 2, + &(reinterpret_cast(this->xorMappedAddress))->sin_port, + 2); + attrValue[2] ^= StunPacket::magicCookie[0]; + attrValue[3] ^= StunPacket::magicCookie[1]; + // Set address and XOR it. + std::memcpy( + attrValue + 4, + &(reinterpret_cast(this->xorMappedAddress))->sin_addr.s_addr, + 4); + attrValue[4] ^= StunPacket::magicCookie[0]; + attrValue[5] ^= StunPacket::magicCookie[1]; + attrValue[6] ^= StunPacket::magicCookie[2]; + attrValue[7] ^= StunPacket::magicCookie[3]; - pos += 4 + 8; + pos += 4 + 8; - break; - } + break; + } - case AF_INET6: - { - // Set first byte to 0. - attrValue[0] = 0; - // Set inet family. - attrValue[1] = 0x02; - // Set port and XOR it. - std::memcpy( - attrValue + 2, - &(reinterpret_cast(this->xorMappedAddress))->sin6_port, - 2); - attrValue[2] ^= StunPacket::magicCookie[0]; - attrValue[3] ^= StunPacket::magicCookie[1]; - // Set address and XOR it. - std::memcpy( - attrValue + 4, - &(reinterpret_cast(this->xorMappedAddress))->sin6_addr.s6_addr, - 16); - attrValue[4] ^= StunPacket::magicCookie[0]; - attrValue[5] ^= StunPacket::magicCookie[1]; - attrValue[6] ^= StunPacket::magicCookie[2]; - attrValue[7] ^= StunPacket::magicCookie[3]; - attrValue[8] ^= this->transactionId[0]; - attrValue[9] ^= this->transactionId[1]; - attrValue[10] ^= this->transactionId[2]; - attrValue[11] ^= this->transactionId[3]; - attrValue[12] ^= this->transactionId[4]; - attrValue[13] ^= this->transactionId[5]; - attrValue[14] ^= this->transactionId[6]; - attrValue[15] ^= this->transactionId[7]; - attrValue[16] ^= this->transactionId[8]; - attrValue[17] ^= this->transactionId[9]; - attrValue[18] ^= this->transactionId[10]; - attrValue[19] ^= this->transactionId[11]; + case AF_INET6: + { + // Set first byte to 0. + attrValue[0] = 0; + // Set inet family. + attrValue[1] = 0x02; + // Set port and XOR it. + std::memcpy( + attrValue + 2, + &(reinterpret_cast(this->xorMappedAddress))->sin6_port, + 2); + attrValue[2] ^= StunPacket::magicCookie[0]; + attrValue[3] ^= StunPacket::magicCookie[1]; + // Set address and XOR it. + std::memcpy( + attrValue + 4, + &(reinterpret_cast(this->xorMappedAddress))->sin6_addr.s6_addr, + 16); + attrValue[4] ^= StunPacket::magicCookie[0]; + attrValue[5] ^= StunPacket::magicCookie[1]; + attrValue[6] ^= StunPacket::magicCookie[2]; + attrValue[7] ^= StunPacket::magicCookie[3]; + attrValue[8] ^= this->transactionId[0]; + attrValue[9] ^= this->transactionId[1]; + attrValue[10] ^= this->transactionId[2]; + attrValue[11] ^= this->transactionId[3]; + attrValue[12] ^= this->transactionId[4]; + attrValue[13] ^= this->transactionId[5]; + attrValue[14] ^= this->transactionId[6]; + attrValue[15] ^= this->transactionId[7]; + attrValue[16] ^= this->transactionId[8]; + attrValue[17] ^= this->transactionId[9]; + attrValue[18] ^= this->transactionId[10]; + attrValue[19] ^= this->transactionId[11]; - pos += 4 + 20; + pos += 4 + 20; - break; - } - } - } + break; + } + } + } - // Add ERROR-CODE. - if (addErrorCode) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ERROR_CODE)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + // Add ERROR-CODE. + if (addErrorCode) + { + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ERROR_CODE)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - auto codeClass = static_cast(this->errorCode / 100); - uint8_t codeNumber = static_cast(this->errorCode) - (codeClass * 100); + auto codeClass = static_cast(this->errorCode / 100); + uint8_t codeNumber = static_cast(this->errorCode) - (codeClass * 100); - Utils::Byte::Set2Bytes(buffer, pos + 4, 0); - Utils::Byte::Set1Byte(buffer, pos + 6, codeClass); - Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber); - pos += 4 + 4; - } + Utils::Byte::Set2Bytes(buffer, pos + 4, 0); + Utils::Byte::Set1Byte(buffer, pos + 6, codeClass); + Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber); + pos += 4 + 4; + } - // Add MESSAGE-INTEGRITY. - if (addMessageIntegrity) - { - // Ignore FINGERPRINT. - if (addFingerprint) - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); + // Add MESSAGE-INTEGRITY. + if (addMessageIntegrity) + { + // Ignore FINGERPRINT. + if (addFingerprint) + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); - // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. + // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos); Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::MESSAGE_INTEGRITY)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 20); - std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); + Utils::Byte::Set2Bytes(buffer, pos + 2, 20); + std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); - // Update the pointer. - this->messageIntegrity = buffer + pos + 4; - pos += 4 + 20; + // Update the pointer. + this->messageIntegrity = buffer + pos + 4; + pos += 4 + 20; - // Restore length field. - if (addFingerprint) - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20)); - } - else - { - // Unset the pointer (if it was set). - this->messageIntegrity = nullptr; - } + // Restore length field. + if (addFingerprint) + Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20)); + } + else + { + // Unset the pointer (if it was set). + this->messageIntegrity = nullptr; + } - // Add FINGERPRINT. - if (addFingerprint) - { - // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT - // attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; + // Add FINGERPRINT. + if (addFingerprint) + { + // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT + // attribute and XOR it with 0x5354554e. + uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); - pos += 4 + 4; + Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); + Utils::Byte::Set2Bytes(buffer, pos + 2, 4); + Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); + pos += 4 + 4; - // Set flag. - this->hasFingerprint = true; - } - else - { - this->hasFingerprint = false; - } + // Set flag. + this->hasFingerprint = true; + } + else + { + this->hasFingerprint = false; + } - MS_ASSERT(pos == this->size, "pos != this->size"); - } + MS_ASSERT(pos == this->size, "pos != this->size"); + } } // namespace RTC diff --git a/webrtc/StunPacket.hpp b/webrtc/StunPacket.hpp index a6b2c940..2776a9b6 100644 --- a/webrtc/StunPacket.hpp +++ b/webrtc/StunPacket.hpp @@ -26,188 +26,188 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. namespace RTC { - class StunPacket - { - public: - // STUN message class. - enum class Class : uint16_t - { - REQUEST = 0, - INDICATION = 1, - SUCCESS_RESPONSE = 2, - ERROR_RESPONSE = 3 - }; + class StunPacket + { + public: + // STUN message class. + enum class Class : uint16_t + { + REQUEST = 0, + INDICATION = 1, + SUCCESS_RESPONSE = 2, + ERROR_RESPONSE = 3 + }; - // STUN message method. - enum class Method : uint16_t - { - BINDING = 1 - }; + // STUN message method. + enum class Method : uint16_t + { + BINDING = 1 + }; - // Attribute type. - enum class Attribute : uint16_t - { - MAPPED_ADDRESS = 0x0001, - USERNAME = 0x0006, - MESSAGE_INTEGRITY = 0x0008, - ERROR_CODE = 0x0009, - UNKNOWN_ATTRIBUTES = 0x000A, - REALM = 0x0014, - NONCE = 0x0015, - XOR_MAPPED_ADDRESS = 0x0020, - PRIORITY = 0x0024, - USE_CANDIDATE = 0x0025, - SOFTWARE = 0x8022, - ALTERNATE_SERVER = 0x8023, - FINGERPRINT = 0x8028, - ICE_CONTROLLED = 0x8029, - ICE_CONTROLLING = 0x802A - }; + // Attribute type. + enum class Attribute : uint16_t + { + MAPPED_ADDRESS = 0x0001, + USERNAME = 0x0006, + MESSAGE_INTEGRITY = 0x0008, + ERROR_CODE = 0x0009, + UNKNOWN_ATTRIBUTES = 0x000A, + REALM = 0x0014, + NONCE = 0x0015, + XOR_MAPPED_ADDRESS = 0x0020, + PRIORITY = 0x0024, + USE_CANDIDATE = 0x0025, + SOFTWARE = 0x8022, + ALTERNATE_SERVER = 0x8023, + FINGERPRINT = 0x8028, + ICE_CONTROLLED = 0x8029, + ICE_CONTROLLING = 0x802A + }; - // Authentication result. - enum class Authentication - { - OK = 0, - UNAUTHORIZED = 1, - BAD_REQUEST = 2 - }; + // Authentication result. + enum class Authentication + { + OK = 0, + UNAUTHORIZED = 1, + BAD_REQUEST = 2 + }; - public: - static bool IsStun(const uint8_t* data, size_t len) - { - // clang-format off - return ( - // STUN headers are 20 bytes. - (len >= 20) && - // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes - (data[0] < 3) && - // Magic cookie must match. - (data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) && - (data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3]) - ); - // clang-format on - } - static StunPacket* Parse(const uint8_t* data, size_t len); + public: + static bool IsStun(const uint8_t* data, size_t len) + { + // clang-format off + return ( + // STUN headers are 20 bytes. + (len >= 20) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] < 3) && + // Magic cookie must match. + (data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) && + (data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3]) + ); + // clang-format on + } + static StunPacket* Parse(const uint8_t* data, size_t len); - private: - static const uint8_t magicCookie[]; + private: + static const uint8_t magicCookie[]; - public: - StunPacket( - Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size); - ~StunPacket(); + public: + StunPacket( + Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size); + ~StunPacket(); - void Dump() const; - Class GetClass() const - { - return this->klass; - } - Method GetMethod() const - { - return this->method; - } - const uint8_t* GetData() const - { - return this->data; - } - size_t GetSize() const - { - return this->size; - } - void SetUsername(const char* username, size_t len) - { - this->username.assign(username, len); - } - void SetPriority(uint32_t priority) - { - this->priority = priority; - } - void SetIceControlling(uint64_t iceControlling) - { - this->iceControlling = iceControlling; - } - void SetIceControlled(uint64_t iceControlled) - { - this->iceControlled = iceControlled; - } - void SetUseCandidate() - { - this->hasUseCandidate = true; - } - void SetXorMappedAddress(const struct sockaddr* xorMappedAddress) - { - this->xorMappedAddress = xorMappedAddress; - } - void SetErrorCode(uint16_t errorCode) - { - this->errorCode = errorCode; - } - void SetMessageIntegrity(const uint8_t* messageIntegrity) - { - this->messageIntegrity = messageIntegrity; - } - void SetFingerprint() - { - this->hasFingerprint = true; - } - const std::string& GetUsername() const - { - return this->username; - } - uint32_t GetPriority() const - { - return this->priority; - } - uint64_t GetIceControlling() const - { - return this->iceControlling; - } - uint64_t GetIceControlled() const - { - return this->iceControlled; - } - bool HasUseCandidate() const - { - return this->hasUseCandidate; - } - uint16_t GetErrorCode() const - { - return this->errorCode; - } - bool HasMessageIntegrity() const - { - return (this->messageIntegrity ? true : false); - } - bool HasFingerprint() const - { - return this->hasFingerprint; - } - Authentication CheckAuthentication( - const std::string& localUsername, const std::string& localPassword); - StunPacket* CreateSuccessResponse(); - StunPacket* CreateErrorResponse(uint16_t errorCode); - void Authenticate(const std::string& password); - void Serialize(uint8_t* buffer); + void Dump() const; + Class GetClass() const + { + return this->klass; + } + Method GetMethod() const + { + return this->method; + } + const uint8_t* GetData() const + { + return this->data; + } + size_t GetSize() const + { + return this->size; + } + void SetUsername(const char* username, size_t len) + { + this->username.assign(username, len); + } + void SetPriority(uint32_t priority) + { + this->priority = priority; + } + void SetIceControlling(uint64_t iceControlling) + { + this->iceControlling = iceControlling; + } + void SetIceControlled(uint64_t iceControlled) + { + this->iceControlled = iceControlled; + } + void SetUseCandidate() + { + this->hasUseCandidate = true; + } + void SetXorMappedAddress(const struct sockaddr* xorMappedAddress) + { + this->xorMappedAddress = xorMappedAddress; + } + void SetErrorCode(uint16_t errorCode) + { + this->errorCode = errorCode; + } + void SetMessageIntegrity(const uint8_t* messageIntegrity) + { + this->messageIntegrity = messageIntegrity; + } + void SetFingerprint() + { + this->hasFingerprint = true; + } + const std::string& GetUsername() const + { + return this->username; + } + uint32_t GetPriority() const + { + return this->priority; + } + uint64_t GetIceControlling() const + { + return this->iceControlling; + } + uint64_t GetIceControlled() const + { + return this->iceControlled; + } + bool HasUseCandidate() const + { + return this->hasUseCandidate; + } + uint16_t GetErrorCode() const + { + return this->errorCode; + } + bool HasMessageIntegrity() const + { + return (this->messageIntegrity ? true : false); + } + bool HasFingerprint() const + { + return this->hasFingerprint; + } + Authentication CheckAuthentication( + const std::string& localUsername, const std::string& localPassword); + StunPacket* CreateSuccessResponse(); + StunPacket* CreateErrorResponse(uint16_t errorCode); + void Authenticate(const std::string& password); + void Serialize(uint8_t* buffer); - private: - // Passed by argument. - Class klass; // 2 bytes. - Method method; // 2 bytes. - const uint8_t* transactionId{ nullptr }; // 12 bytes. - uint8_t* data{ nullptr }; // Pointer to binary data. - size_t size{ 0u }; // The full message size (including header). - // STUN attributes. - std::string username; // Less than 513 bytes. - uint32_t priority{ 0u }; // 4 bytes unsigned integer. - uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer. - uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer. - bool hasUseCandidate{ false }; // 0 bytes. - const uint8_t* messageIntegrity{ nullptr }; // 20 bytes. - bool hasFingerprint{ false }; // 4 bytes. - const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes. - uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase). - std::string password; - }; + private: + // Passed by argument. + Class klass; // 2 bytes. + Method method; // 2 bytes. + const uint8_t* transactionId{ nullptr }; // 12 bytes. + uint8_t* data{ nullptr }; // Pointer to binary data. + size_t size{ 0u }; // The full message size (including header). + // STUN attributes. + std::string username; // Less than 513 bytes. + uint32_t priority{ 0u }; // 4 bytes unsigned integer. + uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer. + uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer. + bool hasUseCandidate{ false }; // 0 bytes. + const uint8_t* messageIntegrity{ nullptr }; // 20 bytes. + bool hasFingerprint{ false }; // 4 bytes. + const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes. + uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase). + std::string password; + }; } // namespace RTC #endif diff --git a/webrtc/WebRtcPlayer.cpp b/webrtc/WebRtcPlayer.cpp index 3b533594..bbdb6274 100644 --- a/webrtc/WebRtcPlayer.cpp +++ b/webrtc/WebRtcPlayer.cpp @@ -70,21 +70,17 @@ void WebRtcPlayer::onStartWebRTC() { } } void WebRtcPlayer::onDestory() { - WebRtcTransportImp::onDestory(); - auto duration = getDuration(); auto bytes_usage = getBytesUsage(); //流量统计事件广播 GET_CONFIG(uint32_t, iFlowThreshold, General::kFlowThreshold); if (_reader && getSession()) { - WarnL << "RTC播放器(" - << _media_info.shortUrl() - << ")结束播放,耗时(s):" << duration; + WarnL << "RTC播放器(" << _media_info.shortUrl() << ")结束播放,耗时(s):" << duration; if (bytes_usage >= iFlowThreshold * 1024) { - NoticeCenter::Instance().emitEvent(Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, - true, static_cast(*getSession())); + NoticeCenter::Instance().emitEvent(Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, true, static_cast(*getSession())); } } + WebRtcTransportImp::onDestory(); } void WebRtcPlayer::onRtcConfigure(RtcConfigure &configure) const { diff --git a/webrtc/WebRtcPusher.cpp b/webrtc/WebRtcPusher.cpp index 06aba8c1..d5c6d063 100644 --- a/webrtc/WebRtcPusher.cpp +++ b/webrtc/WebRtcPusher.cpp @@ -118,20 +118,15 @@ void WebRtcPusher::onStartWebRTC() { } void WebRtcPusher::onDestory() { - WebRtcTransportImp::onDestory(); - auto duration = getDuration(); auto bytes_usage = getBytesUsage(); //流量统计事件广播 GET_CONFIG(uint32_t, iFlowThreshold, General::kFlowThreshold); if (getSession()) { - WarnL << "RTC推流器(" - << _media_info.shortUrl() - << ")结束推流,耗时(s):" << duration; + WarnL << "RTC推流器(" << _media_info.shortUrl() << ")结束推流,耗时(s):" << duration; if (bytes_usage >= iFlowThreshold * 1024) { - NoticeCenter::Instance().emitEvent(Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, - false, static_cast(*getSession())); + NoticeCenter::Instance().emitEvent(Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, false, static_cast(*getSession())); } } @@ -142,6 +137,7 @@ void WebRtcPusher::onDestory() { auto push_src = std::move(_push_src); getPoller()->doDelayTask(_continue_push_ms, [push_src]() { return 0; }); } + WebRtcTransportImp::onDestory(); } void WebRtcPusher::onRtcConfigure(RtcConfigure &configure) const { diff --git a/webrtc/WebRtcSession.cpp b/webrtc/WebRtcSession.cpp index c797ddb0..ade1ce20 100644 --- a/webrtc/WebRtcSession.cpp +++ b/webrtc/WebRtcSession.cpp @@ -48,8 +48,6 @@ EventPoller::Ptr WebRtcSession::queryPoller(const Buffer::Ptr &buffer) { //////////////////////////////////////////////////////////////////////////////// WebRtcSession::WebRtcSession(const Socket::Ptr &sock) : Session(sock) { - socklen_t addr_len = sizeof(_peer_addr); - getpeername(sock->rawFD(), (struct sockaddr *)&_peer_addr, &addr_len); _over_tcp = sock->sockType() == SockNum::Sock_TCP; } @@ -87,14 +85,12 @@ void WebRtcSession::onRecv_l(const char *data, size_t len) { //3、销毁原先的socket和WebRtcSession(原先的对象跟WebRtcTransport不在同一条线程) throw std::runtime_error("webrtc over tcp change poller: " + getPoller()->getThreadName() + " -> " + sock->getPoller()->getThreadName()); } - - transport->setSession(shared_from_this()); _transport = std::move(transport); InfoP(this); } _ticker.resetTime(); CHECK(_transport); - _transport->inputSockData((char *)data, len, (struct sockaddr *)&_peer_addr); + _transport->inputSockData((char *)data, len, this); } void WebRtcSession::onRecv(const Buffer::Ptr &buffer) { @@ -114,9 +110,13 @@ void WebRtcSession::onError(const SockException &err) { if (!_transport) { return; } + auto self = shared_from_this(); auto transport = std::move(_transport); - getPoller()->async([transport] { + getPoller()->async([transport, self]() mutable { //延时减引用,防止使用transport对象时,销毁对象 + transport->removeTuple(self.get()); + //确保transport在Session对象前销毁,防止WebRtcTransport::onDestory()时获取不到Session对象 + transport = nullptr; }, false); } diff --git a/webrtc/WebRtcSession.h b/webrtc/WebRtcSession.h index 6ce881ba..f70d5e74 100644 --- a/webrtc/WebRtcSession.h +++ b/webrtc/WebRtcSession.h @@ -46,7 +46,6 @@ private: bool _over_tcp = false; bool _find_transport = true; Ticker _ticker; - struct sockaddr_storage _peer_addr; std::weak_ptr _server; WebRtcTransportImp::Ptr _transport; }; diff --git a/webrtc/WebRtcTransport.cpp b/webrtc/WebRtcTransport.cpp index e9de6f6b..09d0d466 100644 --- a/webrtc/WebRtcTransport.cpp +++ b/webrtc/WebRtcTransport.cpp @@ -15,6 +15,7 @@ #include "Rtcp/Rtcp.h" #include "Rtcp/RtcpFCI.h" #include "Rtcp/RtcpContext.h" +#include "Rtsp/Rtsp.h" #include "Rtsp/RtpReceiver.h" #include "WebRtcTransport.h" @@ -74,6 +75,17 @@ static void translateIPFromEnv(std::vector &v) { } } +const char* sockTypeStr(Session* session) { + if (session) { + switch (session->getSock()->sockType()) { + case SockNum::Sock_TCP: return "tcp"; + case SockNum::Sock_UDP: return "udp"; + default: break; + } + } + return "unknown"; +} + WebRtcTransport::WebRtcTransport(const EventPoller::Ptr &poller) { _poller = poller; _identifier = "zlm_" + to_string(++s_key); @@ -108,16 +120,18 @@ void WebRtcTransport::OnIceServerSendStunPacket( sendSockData((char *)packet->GetData(), packet->GetSize(), tuple); } -void WebRtcTransport::OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) { - InfoL; +void WebRtcTransportImp::OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) { + InfoL << getIdentifier() << " select tuple " << sockTypeStr(tuple) << " " << tuple->get_peer_ip() << ":" << tuple->get_peer_port(); + tuple->setSendFlushFlag(false); + unrefSelf(); } void WebRtcTransport::OnIceServerConnected(const RTC::IceServer *iceServer) { - InfoL; + InfoL << getIdentifier(); } void WebRtcTransport::OnIceServerCompleted(const RTC::IceServer *iceServer) { - InfoL; + InfoL << getIdentifier(); if (_answer_sdp->media[0].role == DtlsRole::passive) { _dtls_transport->Run(RTC::DtlsTransport::Role::SERVER); } else { @@ -126,7 +140,7 @@ void WebRtcTransport::OnIceServerCompleted(const RTC::IceServer *iceServer) { } void WebRtcTransport::OnIceServerDisconnected(const RTC::IceServer *iceServer) { - InfoL; + InfoL << getIdentifier(); } ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -134,7 +148,7 @@ void WebRtcTransport::OnIceServerDisconnected(const RTC::IceServer *iceServer) { void WebRtcTransport::OnDtlsTransportConnected( const RTC::DtlsTransport *dtlsTransport, RTC::SrtpSession::CryptoSuite srtpCryptoSuite, uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) { - InfoL; + InfoL << getIdentifier(); _srtp_session_send = std::make_shared( RTC::SrtpSession::Type::OUTBOUND, srtpCryptoSuite, srtpLocalKey, srtpLocalKeyLen); _srtp_session_recv = std::make_shared( @@ -152,16 +166,16 @@ void WebRtcTransport::OnDtlsTransportSendData( } void WebRtcTransport::OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) { - InfoL; + InfoL << getIdentifier(); } void WebRtcTransport::OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) { - InfoL; + InfoL << getIdentifier(); onShutdown(SockException(Err_shutdown, "dtls transport failed")); } void WebRtcTransport::OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) { - InfoL; + InfoL << getIdentifier(); onShutdown(SockException(Err_shutdown, "dtls close notify received")); } @@ -177,7 +191,7 @@ void WebRtcTransport::OnDtlsTransportApplicationDataReceived( ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_SCTP void WebRtcTransport::OnSctpAssociationConnecting(RTC::SctpAssociation *sctpAssociation) { - TraceL; + TraceL << getIdentifier(); } void WebRtcTransport::OnSctpAssociationConnected(RTC::SctpAssociation *sctpAssociation) { @@ -214,8 +228,9 @@ void WebRtcTransport::sendSockData(const char *buf, size_t len, RTC::TransportTu onSendSockData(std::move(pkt), true, tuple ? tuple : _ice_server->GetSelectedTuple()); } -RTC::TransportTuple *WebRtcTransport::getSelectedTuple() const { - return _ice_server->GetSelectedTuple(); +Session::Ptr WebRtcTransport::getSession() const { + auto tuple = _ice_server->GetSelectedTuple(true); + return tuple ? tuple->shared_from_this() : nullptr; } void WebRtcTransport::sendRtcpRemb(uint32_t ssrc, size_t bit_rate) { @@ -287,22 +302,12 @@ std::string WebRtcTransport::getAnswerSdp(const string &offer) { } } -static bool is_dtls(char *buf) { +static bool isDtls(char *buf) { return ((*buf > 19) && (*buf < 64)); } -static bool is_rtp(char *buf) { - RtpHeader *header = (RtpHeader *)buf; - return ((header->pt < 64) || (header->pt >= 96)); -} - -static bool is_rtcp(char *buf) { - RtpHeader *header = (RtpHeader *)buf; - return ((header->pt >= 64) && (header->pt < 96)); -} - static string getPeerAddress(RTC::TransportTuple *tuple) { - return SockUtil::inet_ntoa(tuple); + return tuple->get_peer_ip(); } void WebRtcTransport::inputSockData(char *buf, int len, RTC::TransportTuple *tuple) { @@ -315,11 +320,11 @@ void WebRtcTransport::inputSockData(char *buf, int len, RTC::TransportTuple *tup _ice_server->ProcessStunPacket(packet.get(), tuple); return; } - if (is_dtls(buf)) { + if (isDtls(buf)) { _dtls_transport->ProcessDtlsData((uint8_t *)buf, len); return; } - if (is_rtp(buf)) { + if (isRtp(buf, len)) { if (!_srtp_session_recv) { WarnL << "received rtp packet when dtls not completed from:" << getPeerAddress(tuple); return; @@ -329,7 +334,7 @@ void WebRtcTransport::inputSockData(char *buf, int len, RTC::TransportTuple *tup } return; } - if (is_rtcp(buf)) { + if (isRtcp(buf, len)) { if (!_srtp_session_recv) { WarnL << "received rtcp packet when dtls not completed from:" << getPeerAddress(tuple); return; @@ -418,24 +423,27 @@ void WebRtcTransportImp::onDestory() { } void WebRtcTransportImp::onSendSockData(Buffer::Ptr buf, bool flush, RTC::TransportTuple *tuple) { - if (!_selected_session) { - WarnL << "send data failed:" << buf->size(); - return; + if (tuple == nullptr) { + tuple = _ice_server->GetSelectedTuple(); + if (!tuple) { + WarnL << "send data failed:" << buf->size(); + return; + } } // 一次性发送一帧的rtp数据,提高网络io性能 - if (_selected_session->getSock()->sockType() == SockNum::Sock_TCP) { + if (tuple->getSock()->sockType() == SockNum::Sock_TCP) { // 增加tcp两字节头 auto len = buf->size(); char tcp_len[2] = { 0 }; tcp_len[0] = (len >> 8) & 0xff; tcp_len[1] = len & 0xff; - _selected_session->SockSender::send(tcp_len, 2); + tuple->SockSender::send(tcp_len, 2); } - _selected_session->send(std::move(buf)); + tuple->send(std::move(buf)); if (flush) { - _selected_session->flushAll(); + tuple->flushAll(); } } @@ -1049,28 +1057,14 @@ void WebRtcTransportImp::onBeforeEncryptRtp(const char *buf, int &len, void *ctx void WebRtcTransportImp::onShutdown(const SockException &ex) { WarnL << ex.what(); unrefSelf(); - for (auto &pr : _history_sessions) { - auto session = pr.second.lock(); - if (session) { - session->shutdown(ex); - } + for (auto &tuple : _ice_server->GetTuples()) { + tuple->shutdown(ex); } } -void WebRtcTransportImp::setSession(Session::Ptr session) { - _history_sessions.emplace(session.get(), session); - if (_selected_session) { - InfoL << "rtc network changed: " << _selected_session->get_peer_ip() << ":" - << _selected_session->get_peer_port() << " -> " << session->get_peer_ip() << ":" - << session->get_peer_port() << ", id:" << getIdentifier(); - } - _selected_session = std::move(session); - _selected_session->setSendFlushFlag(false); - unrefSelf(); -} - -const Session::Ptr &WebRtcTransportImp::getSession() const { - return _selected_session; +void WebRtcTransportImp::removeTuple(RTC::TransportTuple *tuple) { + InfoL << getIdentifier() << " remove tuple " << tuple->get_peer_ip() << ":" << tuple->get_peer_port(); + this->_ice_server->RemoveTuple(tuple); } uint64_t WebRtcTransportImp::getBytesUsage() const { diff --git a/webrtc/WebRtcTransport.h b/webrtc/WebRtcTransport.h index 3978864b..dfae8012 100644 --- a/webrtc/WebRtcTransport.h +++ b/webrtc/WebRtcTransport.h @@ -110,6 +110,7 @@ public: void sendRtcpPacket(const char *buf, int len, bool flush, void *ctx = nullptr); const EventPoller::Ptr& getPoller() const; + Session::Ptr getSession() const; protected: //// dtls相关的回调 //// @@ -130,7 +131,6 @@ protected: protected: //// ice相关的回调 /// void OnIceServerSendStunPacket(const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) override; - void OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) override; void OnIceServerConnected(const RTC::IceServer *iceServer) override; void OnIceServerCompleted(const RTC::IceServer *iceServer) override; void OnIceServerDisconnected(const RTC::IceServer *iceServer) override; @@ -159,7 +159,6 @@ protected: virtual void onRtcpBye() = 0; protected: - RTC::TransportTuple* getSelectedTuple() const; void sendRtcpRemb(uint32_t ssrc, size_t bit_rate); void sendRtcpPli(uint32_t ssrc); @@ -170,11 +169,11 @@ private: protected: RtcSession::Ptr _offer_sdp; RtcSession::Ptr _answer_sdp; + std::shared_ptr _ice_server; private: std::string _identifier; EventPoller::Ptr _poller; - std::shared_ptr _ice_server; std::shared_ptr _dtls_transport; std::shared_ptr _srtp_session_send; std::shared_ptr _srtp_session_recv; @@ -239,8 +238,6 @@ public: using Ptr = std::shared_ptr; ~WebRtcTransportImp() override; - void setSession(Session::Ptr session); - const Session::Ptr& getSession() const; uint64_t getBytesUsage() const; uint64_t getDuration() const; bool canSendRtp() const; @@ -248,8 +245,10 @@ public: void onSendRtp(const RtpPacket::Ptr &rtp, bool flush, bool rtx = false); void createRtpChannel(const std::string &rid, uint32_t ssrc, MediaTrack &track); + void removeTuple(RTC::TransportTuple* tuple); protected: + void OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) override; WebRtcTransportImp(const EventPoller::Ptr &poller,bool preferred_tcp = false); void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; void onStartWebRTC() override; @@ -292,10 +291,6 @@ private: Ticker _alive_ticker; //pli rtcp计时器 Ticker _pli_ticker; - //当前选中的udp链接 - Session::Ptr _selected_session; - //链接迁移前后使用过的udp链接 - std::unordered_map > _history_sessions; //twcc rtcp发送上下文对象 TwccContext _twcc_ctx; //根据发送rtp的track类型获取相关信息