Compare commits

...

44 Commits

Author SHA1 Message Date
孙小云 a9f2eb2b80 开启HLS 2025-12-10 14:06:52 +08:00
孙小云 957f96305e 添加说明 2025-12-10 11:24:56 +08:00
孙小云 97298b76f9 删除无效备份 2025-12-09 15:04:17 +08:00
孙小云 88e3b3deda 删除无效备份 2025-12-09 14:45:40 +08:00
孙小云 95991bf58e 修改空指针BUG 2025-12-09 11:04:48 +08:00
孙小云 a138f864ee 添加注释 2025-12-09 10:42:10 +08:00
孙小云 d77ba30d3a 添加多端口支持 2025-12-09 10:25:50 +08:00
孙小云 327fa23296 修改说明 2025-12-09 10:17:11 +08:00
孙小云 fbc1af2db4 修改说明文档 2025-12-09 09:19:08 +08:00
孙小云 e4c9d0c3bf ddd 2025-12-09 09:09:45 +08:00
孙小云 4c11c451aa webrtc 2025-12-09 09:09:02 +08:00
孙小云 b4e2b31836 xx 2025-12-09 08:52:00 +08:00
孙小云 7e1e259bbe t4est 2025-12-09 08:22:13 +08:00
孙小云 dcd5f26fb5 修改说明文档 2025-12-07 13:06:33 +08:00
孙小云 b3438f9dcf add media构建 2025-12-07 12:58:53 +08:00
孙小云 548ed256da 将 ZLM 从独立 git 仓库转换为普通目录
- 删除 ZLM 的 .git 目录
- 将 ZLM 及其所有源码添加到主项目
- 原 git 仓库信息已备份到 /tmp/zlm_git_remote_backup.txt
- 包含所有 3rdpart 依赖库代码
2025-12-07 12:57:53 +08:00
孙小云 9e965e3e3c xx 2025-12-06 23:15:33 +08:00
孙小云 db83764b3a 更新版本号为version 2025-12-06 22:32:42 +08:00
孙小云 5b9786a501 更新使用说明 2025-12-06 22:30:38 +08:00
孙小云 e15737a935 将 wvpcode 从子模块转换为项目直接内容 2025-12-06 22:23:35 +08:00
孙小云 10926f6af7 修改npm 2025-12-06 22:18:00 +08:00
孙小云 5284e16cb4 update submodule 2025-12-06 22:13:26 +08:00
孙小云 39c075ddbe 使用缓存 2025-12-06 22:10:31 +08:00
孙小云 a3e2e3bfe3 使用缓存 2025-12-06 22:09:57 +08:00
孙小云 b31a07d1fb web 2025-12-06 17:53:25 +08:00
孙小云 7c16ebf3e4 修改配置 2025-12-06 17:10:17 +08:00
孙小云 1e4db4588c 添加linux的支持 2025-12-06 16:43:16 +08:00
孙小云 8190269338 添加使用说明 2025-12-06 15:41:26 +08:00
孙小云 8fb3476840 添加描述文档 2025-12-06 15:20:37 +08:00
孙小云 7325ab9ba1 修改格式错误 2025-12-06 15:07:49 +08:00
孙小云 25cea46313 统一配置 2025-12-06 15:01:14 +08:00
孙小云 e9642d6054 统一配置 2025-12-06 15:00:45 +08:00
孙小云 23507271da 修改配置 2025-12-06 14:14:28 +08:00
孙小云 1276aa411a 删除无用文件 2025-12-06 10:03:57 +08:00
孙小云 0f8c7b6780 删除多余文件 2025-12-06 09:55:11 +08:00
孙小云 9ada445586 删除libs 2025-12-06 09:52:52 +08:00
孙小云 3d633dc032 删除多余文件 2025-12-06 09:52:40 +08:00
孙小云 a94982db84 删除多余文件 2025-12-06 09:52:16 +08:00
孙小云 8108266ca9 删除github文件 2025-12-06 09:51:53 +08:00
孙小云 14e0048e3b 删除src文件 2025-12-06 09:51:35 +08:00
孙小云 474e3a671c 删除web文件 2025-12-06 09:51:15 +08:00
孙小云 a595f6387b 添加清空文件 2025-12-06 09:50:39 +08:00
孙小云 fbd78d42a3 删除doc文件 2025-12-06 09:50:08 +08:00
孙小云 1a3659177c feat: upgrade to JDK 17 and add wvpcode submodule
- Upgrade Dockerfile from JDK 11 to JDK 17 (eclipse-temurin)
- Add wvpcode as git submodule for independent source management
- Remove unused be.teletask.onvif-java submodule
- Add documentation for docker-compose and build process

🤖 Generated with Claude Code
2025-12-05 17:14:30 +08:00
2871 changed files with 331342 additions and 5720 deletions

View File

@ -1,37 +0,0 @@
---
name: "[ BUG ] "
about: 关于wvp的bug与zlm有关的建议直接在zlm的issue中提问
title: 'BUG'
labels: 'wvp的bug'
assignees: ''
---
**环境信息:**
- 1. 部署方式 wvp-pro docker / zlm(docker) + 编译wvp-pro/ wvp-prp + zlm都是编译部署/
- 2. 部署环境 windows / ubuntu/ centos ...
- 3. 端口开放情况
- 4. 是否是公网部署
- 5. 是否使用https
- 6. 接入设备/平台品牌
- 7. 你做过哪些尝试
- 8. 代码更新时间
- 9. 是否是4G设备接入
**描述错误**
描述下您遇到的问题
**如何复现**
有明确复现步骤的问题会很容易被解决
**截图**
**抓包文件**
**日志**
```
日志内容放这里, 文件的话请直接上传
```

View File

@ -1,13 +0,0 @@
---
name: "[ 新功能 ]"
about: 新功能
title: '希望wVP实现的新功能此功能应与你的具体业务无关'
labels: ''
assignees: ''
---
**项目的详细需求**
**这样的实现什么作用**

View File

@ -1,31 +0,0 @@
---
name: "[ 技术咨询 ] "
about: 对于使用中遇到问题
title: '技术咨询'
labels: '技术咨询'
assignees: ''
---
**环境信息:**
- 1. 部署方式 wvp-pro docker / zlm(docker) + 编译wvp-pro/ wvp-prp + zlm都是编译部署/
- 2. 部署环境 windows / ubuntu/ centos ...
- 3. 端口开放情况
- 4. 是否是公网部署
- 5. 是否使用https
- 6. 方便的话提供下使用的设备品牌或平台
- 7. 你做过哪些尝试
- 8. 代码更新时间(旧版本请更新最新版本代码测试)
**内容描述:**
**截图**
**抓包文件**
**日志**
```
日志内容放这里, 文件的话请直接上传
```

View File

@ -1,75 +0,0 @@
name: release-ubuntu
on:
push:
tags:
- "v*.*.*" # 触发条件是推送标签 如git tag v2.7.4 git push origin v2.7.4
jobs:
build-ubuntu:
runs-on: ubuntu-latest
strategy:
matrix:
arch: [amd64]
max-parallel: 1 # 最大并行数
steps:
- name: Checkout
uses: actions/checkout@v4 # github action运行环境
- name: Create release # 创建文件夹
run: |
rm -rf release
mkdir release
echo ${{ github.sha }} > Release.txt
cp Release.txt LICENSE release/
cat Release.txt
- name: Set up JDK 1.8
uses: actions/setup-java@v4
with:
# Eclipse基金会维护的开源Java发行版 因为github action参考java的用这个 所以用这个
# 还有microsoft(微软维护的openjdk发行版) oracle(商用SDK)等
distribution: 'temurin'
java-version: '8'
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20.x' # Node.js版本 20系列的最新稳定版
- name: Compile backend
run: |
mvn package
mvn package -P war
- name: Compile frontend
run: |
cd ./web
npm install
npm run build:prod
cd ../
- name: Package Files
run: |
cp -r ./src/main/resources/static release/ # 复制前端文件
cp ./target/*.jar release/ # 复制 JAR 文件
cp ./src/main/resources/application-dev.yml release/application.yml
BRANCH=${{ github.event.base_ref }}
BRANCH_NAME=$(echo "$BRANCH" | grep -oP 'refs/heads/\K.*')
echo "BRANCH_NAME= ${BRANCH_NAME}"
# 如果无法获取,使用默认分支
if [[ -z "BRANCH_NAME" ]]; then
BRANCH_NAME="${{ github.event.repository.default_branch }}"
fi
TAG_NAME="${GITHUB_REF#refs/tags/}"
ZIP_FILE_NAME="${BRANCH_NAME}-${TAG_NAME}.zip"
zip -r "$ZIP_FILE_NAME" release
echo "ZIP_FILE_NAME=$ZIP_FILE_NAME" >> $GITHUB_ENV
- name: Release
uses: softprops/action-gh-release@v2
if: startsWith(github.ref, 'refs/tags/')
with:
files: ${{ env.ZIP_FILE_NAME }}

1
.gitignore vendored
View File

@ -6,6 +6,7 @@
logs/* logs/*
# BlueJ files # BlueJ files
*.ctxt *.ctxt
*.DS_Store*
# Mobile Tools for Java (J2ME) # Mobile Tools for Java (J2ME)
.mtj.tmp/ .mtj.tmp/

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "be.teletask.onvif-java"]
path = be.teletask.onvif-java
url = https://gitee.com/pan648540858/be.teletask.onvif-java.git

89
ZLM/.clang-format Normal file
View File

@ -0,0 +1,89 @@
# This is for clang-format >= 9.0.
#
# clang-format --version
# clang-format version 9.0.1 (Red Hat 9.0.1-2.module+el8.2.0+5494+7b8075cf)
#
# 详细说明见: https://clang.llvm.org/docs/ClangFormatStyleOptions.html
# 部分参数会随版本变化.
---
Language: Cpp
# 基于 WebKit 的风格, https://www.webkit.org/coding/coding-style.html
BasedOnStyle: WebKit
# 以下各选项按字母排序
# public/protected/private 不缩进
AccessModifierOffset: -4
# 参数过长时统一换行
AlignAfterOpenBracket: AlwaysBreak
# clang-format >= 13 required, map 之类的内部列对齐
# AlignArrayOfStructures: Left
# 换行符统一在 ColumnLimit 最右侧
AlignEscapedNewlines: Right
# 不允许短代码块单行, 即不允许单行代码: if (x) return;
AllowShortBlocksOnASingleLine: false
# 只允许 Inline 函数单行
AllowShortFunctionsOnASingleLine: Inline
# 模板声明换行
AlwaysBreakTemplateDeclarations: Yes
# 左开括号不换行
BreakBeforeBraces: Custom
BraceWrapping:
AfterCaseLabel: false
AfterClass: false
# BraceWrappingAfterControlStatementStyle: MultiLine
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
# 构造函数初始化时在 `,` 前换行, 和 `:` 对齐显得整齐
BreakConstructorInitializers: BeforeComma
# 继承过长需要换行时也在 `,` 前
BreakInheritanceList: BeforeComma
# 列宽 160
ColumnLimit: 160
# c++11 括号内起始/结束无空格, false 会加上
Cpp11BracedListStyle: false
# 命名空间后的注释会修正为: // namespace_name
FixNamespaceComments: true
#switch case的缩进
IndentCaseLabels: true
#允许单行case
AllowShortCaseLabelsOnASingleLine: true
# clang-format >= 13 required, lambda 函数内部缩进级别和外部一致, 默认会增加一级缩进
# LambdaBodyIndentation: OuterScope
# 命名空间不缩进
NamespaceIndentation: None
# PPIndentWidth: 2
# */& 靠近变量, 向右靠
PointerAlignment: Right
# c++11 使用 {} 构造时和变量加个空格
SpaceBeforeCpp11BracedList: true
# 继承时 `:` 前加空格
SpaceBeforeInheritanceColon: true
# () 前不加空格, do/for/if/switch/while 除外
SpaceBeforeParens: ControlStatements
# 空 {} 中不加空格
SpaceInEmptyBlock: false
Standard: C++11
# Tab 占 4 位
TabWidth: 4
# 不使用 TAB
UseTab: Never
---
Language: Java
---
Language: JavaScript
...

2
ZLM/.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
*.h linguist-language=cpp
*.c linguist-language=cpp

95
ZLM/.github/ISSUE_TEMPLATE/bug.md vendored Normal file
View File

@ -0,0 +1,95 @@
---
name: bug 反馈
about: 反馈 ZLMediaKit 代码本身的 bug
title: "[BUG] BUG现象描述(必填)"
labels: bug
assignees: ''
---
<!--
请仔细阅读相关注释提示, 请务必根据提示填写相关信息.
1. 信息不完整会影响问题的解决速度.
1. 乱七八糟的渲染格式也会影响开发者心情, 同样会影响问题的解决. 提交前请务必点击 Preview/预览下反馈的显示效果.
1. 不要删除模版内容, 模版的注释部分的内容不会显示,不需要删除,直接在各部分注释外面补充相关信息即可.
-->
<!--
markdown 语法参考:
* https://docs.github.com/cn/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax
* https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax
-->
## 现象描述
<!--
在使用什么功能产生的问题? 其异常表现是什么?
如: 在测试 WebRTC 功能时, 使用 Chrome 浏览器访问 ZLMediait 自带网页播放 FFmpeg 以 RTSP 协议推送的图像有卡顿/花屏.
-->
## 如何复现?
<!--
明确的复现步骤对快速解决问题极有帮助.
格式参考:
1. 首先 ...
1. 然后 ...
1. 期望 ..., 结果 ...
-->
## 相关日志或截图
<!--
由于日志通长较长, 建议将日志信息填写到下面的 "日志内容..."
如果是程序异常崩溃/终止, 相关调用栈信息也极为有用, 可复制下面的格式, 添加相关调用栈信息.
替换下面的 "日志内容..." 为实际日志内容.
-->
<details>
<summary>展开查看详细日志</summary>
<pre>
```
#详细日志粘在这里!
```
</pre>
</details>
## 配置
<!--
部分常见问题是由于配置错误导致的, 建议仔细阅读配置文件中的注释信息
替换下面的 "配置内容..." 为实际配置内容.
-->
<details>
<summary>展开查看详细配置</summary>
<pre>
```ini
#config.ini内容粘在这里!
```
</pre>
</details>
## 各种环境信息
<!--
请填写相关环境信息, 详细的环境信息有助于快速复现定位问题.
* 代码提交记录, 可使用命令 `git rev-parse HEAD` 进行查看.
* 操作系统及版本, 如: Windows 10, CentOS 7, ...
* 硬件信息, 如: Intel, AMD, ARM, 飞腾, 龙芯, ...
-->
* **代码提交记录/git commit hash**:
* **操作系统及版本**:
* **硬件信息**:
* **crash backtrace**:
```
#崩溃信息backtrace粘贴至此
```
* **其他需要补充的信息**:

57
ZLM/.github/ISSUE_TEMPLATE/compile.md vendored Normal file
View File

@ -0,0 +1,57 @@
---
name: 编译问题反馈
about: 反馈 ZLMediaKit 编译相关的问题
title: "[编译问题] 编译问题描述(必填)"
labels: 编译问题
assignees: ''
---
<!--
请仔细阅读相关注释提示, 请务必根据提示填写相关信息.
1. 信息不完整会影响问题的解决速度.
1. 乱七八糟的渲染格式也会影响开发者心情, 同样会影响问题的解决. 提交前请务必点击 Preview/预览下反馈的显示效果.
1. 不要删除模版内容, 模版的注释部分的内容不会显示,不需要删除,直接在各部分注释外面补充相关信息即可.
-->
<!--
markdown 语法参考:
* https://docs.github.com/cn/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax
* https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax
-->
## 相关日志及环境信息
<!--
由于编译日志通长较长, 建议将日志信息填写到下面 `````` block 内,或者上传日志文件
-->
**清除编译缓存后,完整执行 cmake && make 命令的输出**
<details>
<summary>展开查看详细编译日志</summary>
<pre>
```
详细日志粘在这里!
```
</pre>
</details>
编译目录下的 `CMakeCache.txt` 文件内容,请直接上传为附件。
## 各种环境信息
<!--
请填写相关环境信息, 详细的环境信息有助于快速复现定位问题.
* 代码提交记录, 可使用命令 `git rev-parse HEAD` 进行查看.
* 操作系统及版本, 如: Windows 10, CentOS 7, ...
* 硬件信息, 如: Intel, AMD, ARM, 飞腾, 龙芯, ...
-->
* **代码提交记录/git commit hash**:
* **操作系统及版本**:
* **硬件信息**:
* **其他需要补充的信息**:

6
ZLM/.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,6 @@
blank_issues_enabled: false
contact_links:
- name: 技术咨询
url: https://t.zsxq.com/FcVK5
about: 请在知识星球发起技术咨询

14
ZLM/.github/ISSUE_TEMPLATE/feature.md vendored Normal file
View File

@ -0,0 +1,14 @@
---
name: 新增功能请求
about: 请求新增某些新功能或新特性,或者对已有功能的改进
title: "[功能请求] 需求描述(必填)"
labels: 意见建议
assignees: ''
---
## 描述该功能的用处,可以提供相关资料描述该功能
## 该功能是否用于改进项目缺陷,如果是,请描述现有缺陷
## 描述你期望实现该功能的方式和最终效果

59
ZLM/.github/workflows/android.yml vendored Normal file
View File

@ -0,0 +1,59 @@
name: Android
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-24.04
steps:
- name: 下载源码
uses: actions/checkout@v1
- name: 配置JDK
uses: actions/setup-java@v3
with:
java-version: '11'
distribution: 'temurin'
cache: gradle
- name: 下载submodule源码
run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init
- name: 赋予gradlew文件可执行权限
run: chmod +x ./Android/gradlew
- name: 编译
run: cd Android && ./gradlew build
- name: 设置环境变量
run: |
echo "BRANCH=$(echo ${GITHUB_REF#refs/heads/} | tr -s "/\?%*:|\"<>" "_")" >> $GITHUB_ENV
echo "BRANCH2=$(echo ${GITHUB_REF#refs/heads/} )" >> $GITHUB_ENV
echo "DATE=$(date +%Y-%m-%d)" >> $GITHUB_ENV
- name: 打包二进制
id: upload
uses: actions/upload-artifact@v4
with:
name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}
path: Android/app/build/outputs/apk/debug/*
if-no-files-found: error
retention-days: 90
- name: issue评论
if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test'
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
github.rest.issues.createComment({
issue_number: ${{vars.VERSION_ISSUE_NO}},
owner: context.repo.owner,
repo: context.repo.repo,
body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n'
+ '- 分支: ${{ env.BRANCH2 }}\n'
+ '- git hash: ${{ github.sha }} \n'
+ '- 编译日期: ${{ env.DATE }}\n'
+ '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n'
+ '- 开启特性: 未开启openssl/webrtc/datachannel等功能\n'
+ '- 打包ci名: ${{ github.workflow }}\n'
})

62
ZLM/.github/workflows/codeql.yml vendored Normal file
View File

@ -0,0 +1,62 @@
name: CodeQL
on: [push, pull_request]
jobs:
analyze:
name: Analyze
runs-on: ubuntu-24.04
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'cpp', 'javascript' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
steps:
- uses: actions/checkout@v1
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
- name: 下载submodule源码
run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init
- name: apt-get安装依赖库(非必选)
run: sudo apt-get update && sudo apt-get install -y cmake libssl-dev libsdl-dev libavcodec-dev libavutil-dev libswscale-dev libresample-dev
- name: 下载 SRTP
uses: actions/checkout@v2
with:
repository: cisco/libsrtp
fetch-depth: 1
ref: v2.7.0
path: 3rdpart/libsrtp
- name: 编译 SRTP
run: cd 3rdpart/libsrtp && ./configure --enable-openssl && make -j4 && sudo make install
- name: 编译
run: mkdir -p linux_build && cd linux_build && cmake .. -DENABLE_WEBRTC=true -DENABLE_FFMPEG=true && make -j $(nproc)
- name: 运行MediaServer
run: pwd && cd release/linux/Debug && sudo ./MediaServer -d &
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2

88
ZLM/.github/workflows/docker.yml vendored Normal file
View File

@ -0,0 +1,88 @@
name: Docker
on:
push:
branches:
- "master"
- "feature/*"
- "release/*"
env:
# Use docker.io for Docker Hub if empty
REGISTRY: docker.io
IMAGE_NAME: zlmediakit/zlmediakit
jobs:
build:
runs-on: ubuntu-24.04
permissions:
contents: read
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: 下载submodule源码
run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init
# Install the cosign tool except on PR
# https://github.com/sigstore/cosign-installer
- name: Install cosign
uses: sigstore/cosign-installer@d572c9c13673d2e0a26fabf90b5748f36886883f
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
# Workaround: https://github.com/docker/build-push-action/issues/461
- name: Setup Docker buildx
uses: docker/setup-buildx-action@79abd3f86f79a9d68a23c75a09a9a85889262adf
# Login against a Docker registry except on PR
# https://github.com/docker/login-action
- name: Log into registry ${{ env.REGISTRY }}
uses: docker/login-action@28218f9b04b4f3f62068d7b6ce6ca5b26e35336c
with:
registry: ${{ env.REGISTRY }}
username: zlmediakit
password: ${{ secrets.DOCKER_IO_SECRET }}
# Extract metadata (tags, labels) for Docker
# https://github.com/docker/metadata-action
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
# Build and push Docker image with Buildx (don't push on PR)
# https://github.com/docker/build-push-action
- name: Build and push Docker image
if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test'
id: build-and-push
uses: docker/build-push-action@ac9327eae2b366085ac7f6a2d02df8aa8ead720a
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: MODEL=Release
platforms: linux/amd64,linux/arm64
# Sign the resulting Docker image digest except on PRs.
# This will only write to the public Rekor transparency log when the Docker
# repository is public to avoid leaking data. If you would like to publish
# transparency data even for private images, pass --force to cosign below.
# https://github.com/sigstore/cosign
# - name: Sign the published Docker image
# if: ${{ github.event_name != 'pull_request' }}
# env:
# COSIGN_EXPERIMENTAL: "true"
# # This step uses the identity token to provision an ephemeral certificate
# # against the sigstore community Fulcio instance.
# run: cosign sign ${{ steps.meta.outputs.tags }}@${{ steps.build-and-push.outputs.digest }}

58
ZLM/.github/workflows/issue_lint.yml vendored Normal file
View File

@ -0,0 +1,58 @@
name: issue_lint
on:
issues:
types: [opened]
jobs:
issue_lint:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v3
- uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs').promises;
const getTitles = (str) => (
[...str.matchAll(/^## (.*)/gm)].map((m) => m[0])
);
const titles = getTitles(context.payload.issue.body);
for (let file of await fs.readdir('.github/ISSUE_TEMPLATE')) {
if (!file.endsWith('.md')) {
continue;
}
const template = await fs.readFile(`.github/ISSUE_TEMPLATE/${file}`, 'utf-8');
const templateTitles = getTitles(template);
if (templateTitles.every((title) => titles.includes(title))) {
process.exit(0);
}
}
await github.rest.issues.createComment({
owner: context.issue.owner,
repo: context.issue.repo,
issue_number: context.issue.number,
body: '此issue由于不符合模板规范已经自动关闭请重新按照模板规范确保包含模板中所有章节标题再提交\n',
});
await github.rest.issues.addLabels({
owner: context.issue.owner,
repo: context.issue.repo,
issue_number: context.issue.number,
labels: ['自动关闭']
});
await github.rest.issues.update({
owner: context.issue.owner,
repo: context.issue.repo,
issue_number: context.issue.number,
state: 'closed',
});

135
ZLM/.github/workflows/linux.yml vendored Normal file
View File

@ -0,0 +1,135 @@
name: Linux
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v1
- name: 下载submodule源码
run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init
- name: 下载 SRTP
uses: actions/checkout@v2
with:
repository: cisco/libsrtp
fetch-depth: 1
ref: v2.3.0
path: 3rdpart/libsrtp
- name: 下载 openssl
uses: actions/checkout@v2
with:
repository: openssl/openssl
fetch-depth: 1
ref: OpenSSL_1_1_1
path: 3rdpart/openssl
- name: 下载 usrsctp
uses: actions/checkout@v2
with:
repository: sctplab/usrsctp
fetch-depth: 1
ref: 0.9.5.0
path: 3rdpart/usrsctp
- name: 启动 Docker 容器, 在Docker 容器中执行脚本
run: |
docker pull centos:7
docker run -v $(pwd):/root -w /root --rm centos:7 sh -c "
#!/bin/bash
set -x
# Backup original CentOS-Base.repo file
cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup
# Define new repository configuration
cat <<EOF > /etc/yum.repos.d/CentOS-Base.repo
[base]
name=CentOS-7 - Base - mirrors.aliyun.com
baseurl=http://mirrors.aliyun.com/centos/7/os/x86_64/
gpgcheck=1
gpgkey=http://mirrors.aliyun.com/centos/RPM-GPG-KEY-CentOS-7
[updates]
name=CentOS-7 - Updates - mirrors.aliyun.com
baseurl=http://mirrors.aliyun.com/centos/7/updates/x86_64/
gpgcheck=1
gpgkey=http://mirrors.aliyun.com/centos/RPM-GPG-KEY-CentOS-7
EOF
# Clean yum cache and recreate it
yum clean all
yum makecache
echo \"CentOS 7 软件源已成功切换\"
yum install -y git wget gcc gcc-c++ make
mkdir -p /root/install
cd 3rdpart/openssl
./config no-shared --prefix=/root/install
make -j $(nproc)
make install
cd ../../
wget https://github.com/Kitware/CMake/releases/download/v3.29.5/cmake-3.29.5.tar.gz
tar -xf cmake-3.29.5.tar.gz
cd cmake-3.29.5
OPENSSL_ROOT_DIR=/root/install ./configure
make -j $(nproc)
make install
cd ..
cd 3rdpart/usrsctp
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON ..
make -j $(nproc)
make install
cd ../../../
cd 3rdpart/libsrtp && ./configure --enable-openssl --with-openssl-dir=/root/install && make -j $(nproc) && make install
cd ../../
mkdir -p linux_build && cd linux_build && cmake .. -DOPENSSL_ROOT_DIR=/root/install -DCMAKE_BUILD_TYPE=Release && make -j $(nproc)
"
- name: 设置环境变量
run: |
echo "BRANCH=$(echo ${GITHUB_REF#refs/heads/} | tr -s "/\?%*:|\"<>" "_")" >> $GITHUB_ENV
echo "BRANCH2=$(echo ${GITHUB_REF#refs/heads/} )" >> $GITHUB_ENV
echo "DATE=$(date +%Y-%m-%d)" >> $GITHUB_ENV
- name: 打包二进制
id: upload
uses: actions/upload-artifact@v4
with:
name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}
path: release/*
if-no-files-found: error
retention-days: 90
- name: issue评论
if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test'
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
github.rest.issues.createComment({
issue_number: ${{vars.VERSION_ISSUE_NO}},
owner: context.repo.owner,
repo: context.repo.repo,
body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n'
+ '- 分支: ${{ env.BRANCH2 }}\n'
+ '- git hash: ${{ github.sha }} \n'
+ '- 编译日期: ${{ env.DATE }}\n'
+ '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n'
+ '- 打包ci名: ${{ github.workflow }}\n'
+ '- 开启特性: openssl/webrtc/datachannel\n'
+ '- 说明: 本二进制在centos7(x64)上编译,请确保您的机器系统不低于此版本\n'
})

71
ZLM/.github/workflows/macos.yml vendored Normal file
View File

@ -0,0 +1,71 @@
name: macOS
on: [push, pull_request]
jobs:
build:
runs-on: macOS-latest
steps:
- uses: actions/checkout@v1
- name: 下载submodule源码
run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init
- name: 配置 vcpkg
uses: lukka/run-vcpkg@v7
with:
vcpkgDirectory: '${{github.workspace}}/vcpkg'
vcpkgTriplet: arm64-osx
# 2025.07.11
vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729'
vcpkgArguments: 'openssl libsrtp[openssl] usrsctp'
- name: 安装指定 CMake
uses: jwlawson/actions-setup-cmake@v2
with:
cmake-version: '3.30.5'
- name: 编译
uses: lukka/run-cmake@v3
with:
useVcpkgToolchainFile: true
buildDirectory: '${{github.workspace}}/build'
cmakeAppendedArgs: ''
cmakeBuildType: 'Release'
- name: 设置环境变量
run: |
echo "BRANCH=$(echo ${GITHUB_REF#refs/heads/} | tr -s "/\?%*:|\"<>" "_")" >> $GITHUB_ENV
echo "BRANCH2=$(echo ${GITHUB_REF#refs/heads/} )" >> $GITHUB_ENV
echo "DATE=$(date +%Y-%m-%d)" >> $GITHUB_ENV
- name: 打包二进制
id: upload
uses: actions/upload-artifact@v4
with:
name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}
path: release/*
if-no-files-found: error
retention-days: 90
- name: issue评论
if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test'
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
github.rest.issues.createComment({
issue_number: ${{vars.VERSION_ISSUE_NO}},
owner: context.repo.owner,
repo: context.repo.repo,
body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n'
+ '- 分支: ${{ env.BRANCH2 }}\n'
+ '- git hash: ${{ github.sha }} \n'
+ '- 编译日期: ${{ env.DATE }}\n'
+ '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n'
+ '- 打包ci名: ${{ github.workflow }}\n'
+ '- 开启特性: openssl/webrtc/datachannel\n'
+ '- 说明: 此二进制为arm64版本\n'
})

27
ZLM/.github/workflows/style.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: style check
on: [pull_request]
jobs:
check:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v2
with:
# with all history
fetch-depth: 0
- name: Validate BOM
run: |
ret=0
for i in $(git diff --name-only origin/${GITHUB_BASE_REF}...${GITHUB_SHA}); do
if [ -f ${i} ]; then
case ${i} in
*.c|*.cc|*.cpp|*.h)
if file ${i} | grep -qv BOM; then
echo "Missing BOM in ${i}" && ret=1;
fi
;;
esac
fi
done
exit ${ret}

68
ZLM/.github/workflows/windows.yml vendored Normal file
View File

@ -0,0 +1,68 @@
name: Windows
on: [push, pull_request]
jobs:
build:
runs-on: windows-2022
steps:
- uses: actions/checkout@v1
- name: 下载submodule源码
run: mv -Force .gitmodules_github .gitmodules && git submodule sync && git submodule update --init
- name: 配置 vcpkg
uses: lukka/run-vcpkg@v7
with:
vcpkgDirectory: '${{github.workspace}}/vcpkg'
vcpkgTriplet: x64-windows-static
# 2025.07.11
vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729'
vcpkgArguments: 'openssl libsrtp[openssl] usrsctp'
- name: 编译
uses: lukka/run-cmake@v3
with:
useVcpkgToolchainFile: true
buildDirectory: '${{github.workspace}}/build'
cmakeAppendedArgs: ''
cmakeBuildType: 'Release'
- name: 设置环境变量
run: |
$dateString = Get-Date -Format "yyyy-MM-dd"
$branch = $env:GITHUB_REF -replace "refs/heads/", "" -replace "[\\/\\\?\%\*:\|\x22<>]", "_"
$branch2 = $env:GITHUB_REF -replace "refs/heads/", ""
echo "BRANCH=$branch" >> $env:GITHUB_ENV
echo "BRANCH2=$branch2" >> $env:GITHUB_ENV
echo "DATE=$dateString" >> $env:GITHUB_ENV
- name: 打包二进制
id: upload
uses: actions/upload-artifact@v4
with:
name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}
path: release/*
if-no-files-found: error
retention-days: 90
- name: issue评论
if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test'
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
github.rest.issues.createComment({
issue_number: ${{vars.VERSION_ISSUE_NO}},
owner: context.repo.owner,
repo: context.repo.repo,
body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n'
+ '- 分支: ${{ env.BRANCH2 }}\n'
+ '- git hash: ${{ github.sha }} \n'
+ '- 编译日期: ${{ env.DATE }}\n'
+ '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n'
+ '- 打包ci名: ${{ github.workflow }}\n'
+ '- 开启特性: openssl/webrtc/datachannel\n'
+ '- 说明: 此二进制为x64版本\n'
})

51
ZLM/.gitignore vendored Normal file
View File

@ -0,0 +1,51 @@
# Compiled Object files
*.slo
*.lo
*.o
*.obj
*.d
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
#*.dylib
#*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.lib
# Executables
#*.exe
*.out
*.app
/X64/
*.DS_Store
/cmake-build-debug/
/cmake-build-release/
/linux/
/.vs/
/.vscode/
/.idea/
/c_wrapper/.idea/
/release/
/out/
/Android/.idea/
/Android/app/src/main/cpp/libs_export/
/3rdpart/media-server/.idea/
/3rdpart/media-server/.idea/
/build/
/3rdpart/media-server/.idea/
/ios/
/cmake-build-*
/3rdpart/ZLToolKit/cmake-build-mq/

123
ZLM/3rdpart/CMakeLists.txt Normal file
View File

@ -0,0 +1,123 @@
# MIT License
#
# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
##############################################################################
# jsoncpp
file(GLOB JSONCPP_SRC_LIST
${CMAKE_CURRENT_SOURCE_DIR}/jsoncpp/include/json/*.h
${CMAKE_CURRENT_SOURCE_DIR}/jsoncpp/src/lib_json/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jsoncpp/src/lib_json/*.h)
add_library(jsoncpp STATIC ${JSONCPP_SRC_LIST})
target_compile_options(jsoncpp
PRIVATE ${COMPILE_OPTIONS_DEFAULT})
target_include_directories(jsoncpp
PRIVATE
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>/jsoncpp/include"
PUBLIC
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>/jsoncpp/include")
update_cached_list(MK_LINK_LIBRARIES jsoncpp)
##############################################################################
# media-server
set(MediaServer_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/media-server")
# TODO:
# movflv MP4
if (ENABLE_MP4)
# MOV
set(MediaServer_MOV_ROOT ${MediaServer_ROOT}/libmov)
aux_source_directory(${MediaServer_MOV_ROOT}/include MOV_SRC_LIST)
aux_source_directory(${MediaServer_MOV_ROOT}/source MOV_SRC_LIST)
add_library(mov STATIC ${MOV_SRC_LIST})
add_library(MediaServer::mov ALIAS mov)
target_compile_options(mov PRIVATE ${COMPILE_OPTIONS_DEFAULT})
target_include_directories(mov
PRIVATE
"$<BUILD_INTERFACE:${MediaServer_MOV_ROOT}/include>"
PUBLIC
"$<BUILD_INTERFACE:${MediaServer_MOV_ROOT}/include>")
# FLV
set(MediaServer_FLV_ROOT ${MediaServer_ROOT}/libflv)
aux_source_directory(${MediaServer_FLV_ROOT}/include FLV_SRC_LIST)
aux_source_directory(${MediaServer_FLV_ROOT}/source FLV_SRC_LIST)
add_library(flv STATIC ${FLV_SRC_LIST})
add_library(MediaServer::flv ALIAS flv)
target_compile_options(flv PRIVATE ${COMPILE_OPTIONS_DEFAULT})
target_include_directories(flv
PRIVATE
"$<BUILD_INTERFACE:${MediaServer_FLV_ROOT}/include>"
PUBLIC
"$<BUILD_INTERFACE:${MediaServer_FLV_ROOT}/include>")
update_cached_list(MK_LINK_LIBRARIES MediaServer::flv MediaServer::mov)
if (ENABLE_MP4)
message(STATUS "ENABLE_MP4 defined")
update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_MP4)
endif ()
endif ()
# mpeg ts
if(ENABLE_RTPPROXY OR ENABLE_HLS)
# mpeg
set(MediaServer_MPEG_ROOT ${MediaServer_ROOT}/libmpeg)
aux_source_directory(${MediaServer_MPEG_ROOT}/include MPEG_SRC_LIST)
aux_source_directory(${MediaServer_MPEG_ROOT}/source MPEG_SRC_LIST)
add_library(mpeg STATIC ${MPEG_SRC_LIST})
add_library(MediaServer::mpeg ALIAS mpeg)
# media-server
# MPEG_H26X_VERIFY -
# MPEG_ZERO_PAYLOAD_LENGTH - hik
# MPEG_DAHUA_AAC_FROM_G711 - dahua
target_compile_options(mpeg
PRIVATE ${COMPILE_OPTIONS_DEFAULT} -DMPEG_H26X_VERIFY -DMPEG_ZERO_PAYLOAD_LENGTH -DMPEG_DAHUA_AAC_FROM_G711)
target_include_directories(mpeg
PRIVATE
"$<BUILD_INTERFACE:${MediaServer_MPEG_ROOT}/include>"
PUBLIC
"$<BUILD_INTERFACE:${MediaServer_MPEG_ROOT}/include>")
update_cached_list(MK_LINK_LIBRARIES MediaServer::mpeg)
if(ENABLE_RTPPROXY)
message(STATUS "ENABLE_RTPPROXY defined")
update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_RTPPROXY)
endif()
if(ENABLE_HLS)
message(STATUS "ENABLE_HLS defined")
update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_HLS)
endif()
endif()
##############################################################################
# toolkit
add_subdirectory(ZLToolKit)
#
add_library(ZLMediaKit::ToolKit ALIAS ZLToolKit)
#
update_cached_list(MK_LINK_LIBRARIES ZLMediaKit::ToolKit)

View File

@ -0,0 +1,89 @@
# This is for clang-format >= 9.0.
#
# clang-format --version
# clang-format version 9.0.1 (Red Hat 9.0.1-2.module+el8.2.0+5494+7b8075cf)
#
# 详细说明见: https://clang.llvm.org/docs/ClangFormatStyleOptions.html
# 部分参数会随版本变化.
---
Language: Cpp
# 基于 WebKit 的风格, https://www.webkit.org/coding/coding-style.html
BasedOnStyle: WebKit
# 以下各选项按字母排序
# public/protected/private 不缩进
AccessModifierOffset: -4
# 参数过长时统一换行
AlignAfterOpenBracket: AlwaysBreak
# clang-format >= 13 required, map 之类的内部列对齐
# AlignArrayOfStructures: Left
# 换行符统一在 ColumnLimit 最右侧
AlignEscapedNewlines: Right
# 不允许短代码块单行, 即不允许单行代码: if (x) return;
AllowShortBlocksOnASingleLine: false
# 只允许 Inline 函数单行
AllowShortFunctionsOnASingleLine: Inline
# 模板声明换行
AlwaysBreakTemplateDeclarations: Yes
# 左开括号不换行
BreakBeforeBraces: Custom
BraceWrapping:
AfterCaseLabel: false
AfterClass: false
# BraceWrappingAfterControlStatementStyle: MultiLine
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
# 构造函数初始化时在 `,` 前换行, 和 `:` 对齐显得整齐
BreakConstructorInitializers: BeforeComma
# 继承过长需要换行时也在 `,` 前
BreakInheritanceList: BeforeComma
# 列宽 160
ColumnLimit: 160
# c++11 括号内起始/结束无空格, false 会加上
Cpp11BracedListStyle: false
# 命名空间后的注释会修正为: // namespace_name
FixNamespaceComments: true
#switch case的缩进
IndentCaseLabels: true
#允许单行case
AllowShortCaseLabelsOnASingleLine: true
# clang-format >= 13 required, lambda 函数内部缩进级别和外部一致, 默认会增加一级缩进
# LambdaBodyIndentation: OuterScope
# 命名空间不缩进
NamespaceIndentation: None
# PPIndentWidth: 2
# */& 靠近变量, 向右靠
PointerAlignment: Right
# c++11 使用 {} 构造时和变量加个空格
SpaceBeforeCpp11BracedList: true
# 继承时 `:` 前加空格
SpaceBeforeInheritanceColon: true
# () 前不加空格, do/for/if/switch/while 除外
SpaceBeforeParens: ControlStatements
# 空 {} 中不加空格
SpaceInEmptyBlock: false
Standard: C++11
# Tab 占 4 位
TabWidth: 4
# 不使用 TAB
UseTab: Never
---
Language: Java
---
Language: JavaScript
...

View File

@ -0,0 +1,46 @@
name: Linux
on: [push, pull_request]
env:
# Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.)
BUILD_TYPE: Release
jobs:
build:
# The CMake configure and build commands are platform agnostic and should work equally
# well on Windows or Mac. You can convert this to a matrix build if you need
# cross-platform coverage.
# See: https://docs.github.com/en/free-pro-team@latest/actions/learn-github-actions/managing-complex-workflows#using-a-build-matrix
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Create Build Environment
# Some projects don't allow in-source building, so create a separate build directory
# We'll use this as our working directory for all subsequent commands
run: cmake -E make_directory ${{github.workspace}}/build
- name: Configure CMake
# Use a bash shell so we can use the same syntax for environment variable
# access regardless of the host operating system
shell: bash
working-directory: ${{github.workspace}}/build
# Note the current convention is to use the -S and -B options here to specify source
# and build directories, but this is only available with CMake 3.13 and higher.
# The CMake binaries on the Github Actions machines are (as of this writing) 3.12
run: cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=$BUILD_TYPE
- name: Build
working-directory: ${{github.workspace}}/build
shell: bash
# Execute the build. You can specify a specific target with "--target <NAME>"
run: cmake --build . --config $BUILD_TYPE -j $(nproc)
- name: Test
working-directory: ${{github.workspace}}/build
shell: bash
# Execute tests defined by the CMake configuration.
# See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail
run: ctest -C $BUILD_TYPE

View File

@ -0,0 +1,40 @@
name: MacOS
on: [push, pull_request]
env:
# Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.)
BUILD_TYPE: Release
jobs:
build:
# The CMake configure and build commands are platform agnostic and should work equally
# well on Windows or Mac. You can convert this to a matrix build if you need
# cross-platform coverage.
# See: https://docs.github.com/en/free-pro-team@latest/actions/learn-github-actions/managing-complex-workflows#using-a-build-matrix
runs-on: macOS-latest
steps:
- uses: actions/checkout@v2
- name: 配置 vcpkg
uses: lukka/run-vcpkg@v7
with:
vcpkgDirectory: '${{github.workspace}}/vcpkg'
vcpkgTriplet: arm64-osx
# 2025.07.11
vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729'
vcpkgArguments: 'openssl'
- name: 安装指定 CMake
uses: jwlawson/actions-setup-cmake@v2
with:
cmake-version: '3.30.5'
- name: 编译
uses: lukka/run-cmake@v3
with:
useVcpkgToolchainFile: true
buildDirectory: '${{github.workspace}}/build'
cmakeAppendedArgs: ''
cmakeBuildType: 'RelWithDebInfo'

View File

@ -0,0 +1,27 @@
name: style check
on: [pull_request]
jobs:
check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
# with all history
fetch-depth: 0
- name: Validate BOM
run: |
ret=0
for i in $(git diff --name-only origin/${GITHUB_BASE_REF}...${GITHUB_SHA}); do
if [ -f ${i} ]; then
case ${i} in
*.c|*.cc|*.cpp|*.h)
if file ${i} | grep -qv BOM; then
echo "Missing BOM in ${i}" && ret=1;
fi
;;
esac
fi
done
exit ${ret}

View File

@ -0,0 +1,30 @@
name: Windows
on: [push, pull_request]
jobs:
build:
runs-on: windows-2022
steps:
- uses: actions/checkout@v1
with:
submodules: 'recursive'
fetch-depth: 1
- name: 配置 vcpkg
uses: lukka/run-vcpkg@v7
with:
vcpkgDirectory: '${{github.workspace}}/vcpkg'
vcpkgTriplet: x64-windows-static
# 2021.05.12
vcpkgGitCommitId: '5568f110b509a9fd90711978a7cb76bae75bb092'
vcpkgArguments: 'openssl'
- name: 编译
uses: lukka/run-cmake@v3
with:
useVcpkgToolchainFile: true
buildDirectory: '${{github.workspace}}/build'
cmakeAppendedArgs: ''
cmakeBuildType: 'RelWithDebInfo'

34
ZLM/3rdpart/ZLToolKit/.gitignore vendored Normal file
View File

@ -0,0 +1,34 @@
# Compiled Object files
*.slo
*.lo
*.o
*.obj
*.d
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
#*.dylib
#*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.lib
# Executables
*.exe
*.out
*.app
/X64/
*.DS_Store
/cmake-build-debug/
/.idea/
/.vs

View File

@ -0,0 +1,13 @@
language: cpp
sudo: required
dist: trusty
compiler:
- gcc
os:
- linux
before_install:
script:
- ./build_for_linux.sh

View File

@ -0,0 +1,5 @@
#代码贡献者列表提交pr时请留下您的联系方式
#Code contributor list, please leave your contact information when submitting a pull request
xiongziliang <771730766@qq.com>
[清涩绿茶](https://github.com/baiyfcu)

View File

@ -0,0 +1,145 @@
cmake_minimum_required(VERSION 3.1.3...3.26)
project(ZLToolKit)
#使c++11
set(CMAKE_CXX_STANDARD 11)
# -fPIC
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
#
set(CMAKE_VERBOSE_MAKEFILE ON)
#
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
option(ENABLE_OPENSSL "enable openssl" ON)
option(ENABLE_MYSQL "enable mysql" ON)
option(ENABLE_WEPOLL "Enable wepoll" ON)
option(ASAN_USE_DELETE "use delele[] or free when asan enabled" OFF)
option(BUILD_SHARED_LIBS "Build all libraries shared" ON)
include(CheckStructHasMember)
include(CheckSymbolExists)
list(APPEND CMAKE_REQUIRED_DEFINITIONS -D_GNU_SOURCE)
check_struct_has_member("struct mmsghdr" msg_hdr sys/socket.h HAVE_MMSG_HDR)
check_symbol_exists(sendmmsg sys/socket.h HAVE_SENDMMSG_API)
check_symbol_exists(recvmmsg sys/socket.h HAVE_RECVMMSG_API)
# 便
function(update_cached name value)
set("${name}" "${value}" CACHE INTERNAL "*** Internal ***" FORCE)
endfunction()
function(update_cached_list name)
set(_tmp_list "${${name}}")
list(APPEND _tmp_list "${ARGN}")
list(REMOVE_DUPLICATES _tmp_list)
update_cached(${name} "${_tmp_list}")
endfunction()
update_cached(TK_INC_PATHS "")
update_cached(TK_LINK_LIBRARIES "")
update_cached(TK_COMPILE_DEFINITIONS "")
update_cached(TK_COMPILE_OPTIONS "")
if (HAVE_MMSG_HDR)
update_cached_list(TK_COMPILE_DEFINITIONS HAVE_MMSG_HDR)
endif ()
if (HAVE_SENDMMSG_API)
update_cached_list(TK_COMPILE_DEFINITIONS HAVE_SENDMMSG_API)
endif ()
if (HAVE_RECVMMSG_API)
update_cached_list(TK_COMPILE_DEFINITIONS HAVE_RECVMMSG_API)
endif ()
# check the socket buffer size set by the upper cmake project, if it is set, use the setting of the upper cmake project, otherwise set it to 256K
# if the socket buffer size is set to 0, it means that the socket buffer size is not set, and the kernel default value is used(just for linux)
if (DEFINED SOCKET_DEFAULT_BUF_SIZE)
if (SOCKET_DEFAULT_BUF_SIZE EQUAL 0)
message(STATUS "Socket default buffer size is not set, use the kernel default value")
else ()
message(STATUS "Socket default buffer size is set to ${SOCKET_DEFAULT_BUF_SIZE}")
endif ()
update_cached_list(TK_COMPILE_DEFINITIONS SOCKET_DEFAULT_BUF_SIZE=${SOCKET_DEFAULT_BUF_SIZE})
endif ()
#
file(GLOB SRC_LIST
${CMAKE_CURRENT_SOURCE_DIR}/src/*/*.c
${CMAKE_CURRENT_SOURCE_DIR}/src/*/*.mm
${CMAKE_CURRENT_SOURCE_DIR}/src/*/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/*/*/*.cpp)
if (WIN32)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
if (MSVC)
update_cached_list(TK_COMPILE_OPTIONS "/utf-8")
endif ()
update_cached_list(TK_LINK_LIBRARIES WS2_32 Iphlpapi shlwapi)
#Windows.hWinsock.h
update_cached_list(TK_COMPILE_DEFINITIONS WIN32_LEAN_AND_MEAN MP4V2_NO_STDINT_DEFS _CRT_SECURE_NO_WARNINGS _WINSOCK_DEPRECATED_NO_WARNINGS)
else ()
# Windowsgetopt api
list(FILTER SRC_LIST EXCLUDE REGEX "getopt.c$")
update_cached_list(TK_COMPILE_OPTIONS "-Wno-comment" "-Wno-deprecated-declarations" "-Wno-predefined-identifier-outside-function")
endif ()
if (NOT WIN32 OR NOT ENABLE_WEPOLL)
# wepoll
list(FILTER SRC_LIST EXCLUDE REGEX "wepoll.c$")
else ()
update_cached_list(TK_COMPILE_DEFINITIONS HAS_EPOLL)
update_cached_list(TK_INC_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/win32/)
endif ()
#.mm
if (NOT APPLE)
list(FILTER SRC_LIST EXCLUDE REGEX "Socket_ios.mm$")
endif ()
#openssl
if (ENABLE_OPENSSL)
find_package(OpenSSL)
if (OPENSSL_FOUND)
update_cached_list(TK_INC_PATHS ${OPENSSL_INCLUDE_DIR})
update_cached_list(TK_LINK_LIBRARIES ${OPENSSL_LIBRARIES})
update_cached_list(TK_COMPILE_DEFINITIONS ENABLE_OPENSSL)
endif ()
endif ()
#mysql
if (ENABLE_MYSQL)
find_package(MYSQL)
if (MYSQL_FOUND)
update_cached_list(TK_INC_PATHS ${MYSQL_INCLUDE_DIR})
update_cached_list(TK_INC_PATHS ${MYSQL_INCLUDE_DIR}/mysql)
update_cached_list(TK_LINK_LIBRARIES ${MYSQL_LIBRARIES})
update_cached_list(TK_COMPILE_DEFINITIONS ENABLE_MYSQL)
endif ()
endif ()
#使delete[]freeasanMacOS
if (ASAN_USE_DELETE)
update_cached_list(TK_COMPILE_DEFINITIONS ASAN_USE_DELETE)
endif ()
#
add_library(${PROJECT_NAME}_deps INTERFACE)
target_link_libraries(${PROJECT_NAME}_deps INTERFACE ${TK_LINK_LIBRARIES})
#
add_library(${PROJECT_NAME} ${SRC_LIST})
#
target_include_directories(${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src ${TK_INC_PATHS})
target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_deps)
target_compile_definitions(${PROJECT_NAME} PUBLIC ${TK_COMPILE_DEFINITIONS})
target_compile_options(${PROJECT_NAME} PUBLIC ${TK_COMPILE_OPTIONS})
set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib
ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib)
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
#root
add_subdirectory(tests)
endif ()

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,138 @@
# 一个基于C++11简单易用的轻量级网络编程框架
![](https://github.com/ZLMediaKit/ZLToolKit/actions/workflows/linux.yml/badge.svg)
![](https://github.com/ZLMediaKit/ZLToolKit/actions/workflows/macos.yml/badge.svg)
![](https://github.com/ZLMediaKit/ZLToolKit/actions/workflows/windows.yml/badge.svg)
## 项目特点
- 基于C++11开发避免使用裸指针代码稳定可靠同时跨平台移植简单方便代码清晰简洁。
- 使用epoll+线程池+异步网络IO模式开发并发性能优越。
- 代码经过大量的稳定性、性能测试,可满足商用服务器项目。
- 支持linux、macos、ios、android、windows平台
- 了解更多:[ZLMediaKit](https://github.com/ZLMediaKit/ZLMediaKit)
## 特性
- 网络库
- tcp/udp客户端接口简单易用并且是线程安全的用户不必关心具体的socket api操作。
- tcp/udp服务器使用非常简单只要实现具体的tcp/udp会话Session类逻辑,使用模板的方式可以快速的构建高性能的服务器。
- 对套接字多种操作的封装。
- 线程库
- 使用线程实现的简单易用的定时器。
- 信号量。
- 线程组。
- 简单易用的线程池可以异步或同步执行任务支持functional 和 lambad表达式。
- 工具库
- 文件操作。
- std::cout风格的日志库支持颜色高亮、代码定位、异步打印。
- INI配置文件的读写。
- 监听者模式的消息广播器。
- 基于智能指针的循环池,不需要显式手动释放。
- 环形缓冲,支持主动读取和读取事件两种模式。
- mysql链接池使用占位符方式生成sql语句支持同步异步操作。
- 简单易用的ssl加解密黑盒支持多线程。
- 其他一些有用的工具。
- 命令行解析工具,可以很便捷的实现可配置应用程序
## 网络IO适配
| | Linux(Android) | Windows | MacOS(iOS/Unix) |
|:----:|:-----------------:|:-------------------:|:----------------:|
| 多路复用 | epoll/select | wepoll(iocp)/select | kqueue/select |
| udp | recvmmsg/sendmmsg | recvfrom/WSASend | recvfrom/sendto |
| tcp | recvfrom/sendmsg | recvfrom/WSASend | recvfrom/sendmsg |
## 编译(Linux)
- 我的编译环境
- Ubuntu16.04 64 bit + gcc5.4(最低gcc4.7)
- cmake 3.5.1
- 编译
```
cd ZLToolKit
./build_for_linux.sh
```
## 编译(macOS)
- 我的编译环境
- macOS Sierra(10.12.1) + xcode8.3.1
- Homebrew 1.1.3
- cmake 3.8.0
- 编译
```
cd ZLToolKit
./build_for_mac.sh
```
## 编译(iOS)
- 编译环境:`请参考macOS的编译指导。`
- 编译
```
cd ZLToolKit
./build_for_ios.sh
```
- 你也可以生成Xcode工程再编译
```
cd ZLToolKit
mkdir -p build
cd build
# 生成Xcode工程工程文件在build目录下
cmake .. -DCMAKE_TOOLCHAIN_FILE=../cmake/iOS.cmake -DIOS_PLATFORM=SIMULATOR64 -G "Xcode"
```
## 编译(Android)
- 我的编译环境
- macOS Sierra(10.12.1) + xcode8.3.1
- Homebrew 1.1.3
- cmake 3.8.0
- [android-ndk-r14b](https://dl.google.com/android/repository/android-ndk-r14b-darwin-x86_64.zip)
- 编译
```
cd ZLToolKit
export ANDROID_NDK_ROOT=/path/to/ndk
./build_for_android.sh
```
## 编译(Windows)
- 我的编译环境
- windows 10
- visual studio 2017
- [openssl](http://slproweb.com/download/Win32OpenSSL-1_1_0f.exe)
- [mysqlclient](https://dev.mysql.com/downloads/file/?id=472430)
- [cmake-gui](https://cmake.org/files/v3.10/cmake-3.10.0-rc1-win32-x86.msi)
- 编译
```
  1 使用cmake-gui打开工程并生成vs工程文件.
  2 找到工程文件(ZLToolKit.sln),双击用vs2017打开.
  3 选择编译Release 版本.
  4 依次编译 ZLToolKit_static、ZLToolKit_shared、ALL_BUILD、INSTALL.
5 找到目标文件并运行测试用例.
  6 找到安装的头文件及库文件(在源码所在分区根目录).
```
## 授权协议
本项目自有代码使用宽松的MIT协议在保留版权信息的情况下可以自由应用于各自商用、非商业的项目。
但是本项目也零碎的使用了一些其他的开源代码,在商用的情况下请自行替代或剔除;
由于使用本项目而产生的商业纠纷或侵权行为一概与本项项目及开发者无关,请自行承担法律风险。
## QA
- 该库性能怎么样?
基于ZLToolKit我实现了一个流媒体服务器[ZLMediaKit](https://github.com/ZLMediaKit/ZLMediaKit);作者已经对其进行了性能测试,可以查看[benchmark.md](https://github.com/ZLMediaKit/ZLMediaKit/blob/master/benchmark.md)了解详情。
- 该库稳定性怎么样?
该库经过作者严格的valgrind测试长时间大负荷的测试作者也使用该库进行了多个线上项目的开发。实践证明该库稳定性很好可以无看门狗脚本的方式连续运行几个月。
- 在windows下编译很多错误
由于本项目主体代码在macOS/linux下开发部分源码采用的是无bom头的UTF-8编码由于windows对于utf-8支持不甚友好所以如果发现编译错误请先尝试添加bom头再编译。
## 联系方式
- 邮箱:<1213642868@qq.com>本项目相关或网络编程相关问题请走issue流程否则恕不邮件答复
- QQ群542509000

View File

@ -0,0 +1,132 @@
# - Try to find MySQL / MySQL Embedded library
# Find the MySQL includes and client library
# This module defines
# MYSQL_INCLUDE_DIR, where to find mysql.h
# MYSQL_LIBRARIES, the libraries needed to use MySQL.
# MYSQL_LIB_DIR, path to the MYSQL_LIBRARIES
# MYSQL_EMBEDDED_LIBRARIES, the libraries needed to use MySQL Embedded.
# MYSQL_EMBEDDED_LIB_DIR, path to the MYSQL_EMBEDDED_LIBRARIES
# MYSQL_FOUND, If false, do not try to use MySQL.
# MYSQL_EMBEDDED_FOUND, If false, do not try to use MySQL Embedded.
# Copyright (c) 2006-2008, Jarosław Staniek <staniek@kde.org>
#
# Redistribution and use is allowed according to the terms of the BSD license.
# For details see the accompanying COPYING-CMAKE-SCRIPTS file.
include(CheckCXXSourceCompiles)
if(WIN32)
find_path(MYSQL_INCLUDE_DIR mysql.h
PATHS
$ENV{MYSQL_INCLUDE_DIR}
$ENV{MYSQL_DIR}/include
$ENV{ProgramFiles}/MySQL/*/include
$ENV{SystemDrive}/MySQL/*/include
$ENV{ProgramW6432}/MySQL/*/include
)
else(WIN32)
#Mac OS, mysql.hmysql
#/usr/local/mysql/include/mysql.h
find_path(MYSQL_INCLUDE_DIR mysql.h
PATHS
$ENV{MYSQL_INCLUDE_DIR}
$ENV{MYSQL_DIR}/include
/usr/local/mysql/include
/usr/local/mysql/include/mysql
/opt/mysql/mysql/include
PATH_SUFFIXES
mysql
)
endif(WIN32)
if(WIN32)
if (${CMAKE_BUILD_TYPE})
string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_TOLOWER)
endif()
# path suffix for debug/release mode
# binary_dist: mysql binary distribution
# build_dist: custom build
if(CMAKE_BUILD_TYPE_TOLOWER MATCHES "debug")
set(binary_dist debug)
set(build_dist Debug)
else(CMAKE_BUILD_TYPE_TOLOWER MATCHES "debug")
ADD_DEFINITIONS(-DDBUG_OFF)
set(binary_dist opt)
set(build_dist Release)
endif(CMAKE_BUILD_TYPE_TOLOWER MATCHES "debug")
# find_library(MYSQL_LIBRARIES NAMES mysqlclient
set(MYSQL_LIB_PATHS
$ENV{MYSQL_DIR}/lib/${binary_dist}
$ENV{MYSQL_DIR}/libmysql/${build_dist}
$ENV{MYSQL_DIR}/client/${build_dist}
$ENV{ProgramFiles}/MySQL/*/lib/${binary_dist}
$ENV{SystemDrive}/MySQL/*/lib/${binary_dist}
$ENV{MYSQL_DIR}/lib/opt
$ENV{MYSQL_DIR}/client/release
$ENV{ProgramFiles}/MySQL/*/lib/opt
$ENV{ProgramFiles}/MySQL/*/lib/
$ENV{SystemDrive}/MySQL/*/lib/opt
$ENV{ProgramW6432}/MySQL/*/lib
)
find_library(MYSQL_LIBRARIES NAMES libmysql
PATHS
${MYSQL_LIB_PATHS}
)
else(WIN32)
# find_library(MYSQL_LIBRARIES NAMES mysqlclient
set(MYSQL_LIB_PATHS
$ENV{MYSQL_DIR}/libmysql_r/.libs
$ENV{MYSQL_DIR}/lib
$ENV{MYSQL_DIR}/lib/mysql
/usr/local/mysql/lib
/opt/mysql/mysql/lib
$ENV{MYSQL_DIR}/libmysql_r/.libs
$ENV{MYSQL_DIR}/lib
$ENV{MYSQL_DIR}/lib/mysql
/usr/local/mysql/lib
/opt/mysql/mysql/lib
PATH_SUFFIXES
mysql
)
find_library(MYSQL_LIBRARIES NAMES mysqlclient
PATHS
${MYSQL_LIB_PATHS}
)
endif(WIN32)
find_library(MYSQL_EMBEDDED_LIBRARIES NAMES mysqld
PATHS
${MYSQL_LIB_PATHS}
)
if(MYSQL_LIBRARIES)
get_filename_component(MYSQL_LIB_DIR ${MYSQL_LIBRARIES} PATH)
endif(MYSQL_LIBRARIES)
if(MYSQL_EMBEDDED_LIBRARIES)
get_filename_component(MYSQL_EMBEDDED_LIB_DIR ${MYSQL_EMBEDDED_LIBRARIES} PATH)
endif(MYSQL_EMBEDDED_LIBRARIES)
set( CMAKE_REQUIRED_INCLUDES ${MYSQL_INCLUDE_DIR} )
set( CMAKE_REQUIRED_LIBRARIES ${MYSQL_EMBEDDED_LIBRARIES} )
check_cxx_source_compiles( "#include <mysql.h>\nint main() { int i = MYSQL_OPT_USE_EMBEDDED_CONNECTION; }" HAVE_MYSQL_OPT_EMBEDDED_CONNECTION )
if(MYSQL_INCLUDE_DIR AND MYSQL_LIBRARIES)
set(MYSQL_FOUND TRUE)
message(STATUS "Found MySQL: ${MYSQL_INCLUDE_DIR}, ${MYSQL_LIBRARIES}")
else(MYSQL_INCLUDE_DIR AND MYSQL_LIBRARIES)
set(MYSQL_FOUND FALSE)
message(STATUS "MySQL not found.")
endif(MYSQL_INCLUDE_DIR AND MYSQL_LIBRARIES)
if(MYSQL_INCLUDE_DIR AND MYSQL_EMBEDDED_LIBRARIES AND HAVE_MYSQL_OPT_EMBEDDED_CONNECTION)
set(MYSQL_EMBEDDED_FOUND TRUE)
message(STATUS "Found MySQL Embedded: ${MYSQL_INCLUDE_DIR}, ${MYSQL_EMBEDDED_LIBRARIES}")
else(MYSQL_INCLUDE_DIR AND MYSQL_EMBEDDED_LIBRARIES AND HAVE_MYSQL_OPT_EMBEDDED_CONNECTION)
set(MYSQL_EMBEDDED_FOUND FALSE)
message(STATUS "MySQL Embedded not found.")
endif(MYSQL_INCLUDE_DIR AND MYSQL_EMBEDDED_LIBRARIES AND HAVE_MYSQL_OPT_EMBEDDED_CONNECTION)
mark_as_advanced(MYSQL_INCLUDE_DIR MYSQL_LIBRARIES MYSQL_EMBEDDED_LIBRARIES)

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include <cstdlib>
#include "Buffer.h"
#include "Util/onceToken.h"
namespace toolkit {
StatisticImp(Buffer)
StatisticImp(BufferRaw)
StatisticImp(BufferLikeString)
BufferRaw::Ptr BufferRaw::create(size_t size) {
#if 0
static ResourcePool<BufferRaw> packet_pool;
static onceToken token([]() {
packet_pool.setSize(1024);
});
auto ret = packet_pool.obtain2();
ret->setSize(0);
return ret;
#else
return Ptr(new BufferRaw(size));
#endif
}
}//namespace toolkit

View File

@ -0,0 +1,491 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_BUFFER_H
#define ZLTOOLKIT_BUFFER_H
#include <cassert>
#include <memory>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include "Util/util.h"
#include "Util/ResourcePool.h"
namespace toolkit {
template <typename T> struct is_pointer : public std::false_type {};
template <typename T> struct is_pointer<std::shared_ptr<T> > : public std::true_type {};
template <typename T> struct is_pointer<std::shared_ptr<T const> > : public std::true_type {};
template <typename T> struct is_pointer<T*> : public std::true_type {};
template <typename T> struct is_pointer<const T*> : public std::true_type {};
//缓存基类 [AUTO-TRANSLATED:d130ab72]
//Cache base class
class Buffer : public noncopyable {
public:
using Ptr = std::shared_ptr<Buffer>;
Buffer() = default;
virtual ~Buffer() = default;
//返回数据长度 [AUTO-TRANSLATED:955f731c]
//Return data length
virtual char *data() const = 0;
virtual size_t size() const = 0;
virtual std::string toString() const {
return std::string(data(), size());
}
virtual size_t getCapacity() const {
return size();
}
private:
//对象个数统计 [AUTO-TRANSLATED:3b43e8c2]
//Object count statistics
ObjectStatistic<Buffer> _statistic;
};
template <typename C>
class BufferOffset : public Buffer {
public:
using Ptr = std::shared_ptr<BufferOffset>;
BufferOffset(C data, size_t offset = 0, size_t len = 0) : _data(std::move(data)) {
setup(offset, len);
}
~BufferOffset() override = default;
char *data() const override {
return const_cast<char *>(getPointer<C>(_data)->data()) + _offset;
}
size_t size() const override {
return _size;
}
std::string toString() const override {
return std::string(data(), size());
}
private:
void setup(size_t offset = 0, size_t size = 0) {
auto max_size = getPointer<C>(_data)->size();
assert(offset + size <= max_size);
if (!size) {
size = max_size - offset;
}
_size = size;
_offset = offset;
}
template<typename T>
static typename std::enable_if<::toolkit::is_pointer<T>::value, const T &>::type
getPointer(const T &data) {
return data;
}
template<typename T>
static typename std::enable_if<!::toolkit::is_pointer<T>::value, const T *>::type
getPointer(const T &data) {
return &data;
}
private:
C _data;
size_t _size;
size_t _offset;
};
using BufferString = BufferOffset<std::string>;
//指针式缓存对象, [AUTO-TRANSLATED:c8403290]
//Pointer-style cache object,
class BufferRaw : public Buffer {
public:
using Ptr = std::shared_ptr<BufferRaw>;
static Ptr create(size_t size = 0);
~BufferRaw() override {
if (_data) {
delete[] _data;
}
}
//在写入数据时请确保内存是否越界 [AUTO-TRANSLATED:5602043e]
//When writing data, please ensure that the memory does not overflow
char *data() const override {
return _data;
}
//有效数据大小 [AUTO-TRANSLATED:b8dcbda7]
//Effective data size
size_t size() const override {
return _size;
}
//分配内存大小 [AUTO-TRANSLATED:cce87adf]
//Allocated memory size
void setCapacity(size_t capacity) {
if (_data) {
do {
if (capacity > _capacity) {
//请求的内存大于当前内存,那么重新分配 [AUTO-TRANSLATED:65306424]
//If the requested memory is greater than the current memory, reallocate
break;
}
if (_capacity < 2 * 1024) {
//2K以下不重复开辟内存直接复用 [AUTO-TRANSLATED:056416c0]
//Less than 2K, do not repeatedly allocate memory, reuse directly
return;
}
if (2 * capacity > _capacity) {
//如果请求的内存大于当前内存的一半,那么也复用 [AUTO-TRANSLATED:c189d660]
//If the requested memory is greater than half of the current memory, also reuse
return;
}
} while (false);
delete[] _data;
}
_data = new char[capacity];
_capacity = capacity;
}
//设置有效数据大小 [AUTO-TRANSLATED:efc4fb3e]
//Set valid data size
virtual void setSize(size_t size) {
if (size > _capacity) {
throw std::invalid_argument("Buffer::setSize out of range");
}
_size = size;
}
//赋值数据 [AUTO-TRANSLATED:0b91b213]
//Assign data
void assign(const char *data, size_t size = 0) {
if (size <= 0) {
size = strlen(data);
}
setCapacity(size + 1);
memcpy(_data, data, size);
_data[size] = '\0';
setSize(size);
}
size_t getCapacity() const override {
return _capacity;
}
protected:
friend class ResourcePool_l<BufferRaw>;
BufferRaw(size_t capacity = 0) {
if (capacity) {
setCapacity(capacity);
}
}
BufferRaw(const char *data, size_t size = 0) {
assign(data, size);
}
private:
size_t _size = 0;
size_t _capacity = 0;
char *_data = nullptr;
//对象个数统计 [AUTO-TRANSLATED:3b43e8c2]
//Object count statistics
ObjectStatistic<BufferRaw> _statistic;
};
class BufferLikeString : public Buffer {
public:
~BufferLikeString() override = default;
BufferLikeString() {
_erase_head = 0;
_erase_tail = 0;
}
BufferLikeString(std::string str) {
_str = std::move(str);
_erase_head = 0;
_erase_tail = 0;
}
BufferLikeString &operator=(std::string str) {
_str = std::move(str);
_erase_head = 0;
_erase_tail = 0;
return *this;
}
BufferLikeString(const char *str) {
_str = str;
_erase_head = 0;
_erase_tail = 0;
}
BufferLikeString &operator=(const char *str) {
_str = str;
_erase_head = 0;
_erase_tail = 0;
return *this;
}
BufferLikeString(BufferLikeString &&that) {
_str = std::move(that._str);
_erase_head = that._erase_head;
_erase_tail = that._erase_tail;
that._erase_head = 0;
that._erase_tail = 0;
}
BufferLikeString &operator=(BufferLikeString &&that) {
_str = std::move(that._str);
_erase_head = that._erase_head;
_erase_tail = that._erase_tail;
that._erase_head = 0;
that._erase_tail = 0;
return *this;
}
BufferLikeString(const BufferLikeString &that) {
_str = that._str;
_erase_head = that._erase_head;
_erase_tail = that._erase_tail;
}
BufferLikeString &operator=(const BufferLikeString &that) {
_str = that._str;
_erase_head = that._erase_head;
_erase_tail = that._erase_tail;
return *this;
}
char *data() const override {
return (char *) _str.data() + _erase_head;
}
size_t size() const override {
return _str.size() - _erase_tail - _erase_head;
}
BufferLikeString &erase(size_t pos = 0, size_t n = std::string::npos) {
if (pos == 0) {
//移除前面的数据 [AUTO-TRANSLATED:b025d3c5]
//Remove data from the front
if (n != std::string::npos) {
//移除部分 [AUTO-TRANSLATED:a650bef2]
//Remove part
if (n > size()) {
//移除太多数据了 [AUTO-TRANSLATED:64460d15]
//Removed too much data
throw std::out_of_range("BufferLikeString::erase out_of_range in head");
}
//设置起始便宜量 [AUTO-TRANSLATED:7a0250bd]
//Set starting offset
_erase_head += n;
data()[size()] = '\0';
return *this;
}
//移除全部数据 [AUTO-TRANSLATED:3d016f79]
//Remove all data
_erase_head = 0;
_erase_tail = _str.size();
data()[0] = '\0';
return *this;
}
if (n == std::string::npos || pos + n >= size()) {
//移除末尾所有数据 [AUTO-TRANSLATED:efaf1165]
//Remove all data from the end
if (pos >= size()) {
//移除太多数据 [AUTO-TRANSLATED:dc9347c3]
//Removed too much data
throw std::out_of_range("BufferLikeString::erase out_of_range in tail");
}
_erase_tail += size() - pos;
data()[size()] = '\0';
return *this;
}
//移除中间的 [AUTO-TRANSLATED:fd25344c]
//Remove the middle
if (pos + n > size()) {
//超过长度限制 [AUTO-TRANSLATED:9ae84929]
//Exceeds the length limit
throw std::out_of_range("BufferLikeString::erase out_of_range in middle");
}
_str.erase(_erase_head + pos, n);
return *this;
}
BufferLikeString &append(const BufferLikeString &str) {
return append(str.data(), str.size());
}
BufferLikeString &append(const std::string &str) {
return append(str.data(), str.size());
}
BufferLikeString &append(const char *data) {
return append(data, strlen(data));
}
BufferLikeString &append(const char *data, size_t len) {
if (len <= 0) {
return *this;
}
if (_erase_head > _str.capacity() / 2) {
moveData();
}
if (_erase_tail == 0) {
_str.append(data, len);
return *this;
}
_str.insert(_erase_head + size(), data, len);
return *this;
}
void push_back(char c) {
if (_erase_tail == 0) {
_str.push_back(c);
return;
}
data()[size()] = c;
--_erase_tail;
data()[size()] = '\0';
}
BufferLikeString &insert(size_t pos, const char *s, size_t n) {
_str.insert(_erase_head + pos, s, n);
return *this;
}
BufferLikeString &assign(const char *data) {
return assign(data, strlen(data));
}
BufferLikeString &assign(const char *data, size_t len) {
if (len <= 0) {
return *this;
}
if (data >= _str.data() && data < _str.data() + _str.size()) {
_erase_head = data - _str.data();
if (data + len > _str.data() + _str.size()) {
throw std::out_of_range("BufferLikeString::assign out_of_range");
}
_erase_tail = _str.data() + _str.size() - (data + len);
return *this;
}
_str.assign(data, len);
_erase_head = 0;
_erase_tail = 0;
return *this;
}
void clear() {
_erase_head = 0;
_erase_tail = 0;
_str.clear();
}
char &operator[](size_t pos) {
if (pos >= size()) {
throw std::out_of_range("BufferLikeString::operator[] out_of_range");
}
return data()[pos];
}
const char &operator[](size_t pos) const {
return (*const_cast<BufferLikeString *>(this))[pos];
}
size_t capacity() const {
return _str.capacity();
}
void reserve(size_t size) {
_str.reserve(size);
}
void resize(size_t size, char c = '\0') {
auto old_size = this->size();
if (size == old_size) {
return;
}
if (size > old_size) {
auto append = size - old_size;
if (append > _erase_tail) {
_str.resize(append - _erase_tail, c);
memset(const_cast<char *>(_str.data()) + _erase_head + old_size, c, _erase_tail);
_erase_tail = 0;
} else {
_erase_tail -= append;
memset(const_cast<char *>(_str.data()) + _erase_head + old_size, c, append);
}
} else {
auto erased = old_size - size;
_erase_tail += erased;
memset(const_cast<char *>(_str.data()) + _erase_head + size, c, erased);
}
}
bool empty() const {
return size() <= 0;
}
std::string substr(size_t pos, size_t n = std::string::npos) const {
if (n == std::string::npos) {
//获取末尾所有的 [AUTO-TRANSLATED:8a0b92b6]
//Get all at the end
if (pos >= size()) {
throw std::out_of_range("BufferLikeString::substr out_of_range");
}
return _str.substr(_erase_head + pos, size() - pos);
}
//获取部分 [AUTO-TRANSLATED:d01310a4]
//Get part
if (pos + n > size()) {
throw std::out_of_range("BufferLikeString::substr out_of_range");
}
return _str.substr(_erase_head + pos, n);
}
protected:
size_t _erase_head;
size_t _erase_tail;
std::string _str;
private:
void moveData() {
if (_erase_head) {
_str.erase(0, _erase_head);
_erase_head = 0;
}
}
//对象个数统计 [AUTO-TRANSLATED:3b43e8c2]
//Object count statistics
ObjectStatistic<BufferLikeString> _statistic;
};
}//namespace toolkit
#endif //ZLTOOLKIT_BUFFER_H

View File

@ -0,0 +1,622 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include <assert.h>
#include "BufferSock.h"
#include "Util/logger.h"
#include "Util/uv_errno.h"
#if defined(__linux__) || defined(__linux)
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#ifndef MSG_WAITFORONE
#define MSG_WAITFORONE 0x10000
#endif
#ifndef HAVE_MMSG_HDR
struct mmsghdr {
struct msghdr msg_hdr;
unsigned msg_len;
};
#endif
#ifndef HAVE_SENDMMSG_API
#include <unistd.h>
#include <sys/syscall.h>
static inline int sendmmsg(int fd, struct mmsghdr *mmsg,
unsigned vlen, unsigned flags)
{
return syscall(__NR_sendmmsg, fd, mmsg, vlen, flags);
}
#endif
#ifndef HAVE_RECVMMSG_API
#include <unistd.h>
#include <sys/syscall.h>
static inline int recvmmsg(int fd, struct mmsghdr *mmsg,
unsigned vlen, unsigned flags, struct timespec *timeout)
{
return syscall(__NR_recvmmsg, fd, mmsg, vlen, flags, timeout);
}
#endif
#endif// defined(__linux__) || defined(__linux)
namespace toolkit {
StatisticImp(BufferList)
/////////////////////////////////////// BufferSock ///////////////////////////////////////
BufferSock::BufferSock(Buffer::Ptr buffer, struct sockaddr *addr, int addr_len) {
if (addr) {
_addr_len = addr_len ? addr_len : SockUtil::get_sock_len(addr);
memcpy(&_addr, addr, _addr_len);
}
assert(buffer);
_buffer = std::move(buffer);
}
char *BufferSock::data() const {
return _buffer->data();
}
size_t BufferSock::size() const {
return _buffer->size();
}
const struct sockaddr *BufferSock::sockaddr() const {
return (struct sockaddr *)&_addr;
}
socklen_t BufferSock::socklen() const {
return _addr_len;
}
/////////////////////////////////////// BufferCallBack ///////////////////////////////////////
class BufferCallBack {
public:
BufferCallBack(List<std::pair<Buffer::Ptr, bool> > list, BufferList::SendResult cb)
: _cb(std::move(cb))
, _pkt_list(std::move(list)) {}
~BufferCallBack() {
sendCompleted(false);
}
void sendCompleted(bool flag) {
if (_cb) {
//全部发送成功或失败回调 [AUTO-TRANSLATED:6b9a9abf]
//All send success or failure callback
while (!_pkt_list.empty()) {
_cb(_pkt_list.front().first, flag);
_pkt_list.pop_front();
}
} else {
_pkt_list.clear();
}
}
void sendFrontSuccess() {
if (_cb) {
//发送成功回调 [AUTO-TRANSLATED:52759efc]
//Send success callback
_cb(_pkt_list.front().first, true);
}
_pkt_list.pop_front();
}
protected:
BufferList::SendResult _cb;
List<std::pair<Buffer::Ptr, bool> > _pkt_list;
};
/////////////////////////////////////// BufferSendMsg ///////////////////////////////////////
#if defined(_WIN32)
using SocketBuf = WSABUF;
#else
using SocketBuf = iovec;
#endif
class BufferSendMsg final : public BufferList, public BufferCallBack {
public:
using SocketBufVec = std::vector<SocketBuf>;
BufferSendMsg(List<std::pair<Buffer::Ptr, bool> > list, SendResult cb);
~BufferSendMsg() override = default;
bool empty() override;
size_t count() override;
ssize_t send(int fd, int flags) override;
private:
void reOffset(size_t n);
ssize_t send_l(int fd, int flags);
private:
size_t _iovec_off = 0;
size_t _remain_size = 0;
SocketBufVec _iovec;
};
bool BufferSendMsg::empty() {
return _remain_size == 0;
}
size_t BufferSendMsg::count() {
return _iovec.size() - _iovec_off;
}
ssize_t BufferSendMsg::send_l(int fd, int flags) {
ssize_t n;
#if !defined(_WIN32)
do {
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &(_iovec[_iovec_off]);
msg.msg_iovlen = _iovec.size() - _iovec_off;
if (msg.msg_iovlen > IOV_MAX) {
msg.msg_iovlen = IOV_MAX;
}
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = flags;
n = sendmsg(fd, &msg, flags);
} while (-1 == n && UV_EINTR == get_uv_error(true));
#else
do {
DWORD sent = 0;
n = WSASend(fd, const_cast<LPWSABUF>(&_iovec[0]), static_cast<DWORD>(_iovec.size()), &sent, static_cast<DWORD>(flags), 0, 0);
if (n == SOCKET_ERROR) {return -1;}
n = sent;
} while (n < 0 && UV_ECANCELED == get_uv_error(true));
#endif
if (n >= (ssize_t)_remain_size) {
//全部写完了 [AUTO-TRANSLATED:c990f48a]
//All written
_remain_size = 0;
sendCompleted(true);
return n;
}
if (n > 0) {
//部分发送成功 [AUTO-TRANSLATED:4c240905]
//Partial send success
reOffset(n);
return n;
}
//一个字节都未发送 [AUTO-TRANSLATED:c33c611b]
//Not a single byte sent
return n;
}
ssize_t BufferSendMsg::send(int fd, int flags) {
auto remain_size = _remain_size;
while (_remain_size && send_l(fd, flags) != -1);
ssize_t sent = remain_size - _remain_size;
if (sent > 0) {
//部分或全部发送成功 [AUTO-TRANSLATED:a3f5e70e]
//Partial or all send success
return sent;
}
//一个字节都未发送成功 [AUTO-TRANSLATED:858b63e5]
//Not a single byte sent successfully
return -1;
}
void BufferSendMsg::reOffset(size_t n) {
_remain_size -= n;
size_t offset = 0;
for (auto i = _iovec_off; i != _iovec.size(); ++i) {
auto &ref = _iovec[i];
#if !defined(_WIN32)
offset += ref.iov_len;
#else
offset += ref.len;
#endif
if (offset < n) {
//此包发送完毕 [AUTO-TRANSLATED:759b9f0e]
//This package is sent
sendFrontSuccess();
continue;
}
_iovec_off = i;
if (offset == n) {
//这是末尾发送完毕的一个包 [AUTO-TRANSLATED:6a3b77e4]
//This is the last package sent
++_iovec_off;
sendFrontSuccess();
break;
}
//这是末尾发送部分成功的一个包 [AUTO-TRANSLATED:64645cef]
//This is the last package partially sent
size_t remain = offset - n;
#if !defined(_WIN32)
ref.iov_base = (char *)ref.iov_base + ref.iov_len - remain;
ref.iov_len = remain;
#else
ref.buf = (CHAR *)ref.buf + ref.len - remain;
ref.len = remain;
#endif
break;
}
}
BufferSendMsg::BufferSendMsg(List<std::pair<Buffer::Ptr, bool>> list, SendResult cb)
: BufferCallBack(std::move(list), std::move(cb))
, _iovec(_pkt_list.size()) {
auto it = _iovec.begin();
_pkt_list.for_each([&](std::pair<Buffer::Ptr, bool> &pr) {
#if !defined(_WIN32)
it->iov_base = pr.first->data();
it->iov_len = pr.first->size();
_remain_size += it->iov_len;
#else
it->buf = pr.first->data();
it->len = pr.first->size();
_remain_size += it->len;
#endif
++it;
});
}
/////////////////////////////////////// BufferSendTo ///////////////////////////////////////
class BufferSendTo final: public BufferList, public BufferCallBack {
public:
BufferSendTo(List<std::pair<Buffer::Ptr, bool> > list, SendResult cb, bool is_udp);
~BufferSendTo() override = default;
bool empty() override;
size_t count() override;
ssize_t send(int fd, int flags) override;
private:
bool _is_udp;
size_t _offset = 0;
};
BufferSendTo::BufferSendTo(List<std::pair<Buffer::Ptr, bool>> list, BufferList::SendResult cb, bool is_udp)
: BufferCallBack(std::move(list), std::move(cb))
, _is_udp(is_udp) {}
bool BufferSendTo::empty() {
return _pkt_list.empty();
}
size_t BufferSendTo::count() {
return _pkt_list.size();
}
static inline BufferSock *getBufferSockPtr(std::pair<Buffer::Ptr, bool> &pr) {
if (!pr.second) {
return nullptr;
}
return static_cast<BufferSock *>(pr.first.get());
}
ssize_t BufferSendTo::send(int fd, int flags) {
size_t sent = 0;
ssize_t n;
while (!_pkt_list.empty()) {
auto &front = _pkt_list.front();
auto &buffer = front.first;
if (_is_udp) {
auto ptr = getBufferSockPtr(front);
n = ::sendto(fd, buffer->data() + _offset, buffer->size() - _offset, flags, ptr ? ptr->sockaddr() : nullptr, ptr ? ptr->socklen() : 0);
} else {
n = ::send(fd, buffer->data() + _offset, buffer->size() - _offset, flags);
}
if (n >= 0) {
assert(n);
_offset += n;
if (_offset == buffer->size()) {
sendFrontSuccess();
_offset = 0;
}
sent += n;
continue;
}
//n == -1的情况 [AUTO-TRANSLATED:305fb5bc]
//n == -1 case
if (get_uv_error(true) == UV_EINTR) {
//被打断,需要继续发送 [AUTO-TRANSLATED:6ef0b34d]
//interrupted, need to continue sending
continue;
}
//其他原因导致的send返回-1 [AUTO-TRANSLATED:299cddb7]
//other reasons causing send to return -1
break;
}
return sent ? sent : -1;
}
/////////////////////////////////////// BufferSendMmsg ///////////////////////////////////////
#if defined(__linux__) || defined(__linux)
class BufferSendMMsg : public BufferList, public BufferCallBack {
public:
BufferSendMMsg(List<std::pair<Buffer::Ptr, bool> > list, SendResult cb);
~BufferSendMMsg() override = default;
bool empty() override;
size_t count() override;
ssize_t send(int fd, int flags) override;
private:
void reOffset(size_t n);
ssize_t send_l(int fd, int flags);
private:
size_t _remain_size = 0;
std::vector<struct iovec> _iovec;
std::vector<struct mmsghdr> _hdrvec;
};
bool BufferSendMMsg::empty() {
return _remain_size == 0;
}
size_t BufferSendMMsg::count() {
return _hdrvec.size();
}
ssize_t BufferSendMMsg::send_l(int fd, int flags) {
ssize_t n;
do {
n = sendmmsg(fd, &_hdrvec[0], _hdrvec.size(), flags);
} while (-1 == n && UV_EINTR == get_uv_error(true));
if (n > 0) {
//部分或全部发送成功 [AUTO-TRANSLATED:a3f5e70e]
//partially or fully sent successfully
reOffset(n);
return n;
}
//一个字节都未发送 [AUTO-TRANSLATED:c33c611b]
//not a single byte sent
return n;
}
ssize_t BufferSendMMsg::send(int fd, int flags) {
auto remain_size = _remain_size;
while (_remain_size && send_l(fd, flags) != -1);
ssize_t sent = remain_size - _remain_size;
if (sent > 0) {
//部分或全部发送成功 [AUTO-TRANSLATED:a3f5e70e]
//partially or fully sent successfully
return sent;
}
//一个字节都未发送成功 [AUTO-TRANSLATED:858b63e5]
//not a single byte sent successfully
return -1;
}
void BufferSendMMsg::reOffset(size_t n) {
for (auto it = _hdrvec.begin(); it != _hdrvec.end();) {
auto &hdr = *it;
auto &io = *(hdr.msg_hdr.msg_iov);
assert(hdr.msg_len <= io.iov_len);
_remain_size -= hdr.msg_len;
if (hdr.msg_len == io.iov_len) {
//这个udp包全部发送成功 [AUTO-TRANSLATED:fce1cc86]
//this UDP packet sent successfully
it = _hdrvec.erase(it);
sendFrontSuccess();
continue;
}
//部分发送成功 [AUTO-TRANSLATED:4c240905]
//partially sent successfully
io.iov_base = (char *)io.iov_base + hdr.msg_len;
io.iov_len -= hdr.msg_len;
break;
}
}
BufferSendMMsg::BufferSendMMsg(List<std::pair<Buffer::Ptr, bool>> list, SendResult cb)
: BufferCallBack(std::move(list), std::move(cb))
, _iovec(_pkt_list.size())
, _hdrvec(_pkt_list.size()) {
auto i = 0U;
_pkt_list.for_each([&](std::pair<Buffer::Ptr, bool> &pr) {
auto &io = _iovec[i];
io.iov_base = pr.first->data();
io.iov_len = pr.first->size();
_remain_size += io.iov_len;
auto ptr = getBufferSockPtr(pr);
auto &mmsg = _hdrvec[i];
auto &msg = mmsg.msg_hdr;
mmsg.msg_len = 0;
msg.msg_name = ptr ? (void *)ptr->sockaddr() : nullptr;
msg.msg_namelen = ptr ? ptr->socklen() : 0;
msg.msg_iov = &io;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
++i;
});
}
#endif //defined(__linux__) || defined(__linux)
BufferList::Ptr BufferList::create(List<std::pair<Buffer::Ptr, bool> > list, SendResult cb, bool is_udp) {
#if defined(_WIN32)
if (is_udp) {
// sendto/send 方案,待优化 [AUTO-TRANSLATED:e94184aa]
//sendto/send scheme, to be optimized
return std::make_shared<BufferSendTo>(std::move(list), std::move(cb), is_udp);
}
// WSASend方案 [AUTO-TRANSLATED:9ac7bb81]
//WSASend scheme
return std::make_shared<BufferSendMsg>(std::move(list), std::move(cb));
#elif defined(__linux__) || defined(__linux)
if (is_udp) {
// sendmmsg方案 [AUTO-TRANSLATED:4596c2c4]
//sendmmsg scheme
return std::make_shared<BufferSendMMsg>(std::move(list), std::move(cb));
}
// sendmsg方案 [AUTO-TRANSLATED:8846f9c4]
//sendmsg scheme
return std::make_shared<BufferSendMsg>(std::move(list), std::move(cb));
#else
if (is_udp) {
// sendto/send 方案, 可优化? [AUTO-TRANSLATED:21dbae7c]
//sendto/send scheme, can be optimized?
return std::make_shared<BufferSendTo>(std::move(list), std::move(cb), is_udp);
}
// sendmsg方案 [AUTO-TRANSLATED:8846f9c4]
//sendmsg scheme
return std::make_shared<BufferSendMsg>(std::move(list), std::move(cb));
#endif
}
#if defined(__linux) || defined(__linux__)
class SocketRecvmmsgBuffer : public SocketRecvBuffer {
public:
SocketRecvmmsgBuffer(size_t count, size_t size)
: _size(size)
, _iovec(count)
, _mmsgs(count)
, _buffers(count)
, _address(count) {
for (auto i = 0u; i < count; ++i) {
auto buf = BufferRaw::create();
buf->setCapacity(size);
_buffers[i] = buf;
auto &mmsg = _mmsgs[i];
auto &addr = _address[i];
mmsg.msg_len = 0;
mmsg.msg_hdr.msg_name = &addr;
mmsg.msg_hdr.msg_namelen = sizeof(addr);
mmsg.msg_hdr.msg_iov = &_iovec[i];
mmsg.msg_hdr.msg_iov->iov_base = buf->data();
mmsg.msg_hdr.msg_iov->iov_len = buf->getCapacity() - 1;
mmsg.msg_hdr.msg_iovlen = 1;
mmsg.msg_hdr.msg_control = nullptr;
mmsg.msg_hdr.msg_controllen = 0;
mmsg.msg_hdr.msg_flags = 0;
}
}
ssize_t recvFromSocket(int fd, ssize_t &count) override {
for (auto i = 0; i < _last_count; ++i) {
auto &mmsg = _mmsgs[i];
mmsg.msg_hdr.msg_namelen = sizeof(struct sockaddr_storage);
auto &buf = _buffers[i];
if (!buf) {
auto raw = BufferRaw::create();
raw->setCapacity(_size);
buf = raw;
mmsg.msg_hdr.msg_iov->iov_base = buf->data();
}
}
do {
count = recvmmsg(fd, &_mmsgs[0], _mmsgs.size(), 0, nullptr);
} while (-1 == count && UV_EINTR == get_uv_error(true));
_last_count = count;
if (count <= 0) {
return count;
}
ssize_t nread = 0;
for (auto i = 0; i < count; ++i) {
auto &mmsg = _mmsgs[i];
nread += mmsg.msg_len;
auto buf = std::static_pointer_cast<BufferRaw>(_buffers[i]);
buf->setSize(mmsg.msg_len);
buf->data()[mmsg.msg_len] = '\0';
}
return nread;
}
Buffer::Ptr &getBuffer(size_t index) override { return _buffers[index]; }
struct sockaddr_storage &getAddress(size_t index) override { return _address[index]; }
private:
size_t _size;
ssize_t _last_count { 0 };
std::vector<struct iovec> _iovec;
std::vector<struct mmsghdr> _mmsgs;
std::vector<Buffer::Ptr> _buffers;
std::vector<struct sockaddr_storage> _address;
};
#endif
class SocketRecvFromBuffer : public SocketRecvBuffer {
public:
SocketRecvFromBuffer(size_t size): _size(size) {}
ssize_t recvFromSocket(int fd, ssize_t &count) override {
ssize_t nread;
socklen_t len = sizeof(_address);
if (!_buffer) {
allocBuffer();
}
do {
nread = recvfrom(fd, _buffer->data(), _buffer->getCapacity() - 1, 0, (struct sockaddr *)&_address, &len);
} while (-1 == nread && UV_EINTR == get_uv_error(true));
if (nread > 0) {
count = 1;
_buffer->data()[nread] = '\0';
std::static_pointer_cast<BufferRaw>(_buffer)->setSize(nread);
}
return nread;
}
Buffer::Ptr &getBuffer(size_t index) override { return _buffer; }
struct sockaddr_storage &getAddress(size_t index) override { return _address; }
private:
void allocBuffer() {
auto buf = BufferRaw::create();
buf->setCapacity(_size);
_buffer = std::move(buf);
}
private:
size_t _size;
Buffer::Ptr _buffer;
struct sockaddr_storage _address;
};
static constexpr auto kPacketCount = 32;
static constexpr auto kBufferCapacity = 4 * 1024u;
SocketRecvBuffer::Ptr SocketRecvBuffer::create(bool is_udp) {
#if defined(__linux) || defined(__linux__)
if (is_udp) {
return std::make_shared<SocketRecvmmsgBuffer>(kPacketCount, kBufferCapacity);
}
#endif
return std::make_shared<SocketRecvFromBuffer>(kPacketCount * kBufferCapacity);
}
} //toolkit

View File

@ -0,0 +1,87 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_BUFFERSOCK_H
#define ZLTOOLKIT_BUFFERSOCK_H
#if !defined(_WIN32)
#include <sys/uio.h>
#include <limits.h>
#endif
#include <cassert>
#include <memory>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include "Util/util.h"
#include "Util/List.h"
#include "Util/ResourcePool.h"
#include "sockutil.h"
#include "Buffer.h"
namespace toolkit {
#if !defined(IOV_MAX)
#define IOV_MAX 1024
#endif
class BufferSock : public Buffer {
public:
using Ptr = std::shared_ptr<BufferSock>;
BufferSock(Buffer::Ptr ptr, struct sockaddr *addr = nullptr, int addr_len = 0);
~BufferSock() override = default;
char *data() const override;
size_t size() const override;
const struct sockaddr *sockaddr() const;
socklen_t socklen() const;
private:
int _addr_len = 0;
struct sockaddr_storage _addr;
Buffer::Ptr _buffer;
};
class BufferList : public noncopyable {
public:
using Ptr = std::shared_ptr<BufferList>;
using SendResult = toolkit::function_safe<void(const Buffer::Ptr &buffer, bool send_success)>;
BufferList() = default;
virtual ~BufferList() = default;
virtual bool empty() = 0;
virtual size_t count() = 0;
virtual ssize_t send(int fd, int flags) = 0;
static Ptr create(List<std::pair<Buffer::Ptr, bool> > list, SendResult cb, bool is_udp);
private:
//对象个数统计 [AUTO-TRANSLATED:3b43e8c2]
//Object count statistics
ObjectStatistic<BufferList> _statistic;
};
class SocketRecvBuffer {
public:
using Ptr = std::shared_ptr<SocketRecvBuffer>;
virtual ~SocketRecvBuffer() = default;
virtual ssize_t recvFromSocket(int fd, ssize_t &count) = 0;
virtual Buffer::Ptr &getBuffer(size_t index) = 0;
virtual struct sockaddr_storage &getAddress(size_t index) = 0;
static Ptr create(bool is_udp);
};
}
#endif //ZLTOOLKIT_BUFFERSOCK_H

View File

@ -0,0 +1,911 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "Kcp.h"
#include "Util/Byte.hpp"
using namespace std;
namespace toolkit {
static inline uint32_t _imin_(uint32_t a, uint32_t b) {
return a <= b ? a : b;
}
static inline uint32_t _imax_(uint32_t a, uint32_t b) {
return a >= b ? a : b;
}
static inline uint32_t _ibound_(uint32_t lower, uint32_t middle, uint32_t upper) {
return _imin_(_imax_(lower, middle), upper);
}
static inline long _itimediff(uint32_t later, uint32_t earlier) {
return ((int32_t)(later - earlier));
}
uint32_t getCurrent() {
return (uint32_t)(getCurrentMillisecond() & 0xfffffffful);
}
//////////// KcpHeader //////////////////////////
bool KcpHeader::loadHeaderFromData(const char *data, size_t len) {
if (HEADER_SIZE > len) {
WarnL << "data len: " << len << " too small";
return false;
}
int offset = 0;
_conv = Byte::Get4BytesLE((const uint8_t*)data, 0);
offset += 4;
_cmd = (Cmd)Byte::Get1Byte((const uint8_t*)data, offset);
offset += 1;
_frg = Byte::Get1Byte((const uint8_t*)data, offset);
offset += 1;
_wnd = Byte::Get2BytesLE((const uint8_t*)data, offset);
offset += 2;
_ts = Byte::Get4BytesLE((const uint8_t*)data, offset);
offset += 4;
_sn = Byte::Get4BytesLE((const uint8_t*)data, offset);
offset += 4;
_una = Byte::Get4BytesLE((const uint8_t*)data, offset);
offset += 4;
_len = Byte::Get4BytesLE((const uint8_t*)data, offset);
return true;
}
bool KcpHeader::storeHeaderToData(char *buf, size_t size) {
if (HEADER_SIZE > size) {
ErrorL << "size too smalle " << size;
return false;
}
char *ptr = buf;
int offset = 0;
Byte::Set4BytesLE((uint8_t*)buf, offset, _conv);
offset += 4;
Byte::Set1Byte((uint8_t*)buf, offset, (uint8_t)_cmd);
offset += 1;
Byte::Set1Byte((uint8_t*)buf, offset, _frg);
offset += 1;
Byte::Set2BytesLE((uint8_t*)buf, offset, _wnd);
offset += 2;
Byte::Set4BytesLE((uint8_t*)buf, offset, _ts);
offset += 4;
Byte::Set4BytesLE((uint8_t*)buf, offset, _sn);
offset += 4;
Byte::Set4BytesLE((uint8_t*)buf, offset, _una);
offset += 4;
Byte::Set4BytesLE((uint8_t*)buf, offset, _len);
return true;
}
//////////// KcpPacket //////////////////////////
KcpPacket::~KcpPacket() {
}
KcpPacket::Ptr KcpPacket::parse(const char* data, size_t len) {
auto packet = std::make_shared<KcpPacket>();
if (packet->loadFromData(data, len)) {
return packet;
}
return nullptr;
}
bool KcpPacket::loadFromData(const char *data, size_t len) {
if (!loadHeaderFromData(data, len)) {
return false;
}
auto packetSize = getPacketSize();
if (len < packetSize) {
WarnL << "data len: " << len << " is smaller than packet len: " << packetSize;
return false;
}
assign((const char *)(data), packetSize);
return true;
}
bool KcpPacket::storeToData() {
return storeHeaderToData(data(), size());
}
//////////// KcpTransport //////////////////////////
KcpTransport::KcpTransport(bool server_mode) {
_server_mode = server_mode;
if (!server_mode) {
//客户端 conv 随机生成
_conv = makeRandNum();
_conv_init = true;
}
_buffer_pool = BufferRaw::create(_mtu);
}
KcpTransport::KcpTransport(bool server_mode, const EventPoller::Ptr &poller)
: KcpTransport(server_mode) {
_poller = poller ? poller : EventPollerPool::Instance().getPoller();
}
KcpTransport::~KcpTransport() {
update();
}
ssize_t KcpTransport::send(const Buffer::Ptr& buf, bool flush) {
if (!_timer) {
startTimer();
}
if (!_conv_init) {
WarnL << "conv should set before send";
return -1;
}
auto size = buf->size();
if (size <= 0) {
return 0;
}
if (size >= _mss * IKCP_WND_RCV) {
WarnL << "size : "<< size << "over size, send fail";
//分片过大,拒绝发送
return -1;
}
auto cache = BufferRaw::create(size);
cache->assign(buf->data(), size);
_poller->async([=] {
auto data = cache->data();
auto leftLen = size;
auto extendLen = mergeSendQueue(data, leftLen);
data += extendLen;
leftLen -= extendLen;
// fragment
int count = (leftLen + _mss - 1) / _mss;
for (int i = 0; i < count; i++) {
auto len = std::min<size_t>(leftLen, _mss);
auto packet = std::make_shared<KcpDataPacket>(_conv, len);
memcpy(packet->getPayloadData(), data, len);
packet->setFrg(!_stream? (count - i - 1) : 0);
_snd_queue.push_back(packet);
data += len;
leftLen -= len;
}
if (flush) {
update();
// sendSendQueue();
}
}, true);
return size;
}
void KcpTransport::input(const Buffer::Ptr& buf) {
if (!_timer) {
startTimer();
}
auto cache = BufferRaw::create(buf->size());
cache->assign(buf->data(), buf->size());
_poller->async([=] {
// DebugL << hexdump(cache->data(), cache->size());
auto data = cache->data();
auto size = cache->size();
uint32_t current = getCurrent();
uint32_t prev_una = _snd_una;
uint32_t maxack = 0;
uint32_t latest_ts = 0;
bool fastAckFlag = false;
bool hasData = false;
while (size) {
auto packet = KcpPacket::parse(data, size);
if (!packet) {
WarnL << "parse kcp packet fail";
break;
}
data += packet->size();
size -= packet->size();
if (!_conv_init) {
_conv = packet->getConv();
_conv_init = true;
} else {
if (_conv != packet->getConv()) {
WarnL << "_conv check fail, skip this packet";
continue;
}
}
auto cmd = packet->getCmd();
if (cmd != KcpHeader::Cmd::CMD_PUSH && cmd != KcpHeader::Cmd::CMD_ACK &&
cmd != KcpHeader::Cmd::CMD_WASK && cmd != KcpHeader::Cmd::CMD_WINS) {
WarnL << "unknow cmd: " << (uint8_t)cmd;
continue;
}
handleAnyPacket(packet);
switch (cmd) {
case KcpHeader::Cmd::CMD_ACK: {
auto sn = packet->getSn();
auto ts = packet->getTs();
handleCmdAck(packet, current);
if (!fastAckFlag) {
fastAckFlag = true;
maxack = sn;
latest_ts = ts;
} else {
if (sn > maxack) {
if (!_fastack_conserve || ts > latest_ts) {
//激进模式
maxack = sn;
latest_ts = ts;
}
}
}
}
break;
case KcpHeader::Cmd::CMD_PUSH:
handleCmdPush(packet);
hasData = true;
break;
case KcpHeader::Cmd::CMD_WASK:
_probe |= IKCP_ASK_TELL;
break;
case KcpHeader::Cmd::CMD_WINS:
break;
default:
WarnL << "unknow cmd: " << (uint32_t)cmd;
break;
}
}
if (fastAckFlag) {
updateFastAck(maxack, latest_ts);
}
if (_snd_una > prev_una) {
//有新的应答,尝试增大拥塞窗口
increaseCwnd();
}
if (hasData) {
onData();
}
}, true);
return;
}
void KcpTransport::startTimer() {
if (!_poller) {
_poller = EventPollerPool::Instance().getPoller();
}
std::weak_ptr<KcpTransport> weak_self = std::static_pointer_cast<KcpTransport>(shared_from_this());
float interval = float(_interval)/ 1000.0;
_timer = std::make_shared<Timer>(interval, [weak_self]() -> bool {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
strong_self->update();
return true;
}, _poller);
return;
}
void KcpTransport::onData() {
bool fastRecover = false;
sortRecvBuf();
if (_rcv_queue.size() >= _rcv_wnd) {
//接受队列当前超过接收窗口大小
fastRecover = true;
}
// merge fragment
while (int size = peeksize()) {
while (1) {
int offset = 0;
auto buffer = BufferRaw::create(size);
buffer->setSize(size);
auto packet = _rcv_queue.front();
_rcv_queue.pop_front();
memcpy(buffer->data() + offset, packet->getPayloadData(), packet->getLen());
offset += packet->getLen();
if (packet->getFrg() == 0) {
onRead(buffer);
break;
}
}
}
// fast recover
if (_rcv_queue.size() < _rcv_wnd && fastRecover) {
// ready to send back IKCP_CMD_WINS
// tell remote my window size
_probe |= IKCP_ASK_TELL;
}
return;
}
int KcpTransport::peeksize() {
if (_rcv_queue.empty()) {
return 0;
}
//分包数据还没发送完全
if (_rcv_queue.size() < _rcv_queue.front()->getFrg() + 1) {
return 0;
}
int length = 0;
for (auto it = _rcv_queue.begin(); it != _rcv_queue.end(); it++) {
auto seg = *it;
length += seg->getLen();
if (seg->getFrg() == 0) {
break;
}
}
return length;
}
// move available data from rcv_buf -> rcv_queue
void KcpTransport::sortRecvBuf() {
#if 0
//直送应用层,不考虑接受队列满的情况
if (_rcv_queue.size() >= _rcv_wnd) {
//接收队列满
return;
}
#endif
while (!_rcv_buf.empty()) {
auto packet = _rcv_buf.front();
if (packet->getSn() == _rcv_nxt) {
//接收缓存中序号正确,且接受队列窗口足够
//将接收缓存中的包转到接受队列中
_rcv_buf.pop_front();
_rcv_queue.push_back(packet);
_rcv_nxt++;
} else {
break;
}
}
return;
}
// move data from snd_queue to snd_buf
void KcpTransport::sortSendQueue() {
uint32_t current = getCurrent();
uint32_t cwnd = _imin_(_snd_wnd, _rmt_wnd);
if (!_nocwnd) {
cwnd = _imin_(_cwnd, cwnd);
}
cwnd = _imax_(1, cwnd);
while (!_snd_queue.empty()) {
if (_snd_nxt >= _snd_una + cwnd) {
// WarnL << "snd cwnd over size";
break;
}
auto packet = _snd_queue.front();
_snd_queue.pop_front();
packet->setConv(_conv);
packet->setCmd(KcpHeader::Cmd::CMD_PUSH);
packet->setSn(_snd_nxt++);
packet->setXmit(0);
packet->setFastack(0);
#if 0
packet->setTs(current);
packet->setWnd(getWaitSnd());
packet->setUna(_rcv_nxt);
packet->setResendts(current);
packet->setRto(_rx_rto);
#endif
_snd_buf.push_back(packet);
}
return;
}
size_t KcpTransport::mergeSendQueue(const char *buffer, size_t len) {
if (len <= 0) {
return 0;
}
// 流发送模式,表示可以将当前buffer合并之前的包后面
if (!_stream) {
return 0;
}
//发送队列没有数据,不用合并
if (_snd_queue.empty()) {
return 0;
}
auto packet = _snd_queue.front();
size_t oldLen = packet->getLen();
if (oldLen >= _mss) {
//前一个包已经达到_mss长度,不允许合并
return 0;
}
size_t extendLen = std::min<size_t>(len, _mss - oldLen);
packet->setPayLoadSize(oldLen + extendLen);
memcpy(packet->getPayloadData() + oldLen, buffer, extendLen);
packet->setLen(oldLen + extendLen);
packet->setFrg(0);
return extendLen;
}
void KcpTransport::updateRtt(int32_t rtt) {
if (rtt < 0) {
return;
}
int32_t rto = 0;
//Jacobson/Karels RTT估算算法
if (_rx_srtt == 0) {
_rx_srtt = rtt;
_rx_rttval = rtt / 2;
} else {
long delta = abs(rtt - _rx_srtt);
_rx_rttval = (3 * _rx_rttval + delta) / 4;
_rx_srtt = (7 * _rx_srtt + rtt) / 8;
if (_rx_srtt < 1) {
_rx_srtt = 1;
}
}
rto = _rx_srtt + _imax_(_interval, 4 * _rx_rttval);
_rx_rto = _ibound_(_rx_minrto, rto, IKCP_RTO_MAX);
return;
}
void KcpTransport::dropCacheByUna(uint32_t una) {
// TraceL << "recv una: " << una;
if (una <= _snd_una) {
return;
}
while (!_snd_buf.empty()) {
if (una <= _snd_buf.front()->getSn()) {
break;
}
_snd_buf.pop_front();
}
_snd_una = _snd_buf.empty()? _snd_nxt : _snd_buf.front()->getSn();
return;
}
void KcpTransport::dropCacheByAck(uint32_t sn) {
// TraceL << "recv ack sn: " << sn;
if (sn < _snd_una) {
return;
}
for (auto it = _snd_buf.begin(); it != _snd_buf.end(); it++) {
if (sn < (*it)->getSn()) {
break;
} else if (sn == (*it)->getSn()) {
_snd_buf.erase(it);
break;
}
}
_snd_una = _snd_buf.empty()? _snd_nxt : _snd_buf.front()->getSn();
return;
}
void KcpTransport::updateFastAck(uint32_t sn, uint32_t ts) {
if (sn < _snd_una || sn >= _snd_nxt) {
return;
}
for (auto it = _snd_buf.begin(); it != _snd_buf.end(); it++) {
auto seg = *it;
if (sn < seg->getSn()) {
break;
} else if (sn != seg->getSn()) {
if (!_fastack_conserve || ts > seg->getTs()) {
seg->setFastack(seg->getFastack() + 1);
}
}
}
return;
}
void KcpTransport::increaseCwnd() {
if (_cwnd >= _rmt_wnd) {
return;
}
uint32_t mss = _mss;
if (_cwnd < _ssthresh) {
//慢启动阶段,拥塞窗口指数增长
_cwnd++;
_incr += mss;
} else {
//拥塞避免阶段,拥塞窗口线性增长
if (_incr < mss) {
_incr = mss;
}
_incr += (mss * mss) / _incr + (mss / 16);
if ((_cwnd + 1) * mss <= _incr) {
#if 1
_cwnd = (_incr + mss - 1) / ((mss > 0)? mss : 1);
#else
_cwnd++;
#endif
}
}
//控制不超过远端窗口大小
if (_cwnd > _rmt_wnd) {
_cwnd = _rmt_wnd;
_incr = _rmt_wnd * mss;
}
return;
}
void KcpTransport::handleAnyPacket(KcpPacket::Ptr packet) {
_rmt_wnd = packet->getWnd();
dropCacheByUna(packet->getUna());
return;
}
void KcpTransport::handleCmdAck(KcpPacket::Ptr packet, uint32_t current) {
updateRtt(current - packet->getTs());
dropCacheByAck(packet->getSn());
return;
}
void KcpTransport::handleCmdPush(KcpPacket::Ptr packet) {
auto sn = packet->getSn();
auto ts = packet->getTs();
// TraceL << "recv packet sn: " << sn << ", frg: " << (uint32_t)packet->getFrg();
if (sn >= _rcv_nxt + _rcv_wnd) {
// TraceL << "sn: " << sn << " is over wnd, _rcv_nxt: " << _rcv_nxt << ":, skip";
//超出接受窗口数据
return;
}
_acklist.push_back(std::make_pair(sn, ts));
if (sn < _rcv_nxt) {
// TraceL << "sn: " << sn << " is smaller than _rcv_nxt: " << _rcv_nxt << ":, skip";
return;
}
for (auto it = _rcv_buf.begin(); it != _rcv_buf.end(); it++) {
auto old = *it;
if (old->getSn() == sn) {
// TraceL << "sn: " << sn << " is repeat skip";
return;
}
if (old->getSn() > sn) {
_rcv_buf.insert(it, packet);
return;
}
}
_rcv_buf.push_back(packet);
return;
}
//获取当前空闲接受队列窗口
int KcpTransport::getRcvWndUnused() {
auto wnd = _rcv_wnd - _rcv_queue.size();
if (wnd > 0) {
return wnd;
}
return 0;
}
void KcpTransport::update() {
sendAckList();
sendProbePacket();
sendSendQueue();
}
void KcpTransport::sendSendQueue() {
uint32_t resent;
uint32_t rtomin;
bool change = false;
bool lost = false;
uint32_t current = getCurrent();
sortSendQueue();
// calculate resent
resent = (_fastresend > 0)? (uint32_t)_fastresend : 0xffffffff;
rtomin = (_delay_mode == DelayMode::DELAY_MODE_NORMAL)? (_rx_rto >> 3) : 0;
// flush data segments
for (auto it = _snd_buf.begin(); it != _snd_buf.end(); it++) {
bool needsend = false;
auto packet = *it;
auto xmit = packet->getXmit();
//没重传过,第一次发送数据包
if (xmit == 0) {
// TraceL << "normal send sn: " << packet->getSn();
needsend = true;
packet->setXmit(xmit + 1);
packet->setRto(_rx_rto);
packet->setResendts(current + _rx_rto + rtomin);
} else if (current >= packet->getResendts()) {
//普通重传
// TraceL << "resend sn: " << packet->getSn() << ", xmit: " << packet->getXmit();
needsend = true;
packet->setXmit(xmit + 1);
_xmit++;
auto rto = packet->getRto();
if (_delay_mode == DelayMode::DELAY_MODE_NORMAL == 0) {
packet->setRto(rto + _imax_(rto, (uint32_t)_rx_rto));
} else {
int32_t step = (_delay_mode == DelayMode::DELAY_MODE_FAST)? ((int32_t)(rto)) : _rx_rto;
packet->setRto(rto + step / 2);
}
packet->setResendts(current + rto);
lost = true;
} else if (packet->getFastack() >= resent) {
//快速重传
if ((int)xmit <= _fastlimit || _fastlimit <= 0) {
// TraceL << "fast resend sn: " << packet->getSn() << ", xmit: " << packet->getXmit();
auto rto = packet->getRto();
needsend = true;
packet->setXmit(xmit + 1);
packet->setFastack(0);
packet->setResendts(current + rto);
change = true;
}
}
if (needsend) {
int need;
packet->setTs(current);
packet->setWnd(getRcvWndUnused());
packet->setUna(_rcv_nxt);
sendPacket(packet);
if (packet->getXmit() >= _dead_link) {
onErr(SockException(Err_other,
(StrPrinter << "resend time : " << packet->getXmit() << " over " << _dead_link)));
}
}
}
flushPool();
decreaseCwnd(change, lost);
return;
}
void KcpTransport::sendAckList() {
while (!_acklist.empty()) {
auto front = _acklist.front();
_acklist.pop_front();
auto packet = std::make_shared<KcpAckPacket>(_conv);
packet->setWnd(getRcvWndUnused());
packet->setUna(_rcv_nxt);
packet->setSn(front.first);
packet->setTs(front.second);
sendPacket(packet);
// TraceL << "send ack sn: " << packet->getSn() << ", una: " << _rcv_nxt;
}
return;
}
void KcpTransport::sendProbePacket() {
uint32_t current = getCurrent();
// probe window size (if remote window size equals zero)
if (_rmt_wnd == 0) {
if (_probe_wait == 0) {
_probe_wait = IKCP_PROBE_INIT;
_ts_probe = current + _probe_wait;
} else {
if (_itimediff(current, _ts_probe) >= 0) {
if (_probe_wait < IKCP_PROBE_INIT) {
_probe_wait = IKCP_PROBE_INIT;
}
_probe_wait += _probe_wait / 2;
if (_probe_wait > IKCP_PROBE_LIMIT) {
_probe_wait = IKCP_PROBE_LIMIT;
}
_ts_probe = current + _probe_wait;
_probe |= IKCP_ASK_SEND;
}
}
} else {
_ts_probe = 0;
_probe_wait = 0;
}
// flush window probing commands
if (_probe & IKCP_ASK_SEND) {
auto packet = std::make_shared<KcpProbePacket>(_conv);
sendPacket(packet);
}
// flush window probing commands
if (_probe & IKCP_ASK_TELL) {
auto packet = std::make_shared<KcpTellPacket>(_conv);
sendPacket(packet);
}
_probe = 0;
return;
}
int KcpTransport::getWaitSnd() {
return _snd_buf.size() + _snd_queue.size();
}
// update ssthresh
void KcpTransport::decreaseCwnd(bool change, bool lost) {
//处理因为快速重传或者丢包的情况下,进行拥塞窗口处理
uint32_t resent = (_fastresend > 0)? (uint32_t)_fastresend : 0xffffffff;
// calculate window size
uint32_t cwnd = _imin_(_snd_wnd, _rmt_wnd);
if (_nocwnd == 0) {
cwnd = _imin_(_cwnd, cwnd);
}
//快速重传表明网络出现轻微拥塞,采用相对温和的调整策略。
//主动降低发送速率,但不是因为实际的丢包(可能是乱序)
if (change) {
//调整慢启动阈值为在途数据量的一半
uint32_t inflight = _snd_nxt - _snd_una;
_ssthresh = inflight / 2;
if (_ssthresh < IKCP_THRESH_MIN) {
_ssthresh = IKCP_THRESH_MIN;
}
_cwnd = _ssthresh + resent;
_incr = _cwnd * _mss;
}
//超时重传表明网络严重拥塞,采用激进的调整策略。
if (lost) {
_ssthresh = cwnd / 2;
if (_ssthresh < IKCP_THRESH_MIN) {
_ssthresh = IKCP_THRESH_MIN;
}
//重置拥塞窗口,回到慢启动阶段
_cwnd = 1;
_incr = _mss;
}
if (_cwnd < 1) {
_cwnd = 1;
_incr = _mss;
}
return;
}
void KcpTransport::setMtu(int mtu) {
if (mtu < 50 || mtu < KcpHeader::HEADER_SIZE) {
std::string err = (StrPrinter << "kcp setMtu " << mtu << "to small");
throw std::runtime_error(err);
}
_mtu = mtu;
_mss = _mtu - KcpHeader::HEADER_SIZE;
return;
}
void KcpTransport::setInterval(int interval) {
_interval = _ibound_(10, interval, 5000);
return;
}
void KcpTransport::setRxMinrto(int rx_minrto) {
_rx_minrto = rx_minrto;
return;
}
void KcpTransport::setDelayMode(DelayMode delay_mode) {
if (delay_mode < DelayMode::DELAY_MODE_NORMAL
|| delay_mode > DelayMode::DELAY_MODE_NO_DELAY) {
return;
}
_delay_mode = delay_mode;
if (delay_mode == DelayMode::DELAY_MODE_NORMAL) {
_rx_minrto = IKCP_RTO_MIN;
} else {
_rx_minrto = IKCP_RTO_NDL;
}
return;
}
void KcpTransport::setFastackConserve(bool flag) {
_fastack_conserve = flag;
return;
}
void KcpTransport::setNoCwnd(bool flag) {
_nocwnd = flag;
return;
}
void KcpTransport::setStreamMode(bool flag) {
_stream = flag;
return;
}
void KcpTransport::setFastResend(int resend) {
_fastresend = resend;
return;
}
void KcpTransport::setWndSize(int sndwnd, int rcvwnd) {
if (sndwnd > 0) {
_snd_wnd = sndwnd;
}
if (rcvwnd > 0) { // must >= max fragment size
_rcv_wnd = _imax_(rcvwnd, IKCP_WND_RCV);
}
return;
}
void KcpTransport::sendPacket(KcpPacket::Ptr pkt, bool flush) {
pkt->storeToData();
if (pkt->size() + _buffer_pool->size() > _mtu) {
flushPool();
}
memcpy(_buffer_pool->data() + _buffer_pool->size(), pkt->data(), pkt->size());
_buffer_pool->setSize(_buffer_pool->size() + pkt->size());
if (flush) {
flushPool();
}
return;
}
void KcpTransport::flushPool() {
onWrite(_buffer_pool);
_buffer_pool->setSize(0);
}
} // namespace toolkit

View File

@ -0,0 +1,385 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*
* code reference github.com/skywind3000/kcp/releases/tag/1.7.
*/
#ifndef TOOLKIT_NETWORK_KCP_H
#define TOOLKIT_NETWORK_KCP_H
#include "Network/Buffer.h"
#include "Network/sockutil.h"
#include "Poller/EventPoller.h"
#include "Poller/Timer.h"
#include "Util/TimeTicker.h"
#include "Socket.h"
namespace toolkit {
class KcpHeader {
public:
static const size_t HEADER_SIZE = 24;
enum class Cmd : uint8_t {
CMD_PUSH = 81, // cmd: push data
CMD_ACK = 82, // cmd: ack
CMD_WASK = 83, // cmd: window probe (ask)
CMD_WINS = 84, // cmd: window size (tell)
};
uint32_t _conv; // 会话ID,用于标识一个会话
Cmd _cmd; // 命令字段,用于标识数据包类型
uint8_t _frg = 0; // 分片序号,用于消息分片,0表示最后一片
uint16_t _wnd; // 接受窗口大小
uint32_t _ts; // 时间戳,2^32ms,约49.7天会溢出一次
uint32_t _sn; // 序列号
uint32_t _una; // 待接收的第一个未确认包序号
uint32_t _len = 0; // payload部分数据长度(不包含头长度)
public:
// Getters for KcpHeader members
uint32_t getConv() const { return _conv; }
Cmd getCmd() const { return _cmd; }
uint8_t getFrg() const { return _frg; }
uint16_t getWnd() const { return _wnd; }
uint32_t getTs() const { return _ts; }
uint32_t getSn() const { return _sn; }
uint32_t getUna() const { return _una; }
uint32_t getLen() const { return _len; }
// Setters for KcpHeader members
void setConv(uint32_t conv) { _conv = conv; }
void setCmd(Cmd cmd) { _cmd = cmd; }
void setFrg(uint8_t frg) { _frg = frg; }
void setWnd(uint16_t wnd) { _wnd = wnd; }
void setTs(uint32_t ts) { _ts = ts; }
void setSn(uint32_t sn) { _sn = sn; }
void setUna(uint32_t una) { _una = una; }
void setLen(uint32_t len) { _len = len; }
uint32_t getPacketSize() const { return _len + HEADER_SIZE; }
bool loadHeaderFromData(const char *data, size_t len);
bool storeHeaderToData(char *buf, size_t size);
};
class KcpPacket : public KcpHeader, public toolkit::BufferRaw {
public:
using Ptr = std::shared_ptr<KcpPacket>;
static KcpPacket::Ptr parse(const char* data, size_t len);
KcpPacket() {};
KcpPacket(uint32_t conv, Cmd cmd, size_t payloadSize) {
setConv(conv);
setCmd(cmd);
setPayLoadSize(payloadSize);
};
KcpPacket(size_t payloadSize) {
setPayLoadSize(payloadSize);
}
virtual ~KcpPacket();
bool storeToData();
char *getPayloadData() {
return data() + HEADER_SIZE;
};
uint32_t getResendts() const { return _resendts; }
uint32_t getRto() const { return _rto; }
uint32_t getFastack() const { return _fastack; }
uint32_t getXmit() const { return _xmit; }
void setResendts(uint32_t resendts) { _resendts = resendts; }
void setRto(uint32_t rto) {_rto = rto; }
void setFastack(uint32_t fastack) { _fastack = fastack; }
void setXmit(uint32_t xmit) { _xmit = xmit; }
void setPayLoadSize(size_t len) {
setCapacity(len + HEADER_SIZE + 1);
setSize(len + HEADER_SIZE);
setLen(len);
}
protected:
bool loadFromData(const char *data, size_t len);
private:
uint32_t _resendts; // 重传超时时间戳,表示该数据包下次重传的时间戳
uint32_t _rto; // 超时重传时间表示数据包在多长时间没收到ACK就重传,会基于rtt动态调整
uint32_t _fastack; // 快速确认计数器
uint32_t _xmit; // 传输次数,用于统计重传次数
};
//数据包
class KcpDataPacket : public KcpPacket {
public:
KcpDataPacket(uint32_t conv, size_t payloadSize)
: KcpPacket(conv, KcpHeader::Cmd::CMD_WASK, payloadSize) {
}
};
//ACK包
class KcpAckPacket : public KcpPacket {
public:
KcpAckPacket(uint32_t conv)
: KcpPacket(conv, KcpHeader::Cmd::CMD_ACK, 0) {
}
};
//探测窗口大小包
class KcpProbePacket : public KcpPacket {
public:
KcpProbePacket(uint32_t conv)
: KcpPacket(conv, KcpHeader::Cmd::CMD_WASK, 0) {
}
};
//告知窗口大小包
class KcpTellPacket : public KcpPacket {
public:
KcpTellPacket(uint32_t conv)
: KcpPacket(conv, KcpHeader::Cmd::CMD_WINS, 0) {
}
};
//可以根据实际需要调整参数
//参考kcp V.1.7实现由以下推荐模式和参数
//默认,开启流控: setDelayMode(DELAY_MODE_NORMAL); setInterval(10); setFastResend(0); setNoCwnd(false)
//普通,关闭流控: setDelayMode(DELAY_MODE_NORMAL); setInterval(10); setFastResend(0); setNoCwnd(true)
//快速,关闭流控: setDelayMode(DELAY_MODE_NO_DELAY); setInterval(10); setFastResend(1); setNoCwnd(true); setRxMinrto(10)
class KcpTransport : public std::enable_shared_from_this<KcpTransport> {
public:
using Ptr = std::shared_ptr<KcpTransport>;
enum DelayMode {
DELAY_MODE_NORMAL = 0, // 正常模式, 每次重发rto翻倍,往外增加12.5%的最小rto
DELAY_MODE_FAST = 1, // 快速模式, 每次重发rto增加当前包rto的一半,不额外增加延时
DELAY_MODE_NO_DELAY = 2, // 极速模式, 每次重发rto增加基础rto的一半,不额外增加延时
};
static const uint32_t IKCP_ASK_SEND = 1; // need to send IKCP_CMD_WASK
static const uint32_t IKCP_ASK_TELL = 2; // need to send IKCP_CMD_WINS
static const uint32_t IKCP_RTO_NDL = 30; // no delay min rto
static const uint32_t IKCP_RTO_MIN = 100; // normal min rto
static const uint32_t IKCP_RTO_DEF = 200;
static const uint32_t IKCP_RTO_MAX = 60000;
static const uint32_t IKCP_WND_SND = 32;
static const uint32_t IKCP_WND_RCV = 128; // must >= max fragment size
static const uint32_t IKCP_MTU_DEF = 1400;
static const uint32_t IKCP_ACK_FAST = 3;
static const uint32_t IKCP_INTERVAL = 100;
static const uint32_t IKCP_THRESH_INIT = 2;
static const uint32_t IKCP_THRESH_MIN = 2;
static const uint32_t IKCP_PROBE_INIT = 7000; // 7 secs to probe window size
static const uint32_t IKCP_PROBE_LIMIT = 120000; // up to 120 secs to probe window
using onReadCB = std::function<void(const Buffer::Ptr &buf)>;
using onWriteCB = std::function<void(const Buffer::Ptr &buf)>;
using OnErr = std::function<void(const SockException &)>;
KcpTransport(bool serverMode);
KcpTransport(bool serverMode, const EventPoller::Ptr &poller);
virtual ~KcpTransport();
void setOnRead(onReadCB cb) { _on_read = std::move(cb); }
void setOnWrite(onWriteCB cb) { _on_write = std::move(cb); }
void setOnErr(OnErr cb) { _on_err = std::move(cb); }
void setPoller(const EventPoller::Ptr &poller) {
_poller = poller ? poller : EventPollerPool::Instance().getPoller();
}
// 应用层将数据放到发送队列中
ssize_t send(const Buffer::Ptr &buf, bool flush = false);
// 应用层将socket层接收到的数据输入
void input(const Buffer::Ptr &buf);
// change MTU size, default is 1400
void setMtu(int mtu);
void setInterval(int intervoal);
void setRxMinrto(int rx_minrto);
// set maximum window size: sndwnd=32, rcvwnd=32 by default
void setWndSize(int sndwnd, int rcvwnd);
//设置低延时模式
//默认DELAY_MODE_NORMAL
void setDelayMode(DelayMode delay_mode);
//设置快速重传的阈值
//默认0,即不会快速重传
void setFastResend(int resend);
//设置快速重传保守模式
//默认保守模式
void setFastackConserve(bool flag);
//设置是否关闭拥塞控制
//默认开启
void setNoCwnd(bool flag);
//设置是否开启流传输模式
//默认不开启
void setStreamMode(bool flag);
protected:
void onWrite(const Buffer::Ptr &buf) {
if (_on_write) {
_on_write(buf);
}
}
void onRead(const Buffer::Ptr &buf) {
if (_on_read) {
_on_read(buf);
}
}
void onErr(const SockException &err) {
DebugL;
if (_on_err) {
_on_err(err);
}
}
void startTimer();
//处理收到的数据,rcv_buf中有新数据时调用
void onData();
//测量rcv_queue 下一个可以提取的包的长度
int peeksize();
void handleAnyPacket(KcpPacket::Ptr packet);
void handleCmdAck(KcpPacket::Ptr packet, uint32_t current);
void handleCmdPush(KcpPacket::Ptr packet);
// move available data from rcv_buf -> rcv_queue
void sortRecvBuf();
void sortSendQueue();
//流模式,合并发送包
size_t mergeSendQueue(const char *buffer, size_t len);
// 将发送队列的数据真正发送出去
void update();
void sendSendQueue();
void sendAckList();
void sendProbePacket();
void sendPacket(KcpPacket::Ptr pkt, bool flush = false);
void flushPool();
//将发送缓存中对端已经确认的数据包丢弃
//UNA模式,指定序列之前的包都已经确认,可以Drop
void dropCacheByUna(uint32_t una);
//将发送缓存中对端已经确认的数据包丢弃
//ACK模式,仅指定序列的包被确认
void dropCacheByAck(uint32_t sn);
//更新rtt
void updateRtt(int32_t rtt);
//更新发送cache中packet的Faskack计数
void updateFastAck(uint32_t sn, uint32_t ts);
//扩大拥塞窗口
void increaseCwnd();
//缩小拥塞窗口
void decreaseCwnd(bool change, bool lost);
// get how many packet is waiting to be sent
int getWaitSnd();
int getRcvWndUnused();
private:
onReadCB _on_read = nullptr;
onWriteCB _on_write = nullptr;
OnErr _on_err = nullptr;
bool _server_mode;
bool _conv_init = false;
EventPoller::Ptr _poller = nullptr;
Timer::Ptr _timer;
//刷新计时器
Ticker _alive_ticker;
bool _fastack_conserve = false; //快速重传保守模式
uint32_t _conv; // 会话ID,用于标识一个会话
uint32_t _mtu = IKCP_MTU_DEF; // 最大传输单元,默认1400
uint32_t _mss = IKCP_MTU_DEF - KcpPacket::HEADER_SIZE; // 最大分片大小,由MTU计算得到
uint32_t _interval = IKCP_INTERVAL; //内部flush的率先哪个间隔
uint32_t _fastresend = 0; //快速重传触发阈值,当packet的_fastack超过该值时,触发快速重传
int _fastlimit = 5; //快速重传限制,限制触发快速重传的最大次数,防止过度重传
uint32_t _xmit = 0; //重传次数计数器
uint32_t _dead_link = 20; //最大重传次数,当某个包的重传次数超过该值时,认为链路断开
uint32_t _snd_una = 0; //发送缓冲区中第一个未确认的包序号
uint32_t _snd_nxt = 0; //下一个待分配的序号
uint32_t _rcv_nxt = 0; //接收队列中待接收的下一个包序号
uint32_t _ts_recent = 0; //最近一次收到数据包的时间戳
uint32_t _ts_lastack = 0;//最近一次发送ACK的时间戳
//rtt
int32_t _rx_rttval = 0; //RTT方差
int32_t _rx_srtt = 0; //RTT(平滑后)
int32_t _rx_rto = IKCP_RTO_DEF; //重传超时时间(会基于rtt和rtt方差动态调整)
int32_t _rx_minrto = IKCP_RTO_MIN; //最小重传超时时间,防止RTO过小
//for 拥塞窗口控制
uint32_t _snd_wnd = IKCP_WND_SND; //发送队列窗口,用于限制发送速率,用户配置(单位分片数量)
uint32_t _rcv_wnd = IKCP_WND_RCV; //接收队列窗口,用于限制接收速率,用户配置(单位分片数量)
uint32_t _rmt_wnd = IKCP_WND_RCV; //对端接收缓存拥塞窗口,对端通告(单位分片数量)
uint32_t _cwnd = 1; //发送缓存拥塞窗口大小,算法动态调整(单位分片数量)
uint32_t _incr = 0; //拥塞窗口增量,用于拥塞控制算法中动态窗口大小(单位字节)
uint32_t _ssthresh = IKCP_THRESH_INIT; //慢启动阈值
uint32_t _probe = 0; //探测标志,用于探测对端窗口大小
uint32_t _ts_probe = 0; //探测时间戳,记录发送窗口探测包的时间戳
uint32_t _probe_wait = 0;//探测等待时间, 控制探测包发送的时间间隔
DelayMode _delay_mode = DELAY_MODE_NORMAL;
int _nocwnd = false; //是否禁用拥塞控制
bool _stream = false; //是否开启流传输模式
//传输链路: userdata->_snd_queue->_snd_buf->网络发送
//_snd_queue:无限制
//_snd_buf: min(_snd_wnd, _rmt_wnd, _cwnd)
//传输链路: 网络接收->_rcv_buf->_snd_queue->userdata
//_rcv_buf:无限制,乱序数据暂存
//_snd_queue: _rcv_wnd
std::list<KcpDataPacket::Ptr> _snd_queue; //发送队列,还未进入发送窗口
std::list<KcpDataPacket::Ptr> _rcv_queue; //接收队列,已经接收完全的包等待交给应用层
std::list<KcpDataPacket::Ptr> _snd_buf; //发送缓存,已经进入发送窗口,用于重传
std::list<KcpDataPacket::Ptr> _rcv_buf; //接收缓存,已经接受,但是因为乱序丢包等还不能交给应用层
//待发送的ACK列表
std::deque<std::pair<uint32_t /*sn*/, uint32_t /*ts*/>>_acklist;
BufferRaw::Ptr _buffer_pool; //用于合并多个kcp包到一个udp包中
};
} // namespace toolkit
#endif // TOOLKIT_NETWORK_KCP_H

View File

@ -0,0 +1,87 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "Server.h"
using namespace std;
namespace toolkit {
Server::Server(EventPoller::Ptr poller) {
_poller = poller ? std::move(poller) : EventPollerPool::Instance().getPoller();
}
////////////////////////////////////////////////////////////////////////////////////
SessionHelper::SessionHelper(const std::weak_ptr<Server> &server, Session::Ptr session, std::string cls) {
_server = server;
_session = std::move(session);
_cls = std::move(cls);
//记录session至全局的map方便后面管理 [AUTO-TRANSLATED:f90fce35]
//Record the session in the global map for easy management later
_session_map = SessionMap::Instance().shared_from_this();
_identifier = _session->getIdentifier();
_session_map->add(_identifier, _session);
}
SessionHelper::~SessionHelper() {
if (!_server.lock()) {
//务必通知Session已从TcpServer脱离 [AUTO-TRANSLATED:6f55a358]
//Must notify that the session has been detached from TcpServer
_session->onError(SockException(Err_other, "Server shutdown"));
}
//从全局map移除相关记录 [AUTO-TRANSLATED:f0b0b2ad]
//Remove the related record from the global map
_session_map->del(_identifier);
}
const Session::Ptr &SessionHelper::session() const {
return _session;
}
const std::string &SessionHelper::className() const {
return _cls;
}
////////////////////////////////////////////////////////////////////////////////////
bool SessionMap::add(const string &tag, const Session::Ptr &session) {
lock_guard<mutex> lck(_mtx_session);
return _map_session.emplace(tag, session).second;
}
bool SessionMap::del(const string &tag) {
lock_guard<mutex> lck(_mtx_session);
return _map_session.erase(tag);
}
Session::Ptr SessionMap::get(const string &tag) {
lock_guard<mutex> lck(_mtx_session);
auto it = _map_session.find(tag);
if (it == _map_session.end()) {
return nullptr;
}
return it->second.lock();
}
void SessionMap::for_each_session(const function<void(const string &id, const Session::Ptr &session)> &cb) {
lock_guard<mutex> lck(_mtx_session);
for (auto it = _map_session.begin(); it != _map_session.end();) {
auto session = it->second.lock();
if (!session) {
it = _map_session.erase(it);
continue;
}
cb(it->first, session);
++it;
}
}
} // namespace toolkit

View File

@ -0,0 +1,93 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_SERVER_H
#define ZLTOOLKIT_SERVER_H
#include <unordered_map>
#include "Util/mini.h"
#include "Session.h"
namespace toolkit {
// 全局的 Session 记录对象, 方便后面管理 [AUTO-TRANSLATED:1c2725cb]
//Global Session record object, convenient for later management
// 线程安全的 [AUTO-TRANSLATED:efbca605]
//Thread-safe
class SessionMap : public std::enable_shared_from_this<SessionMap> {
public:
friend class SessionHelper;
using Ptr = std::shared_ptr<SessionMap>;
//单例 [AUTO-TRANSLATED:8c2c95b4]
//Singleton
static SessionMap &Instance();
~SessionMap() = default;
//获取Session [AUTO-TRANSLATED:08c6e0f2]
//Get Session
Session::Ptr get(const std::string &tag);
void for_each_session(const std::function<void(const std::string &id, const Session::Ptr &session)> &cb);
private:
SessionMap() = default;
//移除Session [AUTO-TRANSLATED:b6023f67]
//Remove Session
bool del(const std::string &tag);
//添加Session [AUTO-TRANSLATED:4bdf8277]
//Add Session
bool add(const std::string &tag, const Session::Ptr &session);
private:
std::mutex _mtx_session;
std::unordered_map<std::string, std::weak_ptr<Session> > _map_session;
};
class Server;
class SessionHelper {
public:
bool enable = true;
using Ptr = std::shared_ptr<SessionHelper>;
SessionHelper(const std::weak_ptr<Server> &server, Session::Ptr session, std::string cls);
~SessionHelper();
const Session::Ptr &session() const;
const std::string &className() const;
private:
std::string _cls;
std::string _identifier;
Session::Ptr _session;
SessionMap::Ptr _session_map;
std::weak_ptr<Server> _server;
};
// server 基类, 暂时仅用于剥离 SessionHelper 对 TcpServer 的依赖 [AUTO-TRANSLATED:2fe50ede]
//Server base class, temporarily only used to decouple SessionHelper from TcpServer
// 后续将 TCP 与 UDP 服务通用部分加到这里. [AUTO-TRANSLATED:3d8429f3]
//Later, the common parts of TCP and UDP services will be added here.
class Server : public std::enable_shared_from_this<Server>, public mINI {
public:
using Ptr = std::shared_ptr<Server>;
explicit Server(EventPoller::Ptr poller = nullptr);
virtual ~Server() = default;
protected:
EventPoller::Ptr _poller;
};
} // namespace toolkit
#endif // ZLTOOLKIT_SERVER_H

View File

@ -0,0 +1,40 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include <atomic>
#include "Session.h"
using namespace std;
namespace toolkit {
class TcpSession : public Session {};
class UdpSession : public Session {};
StatisticImp(UdpSession)
StatisticImp(TcpSession)
Session::Session(const Socket::Ptr &sock) : SocketHelper(sock) {
if (sock->sockType() == SockNum::Sock_TCP) {
_statistic_tcp.reset(new ObjectStatistic<TcpSession>);
} else {
_statistic_udp.reset(new ObjectStatistic<UdpSession>);
}
}
string Session::getIdentifier() const {
if (_id.empty()) {
static atomic<uint64_t> s_session_index{0};
_id = to_string(++s_session_index) + '-' + to_string(getSock()->rawFD());
}
return _id;
}
} // namespace toolkit

View File

@ -0,0 +1,127 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_SESSION_H
#define ZLTOOLKIT_SESSION_H
#include <memory>
#include "Socket.h"
#include "Util/util.h"
#include "Util/SSLBox.h"
#include "Kcp.h"
namespace toolkit {
// 会话, 用于存储一对客户端与服务端间的关系 [AUTO-TRANSLATED:d69736ea]
//Session, used to store the relationship between a client and a server
class Server;
class TcpSession;
class UdpSession;
class Session : public SocketHelper {
public:
using Ptr = std::shared_ptr<Session>;
Session(const Socket::Ptr &sock);
~Session() override = default;
/**
* Session , Server Session
* @param server,
* After creating a Session, the Server will pass its configuration parameters to the Session through this function
* @param server, server object
* [AUTO-TRANSLATED:5ce03e96]
*/
virtual void attachServer(const Server &server) {}
/**
* Session
* @return
* As the unique identifier of this Session
* @return unique identifier
* [AUTO-TRANSLATED:3b046f26]
*/
std::string getIdentifier() const override;
private:
mutable std::string _id;
std::unique_ptr<toolkit::ObjectStatistic<toolkit::TcpSession> > _statistic_tcp;
std::unique_ptr<toolkit::ObjectStatistic<toolkit::UdpSession> > _statistic_udp;
};
// 通过该模板可以让TCP服务器快速支持TLS [AUTO-TRANSLATED:fea218e6]
//This template allows the TCP server to quickly support TLS
template <typename SessionType>
class SessionWithSSL : public SessionType {
public:
template <typename... ArgsType>
SessionWithSSL(ArgsType &&...args)
: SessionType(std::forward<ArgsType>(args)...) {
_ssl_box.setOnEncData([&](const Buffer::Ptr &buf) { public_send(buf); });
_ssl_box.setOnDecData([&](const Buffer::Ptr &buf) { public_onRecv(buf); });
}
~SessionWithSSL() override { _ssl_box.flush(); }
void onRecv(const Buffer::Ptr &buf) override { _ssl_box.onRecv(buf); }
// 添加public_onRecv和public_send函数是解决较低版本gcc一个lambad中不能访问protected或private方法的bug [AUTO-TRANSLATED:7b16e05b]
//Adding public_onRecv and public_send functions is to solve a bug in lower versions of gcc where a lambda cannot access protected or private methods
inline void public_onRecv(const Buffer::Ptr &buf) { SessionType::onRecv(buf); }
inline void public_send(const Buffer::Ptr &buf) { SessionType::send(buf); }
bool overSsl() const override { return true; }
protected:
ssize_t send(Buffer::Ptr buf) override {
auto size = buf->size();
_ssl_box.onSend(std::move(buf));
return size;
}
private:
SSL_Box _ssl_box;
};
// 通过该模板可以让UDP服务器快速支持KCP
template <typename SessionType>
class SessionWithKCP : public SessionType {
public:
template <typename... ArgsType>
SessionWithKCP(ArgsType &&...args)
: SessionType(std::forward<ArgsType>(args)...) {
_kcp_box = std::make_shared<KcpTransport>(true, std::forward<ArgsType>(args)...);
_kcp_box->setOnWrite([&](const Buffer::Ptr &buf) { public_send(buf); });
_kcp_box->setOnRead([&](const Buffer::Ptr &buf) { public_onRecv(buf); });
_kcp_box->setOnErr([&](const SockException &ex) { public_onErr(ex); });
}
~SessionWithKCP() override { }
void onRecv(const Buffer::Ptr &buf) override { _kcp_box->input(buf); }
inline void public_onRecv(const Buffer::Ptr &buf) { SessionType::onRecv(buf); }
inline void public_send(const Buffer::Ptr &buf) { SessionType::send(buf); }
inline void public_onErr(const SockException &ex) { SessionType::onError(ex); }
protected:
ssize_t send(Buffer::Ptr buf) override {
return _kcp_box->send(std::move(buf));
}
private:
KcpTransport::Ptr _kcp_box;
};
} // namespace toolkit
#endif // ZLTOOLKIT_SESSION_H

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,86 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/xia-chu/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#import "Socket.h"
#include "Util/logger.h"
#if defined (OS_IPHONE)
#import <Foundation/Foundation.h>
#endif //OS_IPHONE
namespace toolkit {
#if defined (OS_IPHONE)
bool SockNum::setSocketOfIOS(int sock){
CFStreamCreatePairWithSocket(NULL, (CFSocketNativeHandle)sock, (CFReadStreamRef *)(&readStream), (CFWriteStreamRef*)(&writeStream));
if (readStream)
CFReadStreamSetProperty((CFReadStreamRef)readStream, kCFStreamPropertyShouldCloseNativeSocket, kCFBooleanFalse);
if (writeStream)
CFWriteStreamSetProperty((CFWriteStreamRef)writeStream, kCFStreamPropertyShouldCloseNativeSocket, kCFBooleanFalse);
if ((readStream == NULL) || (writeStream == NULL))
{
WarnL<<"Unable to create read and write stream...";
if (readStream)
{
CFReadStreamClose((CFReadStreamRef)readStream);
CFRelease(readStream);
readStream = NULL;
}
if (writeStream)
{
CFWriteStreamClose((CFWriteStreamRef)writeStream);
CFRelease(writeStream);
writeStream = NULL;
}
return false;
}
Boolean r1 = CFReadStreamSetProperty((CFReadStreamRef)readStream, kCFStreamNetworkServiceType, kCFStreamNetworkServiceTypeVoIP);
Boolean r2 = CFWriteStreamSetProperty((CFWriteStreamRef)writeStream, kCFStreamNetworkServiceType, kCFStreamNetworkServiceTypeVoIP);
if (!r1 || !r2)
{
return false;
}
CFStreamStatus readStatus = CFReadStreamGetStatus((CFReadStreamRef)readStream);
CFStreamStatus writeStatus = CFWriteStreamGetStatus((CFWriteStreamRef)writeStream);
if ((readStatus == kCFStreamStatusNotOpen) || (writeStatus == kCFStreamStatusNotOpen))
{
BOOL r1 = CFReadStreamOpen((CFReadStreamRef)readStream);
BOOL r2 = CFWriteStreamOpen((CFWriteStreamRef)writeStream);
if (!r1 || !r2)
{
WarnL<<"Error in CFStreamOpen";
return false;
}
}
//NSLog(@"setSocketOfIOS:%d",sock);
return true;
}
void SockNum::unsetSocketOfIOS(int sock){
//NSLog(@"unsetSocketOfIOS:%d",sock);
if (readStream) {
CFReadStreamClose((CFReadStreamRef)readStream);
readStream=NULL;
}
if (writeStream) {
CFWriteStreamClose((CFWriteStreamRef)writeStream);
writeStream=NULL;
}
}
#endif //OS_IPHONE
} // namespace toolkit

View File

@ -0,0 +1,179 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "TcpClient.h"
using namespace std;
namespace toolkit {
StatisticImp(TcpClient)
TcpClient::TcpClient(const EventPoller::Ptr &poller) : SocketHelper(nullptr) {
setPoller(poller ? poller : EventPollerPool::Instance().getPoller());
setOnCreateSocket([](const EventPoller::Ptr &poller) {
//TCP客户端默认开启互斥锁 [AUTO-TRANSLATED:94fad9cd]
//TCP client defaults to enabling mutex lock
return Socket::createSocket(poller, true);
});
}
TcpClient::~TcpClient() {
TraceL << "~" << TcpClient::getIdentifier();
}
void TcpClient::shutdown(const SockException &ex) {
_timer.reset();
SocketHelper::shutdown(ex);
}
bool TcpClient::alive() const {
if (_timer) {
//连接中或已连接 [AUTO-TRANSLATED:bf2b744a]
//Connecting or already connected
return true;
}
//在websocket client(zlmediakit)相关代码中, [AUTO-TRANSLATED:d309d587]
//In websocket client (zlmediakit) related code,
//_timer一直为空但是socket fd有效alive状态也应该返回true [AUTO-TRANSLATED:344889b8]
//_timer is always empty, but socket fd is valid, and alive status should also return true
auto sock = getSock();
return sock && sock->alive();
}
void TcpClient::setNetAdapter(const string &local_ip) {
_net_adapter = local_ip;
}
void TcpClient::startConnect(const string &url, uint16_t port, float timeout_sec, uint16_t local_port) {
weak_ptr<TcpClient> weak_self = static_pointer_cast<TcpClient>(shared_from_this());
_timer = std::make_shared<Timer>(2.0f, [weak_self]() {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
strong_self->onManager();
return true;
}, getPoller());
setSock(createSocket());
auto sock_ptr = getSock().get();
sock_ptr->setOnErr([weak_self, sock_ptr](const SockException &ex) {
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
if (sock_ptr != strong_self->getSock().get()) {
//已经重连socket上次的socket的事件忽略掉 [AUTO-TRANSLATED:9bf35a7a]
//Socket has been reconnected, last socket's event is ignored
return;
}
strong_self->_timer.reset();
TraceL << strong_self->getIdentifier() << " on err: " << ex;
strong_self->onError(ex);
});
TraceL << getIdentifier() << " start connect " << url << ":" << port;
sock_ptr->connect(url, port, [weak_self](const SockException &err) {
auto strong_self = weak_self.lock();
if (strong_self) {
strong_self->onSockConnect(err);
}
}, timeout_sec, _net_adapter, local_port);
}
void TcpClient::onSockConnect(const SockException &ex) {
TraceL << getIdentifier() << " connect result: " << ex;
if (ex) {
//连接失败 [AUTO-TRANSLATED:33415985]
//Connection failed
_timer.reset();
onConnect(ex);
return;
}
auto sock_ptr = getSock().get();
weak_ptr<TcpClient> weak_self = static_pointer_cast<TcpClient>(shared_from_this());
sock_ptr->setOnFlush([weak_self, sock_ptr]() {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
if (sock_ptr != strong_self->getSock().get()) {
//已经重连socket上传socket的事件忽略掉 [AUTO-TRANSLATED:243a8c95]
//Socket has been reconnected, upload socket's event is ignored
return false;
}
strong_self->onFlush();
return true;
});
sock_ptr->setOnRead([weak_self, sock_ptr](const Buffer::Ptr &pBuf, struct sockaddr *, int) {
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
if (sock_ptr != strong_self->getSock().get()) {
//已经重连socket上传socket的事件忽略掉 [AUTO-TRANSLATED:243a8c95]
//Socket has been reconnected, upload socket's event is ignored
return;
}
try {
strong_self->onRecv(pBuf);
} catch (std::exception &ex) {
strong_self->shutdown(SockException(Err_other, ex.what()));
}
});
onConnect(ex);
}
std::string TcpClient::getIdentifier() const {
if (_id.empty()) {
static atomic<uint64_t> s_index{ 0 };
_id = toolkit::demangle(typeid(*this).name()) + "-" + to_string(++s_index);
}
return _id;
}
size_t TcpClient::getSendSpeed() const {
auto sock = getSock();
if (sock) {
return sock->getSendSpeed();
}
return 0;
}
size_t TcpClient::getRecvSpeed() const {
auto sock = getSock();
if (sock) {
return sock->getRecvSpeed();
}
return 0;
}
size_t TcpClient::getRecvTotalBytes() const {
auto sock = getSock();
if (sock) {
return sock->getRecvTotalBytes();
}
return 0;
}
size_t TcpClient::getSendTotalBytes() const {
auto sock = getSock();
if (sock) {
return sock->getSendTotalBytes();
}
return 0;
}
} /* namespace toolkit */

View File

@ -0,0 +1,229 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef NETWORK_TCPCLIENT_H
#define NETWORK_TCPCLIENT_H
#include <memory>
#include "Socket.h"
#include "Util/SSLBox.h"
namespace toolkit {
//Tcp客户端Socket对象默认开始互斥锁 [AUTO-TRANSLATED:5cc9a824]
//Tcp client, Socket object defaults to starting mutex lock
class TcpClient : public SocketHelper {
public:
using Ptr = std::shared_ptr<TcpClient>;
TcpClient(const EventPoller::Ptr &poller = nullptr);
~TcpClient() override;
/**
* tcp服务器
* @param url ip或域名
* @param port
* @param timeout_sec ,
* @param local_port
* Start connecting to the TCP server
* @param url Server IP or domain name
* @param port Server port
* @param timeout_sec Timeout time, in seconds
* @param local_port Local port
* [AUTO-TRANSLATED:7aa87355]
*/
virtual void startConnect(const std::string &url, uint16_t port, float timeout_sec = 5, uint16_t local_port = 0);
/**
* tcp服务器
* @param url ip或域名
* @proxy_host ip
* @proxy_port
* @param timeout_sec ,
* @param local_port
* Start connecting to the TCP server through a proxy
* @param url Server IP or domain name
* @proxy_host Proxy IP
* @proxy_port Proxy port
* @param timeout_sec Timeout time, in seconds
* @param local_port Local port
* [AUTO-TRANSLATED:2739bd58]
*/
virtual void startConnectWithProxy(const std::string &url, const std::string &proxy_host, uint16_t proxy_port, float timeout_sec = 5, uint16_t local_port = 0){};
/**
*
* @param ex onErr事件时的参数
* Actively disconnect the connection
* @param ex Parameter when triggering the onErr event
* [AUTO-TRANSLATED:5f6f3017]
*/
void shutdown(const SockException &ex = SockException(Err_shutdown, "self shutdown")) override;
/**
* truefalse
* Returns true if connected or connecting, returns false if disconnected
* [AUTO-TRANSLATED:60595edc]
*/
virtual bool alive() const;
/**
* ,使
* @param local_ip ip
* Set the network card adapter, use this network card to communicate with the server
* @param local_ip Local network card IP
* [AUTO-TRANSLATED:2549c18d]
*/
virtual void setNetAdapter(const std::string &local_ip);
/**
*
* Unique identifier
* [AUTO-TRANSLATED:6b21021f]
*/
std::string getIdentifier() const override;
size_t getSendSpeed() const;
size_t getRecvSpeed() const;
size_t getRecvTotalBytes() const;
size_t getSendTotalBytes() const;
protected:
/**
*
* @param ex
* Connection result callback
* @param ex Success or failure
* [AUTO-TRANSLATED:103bb2cb]
*/
virtual void onConnect(const SockException &ex) = 0;
/**
* tcp连接成功后每2秒触发一次该事件
* Trigger this event every 2 seconds after a successful TCP connection
* [AUTO-TRANSLATED:37b40b5d]
*/
void onManager() override {}
private:
void onSockConnect(const SockException &ex);
private:
mutable std::string _id;
std::string _net_adapter = "::";
std::shared_ptr<Timer> _timer;
//对象个数统计 [AUTO-TRANSLATED:3b43e8c2]
//Object count statistics
ObjectStatistic<TcpClient> _statistic;
};
//用于实现TLS客户端的模板对象 [AUTO-TRANSLATED:e4d399a3]
//Template object for implementing TLS client
template<typename TcpClientType>
class TcpClientWithSSL : public TcpClientType {
public:
using Ptr = std::shared_ptr<TcpClientWithSSL>;
template<typename ...ArgsType>
TcpClientWithSSL(ArgsType &&...args):TcpClientType(std::forward<ArgsType>(args)...) {}
~TcpClientWithSSL() override {
if (_ssl_box) {
_ssl_box->flush();
}
}
void onRecv(const Buffer::Ptr &buf) override {
if (_ssl_box) {
_ssl_box->onRecv(buf);
} else {
TcpClientType::onRecv(buf);
}
}
// 使能其他未被重写的send函数 [AUTO-TRANSLATED:5f01f91b]
//Enable other unoverridden send functions
using TcpClientType::send;
ssize_t send(Buffer::Ptr buf) override {
if (_ssl_box) {
auto size = buf->size();
_ssl_box->onSend(std::move(buf));
return size;
}
return TcpClientType::send(std::move(buf));
}
//添加public_onRecv和public_send函数是解决较低版本gcc一个lambad中不能访问protected或private方法的bug [AUTO-TRANSLATED:210f092e]
//Adding public_onRecv and public_send functions is to solve a bug in lower version gcc where a lambda cannot access protected or private methods
inline void public_onRecv(const Buffer::Ptr &buf) {
TcpClientType::onRecv(buf);
}
inline void public_send(const Buffer::Ptr &buf) {
TcpClientType::send(buf);
}
void startConnect(const std::string &url, uint16_t port, float timeout_sec = 5, uint16_t local_port = 0) override {
_host = url;
TcpClientType::startConnect(url, port, timeout_sec, local_port);
}
void startConnectWithProxy(const std::string &url, const std::string &proxy_host, uint16_t proxy_port, float timeout_sec = 5, uint16_t local_port = 0) override {
_host = url;
TcpClientType::startConnect(proxy_host, proxy_port, timeout_sec, local_port);
}
bool overSsl() const override { return (bool)_ssl_box; }
protected:
void onConnect(const SockException &ex) override {
if (!ex) {
_ssl_box = std::make_shared<SSL_Box>(false);
_ssl_box->setOnDecData([this](const Buffer::Ptr &buf) {
public_onRecv(buf);
});
_ssl_box->setOnEncData([this](const Buffer::Ptr &buf) {
public_send(buf);
});
if (!isIP(_host.data())) {
//设置ssl域名 [AUTO-TRANSLATED:1286a860]
//Set ssl domain
_ssl_box->setHost(_host.data());
}
}
TcpClientType::onConnect(ex);
}
/**
* ssl, 302http与https的转换
* Reset ssl, mainly to solve some 302 redirects when switching between http and https
* [AUTO-TRANSLATED:12ad26da]
*/
void setDoNotUseSSL() {
_ssl_box.reset();
}
private:
std::string _host;
std::shared_ptr<SSL_Box> _ssl_box;
};
} /* namespace toolkit */
#endif /* NETWORK_TCPCLIENT_H */

View File

@ -0,0 +1,306 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "TcpServer.h"
#include "Util/uv_errno.h"
#include "Util/onceToken.h"
using namespace std;
namespace toolkit {
INSTANCE_IMP(SessionMap)
StatisticImp(TcpServer)
TcpServer::TcpServer(const EventPoller::Ptr &poller) : Server(poller) {
_multi_poller = !poller;
setOnCreateSocket(nullptr);
}
void TcpServer::setupEvent() {
_socket = createSocket(_poller);
weak_ptr<TcpServer> weak_self = std::static_pointer_cast<TcpServer>(shared_from_this());
#if 1
_socket->setOnBeforeAccept([weak_self](const EventPoller::Ptr &poller) -> Socket::Ptr {
if (auto strong_self = weak_self.lock()) {
return strong_self->onBeforeAcceptConnection(poller);
}
return nullptr;
});
_socket->setOnAccept([weak_self](Socket::Ptr &sock, shared_ptr<void> &complete) {
if (auto strong_self = weak_self.lock()) {
auto ptr = sock->getPoller().get();
auto server = strong_self->getServer(ptr);
ptr->async([server, sock, complete]() {
// 该tcp客户端派发给对应线程的TcpServer服务器 [AUTO-TRANSLATED:662b882f]
// This TCP client is dispatched to the corresponding thread of the TcpServer server
server->onAcceptConnection(sock);
});
}
});
#else
_socket->setOnAccept([weak_self](Socket::Ptr &sock, shared_ptr<void> &complete) {
if (auto strong_self = weak_self.lock()) {
if (strong_self->_multi_poller) {
EventPollerPool::Instance().getExecutor([sock, complete, weak_self](const TaskExecutor::Ptr &exe) {
if (auto strong_self = weak_self.lock()) {
sock->moveTo(static_pointer_cast<EventPoller>(exe));
strong_self->getServer(sock->getPoller().get())->onAcceptConnection(sock);
}
});
} else {
strong_self->onAcceptConnection(sock);
}
}
});
#endif
}
TcpServer::~TcpServer() {
if (_main_server && _socket && _socket->rawFD() != -1) {
InfoL << "Close tcp server [" << _socket->get_local_ip() << "]: " << _socket->get_local_port();
}
_timer.reset();
//先关闭socket监听防止收到新的连接 [AUTO-TRANSLATED:cd65064f]
//First close the socket listening to prevent receiving new connections
_socket.reset();
_session_map.clear();
_cloned_server.clear();
}
uint16_t TcpServer::getPort() {
if (!_socket) {
return 0;
}
return _socket->get_local_port();
}
void TcpServer::setOnCreateSocket(Socket::onCreateSocket cb) {
if (cb) {
_on_create_socket = std::move(cb);
} else {
_on_create_socket = [](const EventPoller::Ptr &poller) {
return Socket::createSocket(poller, false);
};
}
for (auto &pr : _cloned_server) {
pr.second->setOnCreateSocket(cb);
}
}
TcpServer::Ptr TcpServer::onCreatServer(const EventPoller::Ptr &poller) {
return Ptr(new TcpServer(poller), [poller](TcpServer *ptr) { poller->async([ptr]() { delete ptr; }); });
}
Socket::Ptr TcpServer::onBeforeAcceptConnection(const EventPoller::Ptr &poller) {
assert(_poller->isCurrentThread());
//此处改成自定义获取poller对象防止负载不均衡 [AUTO-TRANSLATED:16c66457]
//Modify this to a custom way of getting the poller object to prevent load imbalance
return createSocket(_multi_poller ? EventPollerPool::Instance().getPoller(false) : _poller);
}
void TcpServer::cloneFrom(const TcpServer &that) {
if (!that._socket) {
throw std::invalid_argument("TcpServer::cloneFrom other with null socket");
}
setupEvent();
_main_server = false;
_on_create_socket = that._on_create_socket;
_session_alloc = that._session_alloc;
_multi_poller = that._multi_poller;
weak_ptr<TcpServer> weak_self = std::static_pointer_cast<TcpServer>(shared_from_this());
_timer = std::make_shared<Timer>(2.0f, [weak_self]() -> bool {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
strong_self->onManagerSession();
return true;
}, _poller);
this->mINI::operator=(that);
_parent = static_pointer_cast<TcpServer>(const_cast<TcpServer &>(that).shared_from_this());
}
// 接收到客户端连接请求 [AUTO-TRANSLATED:8a67b72a]
//Received a client connection request
Session::Ptr TcpServer::onAcceptConnection(const Socket::Ptr &sock) {
assert(_poller->isCurrentThread());
weak_ptr<TcpServer> weak_self = std::static_pointer_cast<TcpServer>(shared_from_this());
//创建一个Session;这里实现创建不同的服务会话实例 [AUTO-TRANSLATED:9ed745be]
//Create a Session; here implement creating different service session instances
auto helper = _session_alloc(std::static_pointer_cast<TcpServer>(shared_from_this()), sock);
auto session = helper->session();
//把本服务器的配置传递给Session [AUTO-TRANSLATED:e3711484]
//Pass the configuration of this server to the Session
session->attachServer(*this);
//_session_map::emplace肯定能成功 [AUTO-TRANSLATED:09d4aef7]
//_session_map::emplace will definitely succeed
auto success = _session_map.emplace(helper.get(), helper).second;
assert(success == true);
weak_ptr<Session> weak_session = session;
//会话接收数据事件 [AUTO-TRANSLATED:f3f4cbbb]
//Session receives data event
sock->setOnRead([weak_session](const Buffer::Ptr &buf, struct sockaddr *, int) {
//获取会话强应用 [AUTO-TRANSLATED:187497e6]
//Get the strong application of the session
auto strong_session = weak_session.lock();
if (!strong_session) {
return;
}
try {
strong_session->onRecv(buf);
} catch (SockException &ex) {
strong_session->shutdown(ex);
} catch (exception &ex) {
strong_session->shutdown(SockException(Err_shutdown, ex.what()));
}
});
SessionHelper *ptr = helper.get();
auto cls = ptr->className();
//会话接收到错误事件 [AUTO-TRANSLATED:b000e868]
//Session receives an error event
sock->setOnErr([weak_self, weak_session, ptr, cls](const SockException &err) {
//在本函数作用域结束时移除会话对象 [AUTO-TRANSLATED:5c4433b8]
//Remove the session object when the function scope ends
//目的是确保移除会话前执行其onError函数 [AUTO-TRANSLATED:1e6c65df]
//The purpose is to ensure that the onError function is executed before removing the session
//同时避免其onError函数抛异常时没有移除会话对象 [AUTO-TRANSLATED:6d541cbd]
//And avoid not removing the session object when the onError function throws an exception
onceToken token(nullptr, [&]() {
//移除掉会话 [AUTO-TRANSLATED:e7c27790]
//Remove the session
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
assert(strong_self->_poller->isCurrentThread());
if (!strong_self->_is_on_manager) {
//该事件不是onManager时触发的直接操作map [AUTO-TRANSLATED:d90ee039]
//This event is not triggered by onManager, directly operate on the map
strong_self->_session_map.erase(ptr);
} else {
//遍历map时不能直接删除元素 [AUTO-TRANSLATED:0f00040c]
//Cannot directly delete elements when traversing the map
strong_self->_poller->async([weak_self, ptr]() {
auto strong_self = weak_self.lock();
if (strong_self) {
strong_self->_session_map.erase(ptr);
}
}, false);
}
});
//获取会话强应用 [AUTO-TRANSLATED:187497e6]
//Get the strong reference of the session
auto strong_session = weak_session.lock();
if (strong_session) {
//触发onError事件回调 [AUTO-TRANSLATED:825d16df]
//Trigger the onError event callback
TraceP(strong_session) << cls << " on err: " << err;
strong_session->onError(err);
}
});
return session;
}
void TcpServer::start_l(uint16_t port, const std::string &host, uint32_t backlog) {
setupEvent();
//新建一个定时器定时管理这些tcp会话 [AUTO-TRANSLATED:ef859bd7]
//Create a new timer to manage these TCP sessions periodically
weak_ptr<TcpServer> weak_self = std::static_pointer_cast<TcpServer>(shared_from_this());
_timer = std::make_shared<Timer>(2.0f, [weak_self]() -> bool {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
strong_self->onManagerSession();
return true;
}, _poller);
if (_multi_poller) {
EventPollerPool::Instance().for_each([&](const TaskExecutor::Ptr &executor) {
EventPoller::Ptr poller = static_pointer_cast<EventPoller>(executor);
if (poller == _poller) {
return;
}
auto &serverRef = _cloned_server[poller.get()];
if (!serverRef) {
serverRef = onCreatServer(poller);
}
if (serverRef) {
serverRef->cloneFrom(*this);
}
});
}
if (!_socket->listen(port, host.c_str(), backlog)) {
// 创建tcp监听失败可能是由于端口占用或权限问题 [AUTO-TRANSLATED:88ebdefc]
//TCP listener creation failed, possibly due to port occupation or permission issues
string err = (StrPrinter << "Listen on " << host << " " << port << " failed: " << get_uv_errmsg(true));
throw std::runtime_error(err);
}
for (auto &pr: _cloned_server) {
// 启动子Server [AUTO-TRANSLATED:1820131c]
//Start the child Server
pr.second->_socket->cloneSocket(*_socket);
}
InfoL << "TCP server listening on [" << host << "]: " << port;
}
void TcpServer::onManagerSession() {
assert(_poller->isCurrentThread());
onceToken token([&]() {
_is_on_manager = true;
}, [&]() {
_is_on_manager = false;
});
for (auto &pr : _session_map) {
//遍历时可能触发onErr事件(也会操作_session_map) [AUTO-TRANSLATED:7760b80d]
//When traversing, the onErr event may be triggered (also operates on _session_map)
try {
pr.second->session()->onManager();
} catch (exception &ex) {
WarnL << ex.what();
}
}
}
Socket::Ptr TcpServer::createSocket(const EventPoller::Ptr &poller) {
return _on_create_socket(poller);
}
TcpServer::Ptr TcpServer::getServer(const EventPoller *poller) const {
auto parent = _parent.lock();
auto &ref = parent ? parent->_cloned_server : _cloned_server;
auto it = ref.find(poller);
if (it != ref.end()) {
//派发到cloned server [AUTO-TRANSLATED:8765ab56]
//Dispatch to the cloned server
return it->second;
}
//派发到parent server [AUTO-TRANSLATED:4cf34169]
//Dispatch to the parent server
return static_pointer_cast<TcpServer>(parent ? parent : const_cast<TcpServer *>(this)->shared_from_this());
}
Session::Ptr TcpServer::createSession(const Socket::Ptr &sock) {
return getServer(sock->getPoller().get())->onAcceptConnection(sock);
}
} /* namespace toolkit */

View File

@ -0,0 +1,136 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef TCPSERVER_TCPSERVER_H
#define TCPSERVER_TCPSERVER_H
#include <memory>
#include <functional>
#include <unordered_map>
#include "Server.h"
#include "Session.h"
#include "Poller/Timer.h"
#include "Util/util.h"
namespace toolkit {
//TCP服务器可配置的配置通过Session::attachServer方法传递给会话对象 [AUTO-TRANSLATED:4e55c332]
//Configurable TCP server; configuration is passed to the session object through the Session::attachServer method
class TcpServer : public Server {
public:
using Ptr = std::shared_ptr<TcpServer>;
/**
* tcp服务器listen fd的accept事件会加入到所有的poller线程中监听
* TcpServer::start函数时TcpServer对象
* TcpServer对象通过Socket对象克隆的方式在多个poller线程中监听同一个listen fd
* TCP服务器将会通过抢占式accept的方式把客户端均匀的分布到不同的poller线程
*
* Creates a TCP server, the accept event of the listen fd will be added to all poller threads for listening
* When calling the TcpServer::start function, multiple child TcpServer objects will be created internally,
* These child TcpServer objects will be cloned through the Socket object in multiple poller threads to listen to the same listen fd
* This way, the TCP server will distribute clients evenly across different poller threads through a preemptive accept approach
* This approach can achieve client load balancing and improve connection acceptance speed
* [AUTO-TRANSLATED:761a6b1e]
*/
explicit TcpServer(const EventPoller::Ptr &poller = nullptr);
~TcpServer() override;
/**
* @brief tcp server
* @param port 0
* @param host ip
* @param backlog tcp listen backlog
* @brief Starts the TCP server
* @param port Local port, 0 for random
* @param host Listening network card IP
* @param backlog TCP listen backlog
* [AUTO-TRANSLATED:9bab69b6]
*/
template <typename SessionType>
void start(uint16_t port, const std::string &host = "::", uint32_t backlog = 1024, const std::function<void(std::shared_ptr<SessionType> &)> &cb = nullptr) {
static std::string cls_name = toolkit::demangle(typeid(SessionType).name());
// Session创建器通过它创建不同类型的服务器 [AUTO-TRANSLATED:f5585e1e]
//Session creator, creates different types of servers through it
_session_alloc = [cb](const TcpServer::Ptr &server, const Socket::Ptr &sock) {
auto session = std::shared_ptr<SessionType>(new SessionType(sock), [](SessionType *ptr) {
TraceP(static_cast<Session *>(ptr)) << "~" << cls_name;
delete ptr;
});
if (cb) {
cb(session);
}
TraceP(static_cast<Session *>(session.get())) << cls_name;
session->setOnCreateSocket(server->_on_create_socket);
return std::make_shared<SessionHelper>(server, std::move(session), cls_name);
};
start_l(port, host, backlog);
}
/**
* @brief ,
* @brief Gets the server listening port number, the server can choose to listen on a random port
* [AUTO-TRANSLATED:125ff8d8]
*/
uint16_t getPort();
/**
* @brief socket构建行为
* @brief Custom socket construction behavior
* [AUTO-TRANSLATED:4cf98e86]
*/
void setOnCreateSocket(Socket::onCreateSocket cb);
/**
* socket对象创建Session对象
* socket归属poller线程执行本函数
* Creates a Session object based on the socket object
* Ensures that this function is executed in the poller thread that owns the socket
* [AUTO-TRANSLATED:1d52d9ee]
*/
Session::Ptr createSession(const Socket::Ptr &socket);
protected:
virtual void cloneFrom(const TcpServer &that);
virtual TcpServer::Ptr onCreatServer(const EventPoller::Ptr &poller);
virtual Session::Ptr onAcceptConnection(const Socket::Ptr &sock);
virtual Socket::Ptr onBeforeAcceptConnection(const EventPoller::Ptr &poller);
private:
void onManagerSession();
Socket::Ptr createSocket(const EventPoller::Ptr &poller);
void start_l(uint16_t port, const std::string &host, uint32_t backlog);
Ptr getServer(const EventPoller *) const;
void setupEvent();
private:
bool _multi_poller;
bool _is_on_manager = false;
bool _main_server = true;
std::weak_ptr<TcpServer> _parent;
Socket::Ptr _socket;
std::shared_ptr<Timer> _timer;
Socket::onCreateSocket _on_create_socket;
std::unordered_map<SessionHelper *, SessionHelper::Ptr> _session_map;
std::function<SessionHelper::Ptr(const TcpServer::Ptr &server, const Socket::Ptr &)> _session_alloc;
std::unordered_map<const EventPoller *, Ptr> _cloned_server;
//对象个数统计 [AUTO-TRANSLATED:3b43e8c2]
//Object count statistics
ObjectStatistic<TcpServer> _statistic;
};
} /* namespace toolkit */
#endif /* TCPSERVER_TCPSERVER_H */

View File

@ -0,0 +1,130 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "UdpClient.h"
using namespace std;
namespace toolkit {
StatisticImp(UdpClient)
UdpClient::UdpClient(const EventPoller::Ptr &poller) : SocketHelper(nullptr) {
setPoller(poller ? poller : EventPollerPool::Instance().getPoller());
setOnCreateSocket([](const EventPoller::Ptr &poller) {
//TCP客户端默认开启互斥锁
return Socket::createSocket(poller, true);
});
}
UdpClient::~UdpClient() {
TraceL << "~" << UdpClient::getIdentifier();
}
void UdpClient::startConnect(const string &peer_host, uint16_t peer_port, uint16_t local_port) {
weak_ptr<UdpClient> weak_self = static_pointer_cast<UdpClient>(shared_from_this());
_timer = std::make_shared<Timer>(2.0f, [weak_self]() {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
strong_self->onManager();
return true;
}, getPoller());
setSock(createSocket());
auto sock_ptr = getSock().get();
sock_ptr->setOnErr([weak_self, sock_ptr](const SockException &ex) {
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
if (sock_ptr != strong_self->getSock().get()) {
//已经重连socket上次的socket的事件忽略掉
return;
}
strong_self->_timer.reset();
TraceL << strong_self->getIdentifier() << " on err: " << ex;
strong_self->onError(ex);
});
sock_ptr->setOnFlush([weak_self, sock_ptr]() {
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
if (sock_ptr != strong_self->getSock().get()) {
//已经重连socket上传socket的事件忽略掉
return false;
}
strong_self->onFlush();
return true;
});
sock_ptr->setOnRead([weak_self, sock_ptr](const Buffer::Ptr &pBuf, struct sockaddr * addr, int addr_len) {
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
if (sock_ptr != strong_self->getSock().get()) {
//已经重连socket上传socket的事件忽略掉
return;
}
try {
strong_self->onRecvFrom(pBuf, addr, addr_len);
} catch (std::exception &ex) {
strong_self->shutdown(SockException(Err_other, ex.what()));
}
});
bool ret = getSock()->bindUdpSock(local_port, _net_adapter);
if (!ret) {
WarnL << "UDP output bind local error";
}
auto peer_addr = SockUtil::make_sockaddr(peer_host.c_str(), peer_port);
//只能软绑定
ret = getSock()->bindPeerAddr((struct sockaddr *)&peer_addr, 0, true);
if (!ret) {
WarnL << "UDP output bind peer error";
}
// TraceL << getIdentifier() << " start connect " << url << ":" << peer_port;
}
void UdpClient::shutdown(const SockException &ex) {
_timer.reset();
SocketHelper::shutdown(ex);
}
bool UdpClient::alive() const {
if (_timer) {
//连接中或已连接
return true;
}
//在websocket client(zlmediakit)相关代码中,
//_timer一直为空但是socket fd有效alive状态也应该返回true
auto sock = getSock();
return sock && sock->alive();
}
void UdpClient::setNetAdapter(const string &local_ip) {
_net_adapter = local_ip;
}
std::string UdpClient::getIdentifier() const {
if (_id.empty()) {
static atomic<uint64_t> s_index { 0 };
_id = toolkit::demangle(typeid(*this).name()) + "-" + to_string(++s_index);
}
return _id;
}
} /* namespace toolkit */

View File

@ -0,0 +1,194 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef NETWORK_UDPCLIENT_H
#define NETWORK_UDPCLIENT_H
#include <memory>
#include "Socket.h"
#include "Util/SSLBox.h"
#include "Kcp.h"
namespace toolkit {
//Udp客户端Socket对象默认开始互斥锁
class UdpClient : public SocketHelper {
public:
using Ptr = std::shared_ptr<UdpClient>;
using OnRecvFrom = std::function<void(const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len)>;
using OnErr = std::function<void(const SockException &)>;
UdpClient(const EventPoller::Ptr &poller = nullptr);
~UdpClient() override;
/**
* udp服务器
* @param peer_host ip或域名
* @param peer_port
* @param local_port
*/
virtual void startConnect(const std::string &peer_host, uint16_t peer_port, uint16_t local_port = 0);
/**
*
* @param ex onErr事件时的参数
*/
void shutdown(const SockException &ex = SockException(Err_shutdown, "self shutdown")) override;
/**
* truefalse
*/
virtual bool alive() const;
/**
* ,使
* @param local_ip ip
*/
virtual void setNetAdapter(const std::string &local_ip);
/**
*
*/
std::string getIdentifier() const override;
void setOnRecvFrom(OnRecvFrom cb) {
_on_recvfrom = std::move(cb);
}
void setOnError(OnErr cb) {
_on_err = std::move(cb);
}
protected:
virtual void onRecvFrom(const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) {
if (_on_recvfrom) {
_on_recvfrom(buf, addr, addr_len);
}
}
void onRecv(const Buffer::Ptr &buf) override {}
void onError(const SockException &err) override {
DebugL;
if (_on_err) {
_on_err(err);
}
}
/**
* udp连接成功后每2秒触发一次该事件
*/
void onManager() override {}
private:
mutable std::string _id;
std::string _net_adapter = "::";
std::shared_ptr<Timer> _timer;
//对象个数统计
ObjectStatistic<UdpClient> _statistic;
OnRecvFrom _on_recvfrom;
OnErr _on_err;
};
//用于实现KCP客户端的模板对象
template<typename UdpClientType>
class UdpClientWithKcp : public UdpClientType {
public:
using Ptr = std::shared_ptr<UdpClientWithKcp>;
template<typename ...ArgsType>
UdpClientWithKcp(ArgsType &&...args)
:UdpClientType(std::forward<ArgsType>(args)...) {
_kcp_box = std::make_shared<KcpTransport>(false);
_kcp_box->setOnWrite([&](const Buffer::Ptr &buf) { public_send(buf); });
_kcp_box->setOnRead([&](const Buffer::Ptr &buf) { public_onRecv(buf); });
_kcp_box->setOnErr([&](const SockException &ex) { public_onErr(ex); });
}
~UdpClientWithKcp() override { }
void onRecvFrom(const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) override {
//KCP 暂不支持一个UDP Socket 对多个目标,因此先忽略addr参数
_kcp_box->input(buf);
}
ssize_t send(Buffer::Ptr buf) override {
return _kcp_box->send(std::move(buf));
}
ssize_t sendto(Buffer::Ptr buf, struct sockaddr *addr = nullptr, socklen_t addr_len = 0) override {
//KCP 暂不支持一个UDP Socket 对多个目标,因此先忽略addr参数
return _kcp_box->send(std::move(buf));
}
inline void public_onRecv(const Buffer::Ptr &buf) {
//KCP 暂不支持一个UDP Socket 对多个目标,因此固定采用bind的地址参数
UdpClientType::onRecvFrom(buf, (struct sockaddr*)&_peer_addr, _peer_addr_len);
}
inline void public_send(const Buffer::Ptr &buf) {
UdpClientType::send(buf);
}
inline void public_onErr(const SockException &ex) { UdpClientWithKcp::onError(ex); }
virtual void startConnect(const std::string &peer_host, uint16_t peer_port, uint16_t local_port = 0) override {
_kcp_box->setPoller(UdpClientType::getPoller());
_peer_addr = SockUtil::make_sockaddr(peer_host.data(), peer_port);
_peer_addr_len = SockUtil::get_sock_len((const struct sockaddr*)&_peer_addr);
UdpClientType::startConnect(peer_host, peer_port, local_port);
}
void setMtu(int mtu) {
_kcp_box->setMtu(mtu);
}
void setInterval(int intervoal) {
_kcp_box->setInterval(intervoal);
}
void setRxMinrto(int rx_minrto) {
_kcp_box->setRxMinrto(rx_minrto);
}
void setWndSize(int sndwnd, int rcvwnd) {
_kcp_box->setWndSize(sndwnd, rcvwnd);
}
void setDelayMode(KcpTransport::DelayMode delay_mode) {
_kcp_box->setDelayMode(delay_mode);
}
void setFastResend(int resend) {
_kcp_box->setFastResend(resend);
}
void setFastackConserve(bool flag) {
_kcp_box->setFastackConserve(flag);
}
void setNoCwnd(bool flag) {
_kcp_box->setNoCwnd(flag);
}
void setStreamMode(bool flag) {
_kcp_box->setStreamMode(flag);
}
private:
struct sockaddr_storage _peer_addr;
int _peer_addr_len = 0;
KcpTransport::Ptr _kcp_box;
};
} /* namespace toolkit */
#endif /* NETWORK_UDPCLIENT_H */

View File

@ -0,0 +1,406 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "Util/uv_errno.h"
#include "Util/onceToken.h"
#include "UdpServer.h"
using namespace std;
namespace toolkit {
static const uint8_t s_in6_addr_maped[]
= { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 };
static constexpr auto kUdpDelayCloseMS = 3 * 1000;
static UdpServer::PeerIdType makeSockId(sockaddr *addr, int) {
UdpServer::PeerIdType ret;
switch (addr->sa_family) {
case AF_INET : {
ret[0] = ((struct sockaddr_in *) addr)->sin_port >> 8;
ret[1] = ((struct sockaddr_in *) addr)->sin_port & 0xFF;
//ipv4地址统一转换为ipv6方式处理 [AUTO-TRANSLATED:ad7cf8c3]
//Convert ipv4 addresses to ipv6 for unified processing
memcpy(&ret[2], &s_in6_addr_maped, 12);
memcpy(&ret[14], &(((struct sockaddr_in *) addr)->sin_addr), 4);
return ret;
}
case AF_INET6 : {
ret[0] = ((struct sockaddr_in6 *) addr)->sin6_port >> 8;
ret[1] = ((struct sockaddr_in6 *) addr)->sin6_port & 0xFF;
memcpy(&ret[2], &(((struct sockaddr_in6 *)addr)->sin6_addr), 16);
return ret;
}
default: throw std::invalid_argument("invalid sockaddr address");
}
}
UdpServer::UdpServer(const EventPoller::Ptr &poller) : Server(poller) {
_multi_poller = !poller;
setOnCreateSocket(nullptr);
}
void UdpServer::setupEvent() {
_socket = createSocket(_poller);
std::weak_ptr<UdpServer> weak_self = std::static_pointer_cast<UdpServer>(shared_from_this());
_socket->setOnRead([weak_self](Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) {
if (auto strong_self = weak_self.lock()) {
strong_self->onRead(buf, addr, addr_len);
}
});
}
UdpServer::~UdpServer() {
if (!_cloned && _socket && _socket->rawFD() != -1) {
InfoL << "Close udp server [" << _socket->get_local_ip() << "]: " << _socket->get_local_port();
}
_timer.reset();
_socket.reset();
_cloned_server.clear();
if (!_cloned && _session_mutex && _session_map) {
lock_guard<std::recursive_mutex> lck(*_session_mutex);
_session_map->clear();
}
}
void UdpServer::start_l(uint16_t port, const std::string &host) {
setupEvent();
//主server才创建session map其他cloned server共享之 [AUTO-TRANSLATED:113cf4fd]
//Only the main server creates a session map, other cloned servers share it
_session_mutex = std::make_shared<std::recursive_mutex>();
_session_map = std::make_shared<SessionMapType>();
// 新建一个定时器定时管理这些 udp 会话,这些对象只由主server做超时管理cloned server不管理 [AUTO-TRANSLATED:d20478a2]
//Create a timer to manage these udp sessions periodically, these objects are only managed by the main server, cloned servers do not manage them
std::weak_ptr<UdpServer> weak_self = std::static_pointer_cast<UdpServer>(shared_from_this());
_timer = std::make_shared<Timer>(2.0f, [weak_self]() -> bool {
if (auto strong_self = weak_self.lock()) {
strong_self->onManagerSession();
return true;
}
return false;
}, _poller);
if (_multi_poller) {
// clone server至不同线程让udp server支持多线程 [AUTO-TRANSLATED:15a85c8f]
//Clone the server to different threads to support multi-threading for the udp server
EventPollerPool::Instance().for_each([&](const TaskExecutor::Ptr &executor) {
auto poller = std::static_pointer_cast<EventPoller>(executor);
if (poller == _poller) {
return;
}
auto &serverRef = _cloned_server[poller.get()];
if (!serverRef) {
serverRef = onCreatServer(poller);
}
if (serverRef) {
serverRef->cloneFrom(*this);
}
});
}
if (!_socket->bindUdpSock(port, host.c_str())) {
// udp 绑定端口失败, 可能是由于端口占用或权限问题 [AUTO-TRANSLATED:c31eedba]
//Failed to bind udp port, possibly due to port occupation or permission issues
std::string err = (StrPrinter << "Bind udp socket on " << host << " " << port << " failed: " << get_uv_errmsg(true));
throw std::runtime_error(err);
}
for (auto &pr: _cloned_server) {
// 启动子Server [AUTO-TRANSLATED:1820131c]
//Start the child server
#if 0
pr.second->_socket->cloneSocket(*_socket);
#else
// 实验发现cloneSocket方式虽然可以节省fd资源但是在某些系统上线程漂移问题更严重 [AUTO-TRANSLATED:d6a88e17]
//Experiments have found that the cloneSocket method can save fd resources, but the thread drift problem is more serious on some systems
pr.second->_socket->bindUdpSock(_socket->get_local_port(), _socket->get_local_ip());
#endif
}
InfoL << "UDP server bind to [" << host << "]: " << port;
}
UdpServer::Ptr UdpServer::onCreatServer(const EventPoller::Ptr &poller) {
return Ptr(new UdpServer(poller), [poller](UdpServer *ptr) { poller->async([ptr]() { delete ptr; }); });
}
void UdpServer::cloneFrom(const UdpServer &that) {
if (!that._socket) {
throw std::invalid_argument("UdpServer::cloneFrom other with null socket");
}
// 将socket的创建回调复制前置, 确保所有的socket都可以通过上层创建
_on_create_socket = that._on_create_socket;
setupEvent();
_cloned = true;
// clone callbacks
_session_alloc = that._session_alloc;
_session_mutex = that._session_mutex;
_session_map = that._session_map;
_multi_poller = that._multi_poller;
// clone properties
this->mINI::operator=(that);
}
void UdpServer::onRead(Buffer::Ptr &buf, sockaddr *addr, int addr_len) {
const auto id = makeSockId(addr, addr_len);
onRead_l(true, id, buf, addr, addr_len);
}
static void emitSessionRecv(const SessionHelper::Ptr &helper, const Buffer::Ptr &buf) {
if (!helper->enable) {
// 延时销毁中 [AUTO-TRANSLATED:24d3d333]
//Delayed destruction in progress
return;
}
try {
helper->session()->onRecv(buf);
} catch (SockException &ex) {
helper->session()->shutdown(ex);
} catch (exception &ex) {
helper->session()->shutdown(SockException(Err_shutdown, ex.what()));
}
}
void UdpServer::onRead_l(bool is_server_fd, const UdpServer::PeerIdType &id, Buffer::Ptr &buf, sockaddr *addr, int addr_len) {
// udp server fd收到数据时触发此函数大部分情况下数据应该在peer fd触发此函数应该不是热点函数 [AUTO-TRANSLATED:f347ff20]
//This function is triggered when the udp server fd receives data; in most cases, the data should be triggered by the peer fd, and this function should not be a hot spot
bool is_new = false;
if (auto helper = getOrCreateSession(id, buf, addr, addr_len, is_new)) {
if (helper->session()->getPoller()->isCurrentThread()) {
//当前线程收到数据,直接处理数据 [AUTO-TRANSLATED:07e5a596]
//The current thread receives data and processes it directly
emitSessionRecv(helper, buf);
} else {
//数据漂移到其他线程,需要先切换线程 [AUTO-TRANSLATED:15235f6f]
//Data migration to another thread requires switching threads first
#if !defined(_WIN32)
WarnL << "UDP packet incoming from other thread";
#endif
std::weak_ptr<SessionHelper> weak_helper = helper;
//由于socket读buffer是该线程上所有socket共享复用的所以不能跨线程使用必须先转移走 [AUTO-TRANSLATED:1134538b]
//Since the socket read buffer is shared and reused by all sockets on this thread, it cannot be used across threads and must be transferred first
auto cacheable_buf = std::move(buf);
helper->session()->async([weak_helper, cacheable_buf]() {
if (auto strong_helper = weak_helper.lock()) {
emitSessionRecv(strong_helper, cacheable_buf);
}
});
}
#if !defined(NDEBUG) && !defined(_WIN32)
if (!is_new) {
TraceL << "UDP packet incoming from " << (is_server_fd ? "server fd" : "other peer fd");
}
#endif
}
}
void UdpServer::onManagerSession() {
decltype(_session_map) copy_map;
{
std::lock_guard<std::recursive_mutex> lock(*_session_mutex);
//拷贝map防止遍历时移除对象 [AUTO-TRANSLATED:ebbc7595]
//Copy the map to prevent objects from being removed during traversal
copy_map = std::make_shared<SessionMapType>(*_session_map);
}
auto lam = [copy_map]() {
for (auto &pr : *copy_map) {
auto &session = pr.second->session();
if (!session->getPoller()->isCurrentThread()) {
// 该session不归属该poller管理 [AUTO-TRANSLATED:d5edb552]
//This session does not belong to the management of this poller
continue;
}
try {
// UDP 会话需要处理超时 [AUTO-TRANSLATED:0a51f8a1]
//UDP sessions need to handle timeouts
session->onManager();
} catch (exception &ex) {
WarnL << "Exception occurred when emit onManager: " << ex.what();
}
}
};
if (_multi_poller){
EventPollerPool::Instance().for_each([lam](const TaskExecutor::Ptr &executor) {
std::static_pointer_cast<EventPoller>(executor)->async(lam);
});
} else {
lam();
}
}
SessionHelper::Ptr UdpServer::getOrCreateSession(const UdpServer::PeerIdType &id, Buffer::Ptr &buf, sockaddr *addr, int addr_len, bool &is_new) {
{
//减小临界区 [AUTO-TRANSLATED:3d6089d8]
//Reduce the critical section
std::lock_guard<std::recursive_mutex> lock(*_session_mutex);
auto it = _session_map->find(id);
if (it != _session_map->end()) {
return it->second;
}
}
is_new = true;
return createSession(id, buf, addr, addr_len);
}
SessionHelper::Ptr UdpServer::createSession(const PeerIdType &id, Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) {
// 此处改成自定义获取poller对象防止负载不均衡 [AUTO-TRANSLATED:194e8460]
//Change to custom acquisition of poller objects to prevent load imbalance
auto socket = createSocket(_multi_poller ? EventPollerPool::Instance().getPoller(false) : _poller, buf, addr, addr_len);
if (!socket) {
//创建socket失败本次onRead事件收到的数据直接丢弃 [AUTO-TRANSLATED:b218d68c]
//Socket creation failed, the data received by this onRead event is discarded
return nullptr;
}
auto addr_str = string((char *) addr, addr_len);
std::weak_ptr<UdpServer> weak_self = std::static_pointer_cast<UdpServer>(shared_from_this());
auto helper_creator = [this, weak_self, socket, addr_str, id]() -> SessionHelper::Ptr {
auto server = weak_self.lock();
if (!server) {
return nullptr;
}
//如果已经创建该客户端对应的UdpSession类那么直接返回 [AUTO-TRANSLATED:c57a0d71]
//If the UdpSession class corresponding to this client has already been created, return directly
lock_guard<std::recursive_mutex> lck(*_session_mutex);
auto it = _session_map->find(id);
if (it != _session_map->end()) {
return it->second;
}
assert(_socket);
socket->bindUdpSock(_socket->get_local_port(), _socket->get_local_ip());
socket->bindPeerAddr((struct sockaddr *) addr_str.data(), addr_str.size());
auto helper = _session_alloc(server, socket);
// 把本服务器的配置传递给 Session [AUTO-TRANSLATED:e3ed95ab]
//Pass the configuration of this server to the Session
helper->session()->attachServer(*this);
std::weak_ptr<SessionHelper> weak_helper = helper;
socket->setOnRead([weak_self, weak_helper, id](Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) {
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
auto new_id = makeSockId(addr, addr_len);
//快速判断是否为本会话的的数据, 通常应该成立 [AUTO-TRANSLATED:d5d147e4]
//Quickly determine if it's data for the current session, usually should be true
if (id == new_id) {
if (auto strong_helper = weak_helper.lock()) {
emitSessionRecv(strong_helper, buf);
}
return;
}
//收到非本peer fd的数据让server去派发此数据到合适的session对象 [AUTO-TRANSLATED:e5f44445]
//Received data from a non-current peer fd, let the server dispatch this data to the appropriate session object
strong_self->onRead_l(false, new_id, buf, addr, addr_len);
});
socket->setOnErr([weak_self, weak_helper, id](const SockException &err) {
// 在本函数作用域结束时移除会话对象 [AUTO-TRANSLATED:b2ade305]
//Remove the session object when this function scope ends
// 目的是确保移除会话前执行其 onError 函数 [AUTO-TRANSLATED:7d0329d7]
//The purpose is to ensure the onError function is executed before removing the session
// 同时避免其 onError 函数抛异常时没有移除会话对象 [AUTO-TRANSLATED:354191bd]
//And avoid not removing the session object when its onError function throws an exception
onceToken token(nullptr, [&]() {
// 移除掉会话 [AUTO-TRANSLATED:1d786335]
//Remove the session
auto strong_self = weak_self.lock();
if (!strong_self) {
return;
}
// 延时移除udp session, 防止频繁快速重建对象 [AUTO-TRANSLATED:50dbd694]
//Delay removing the UDP session to prevent frequent and rapid object reconstruction
strong_self->_poller->doDelayTask(kUdpDelayCloseMS, [weak_self, id]() {
if (auto strong_self = weak_self.lock()) {
// 从共享map中移除本session对象 [AUTO-TRANSLATED:47ecbf11]
//Remove the current session object from the shared map
lock_guard<std::recursive_mutex> lck(*strong_self->_session_mutex);
strong_self->_session_map->erase(id);
}
return 0;
});
});
// 获取会话强应用 [AUTO-TRANSLATED:42283ea0]
//Get a strong reference to the session
if (auto strong_helper = weak_helper.lock()) {
// 触发 onError 事件回调 [AUTO-TRANSLATED:82070c3c]
//Trigger the onError event callback
TraceP(strong_helper->session()) << strong_helper->className() << " on err: " << err;
strong_helper->enable = false;
strong_helper->session()->onError(err);
}
});
auto pr = _session_map->emplace(id, std::move(helper));
assert(pr.second);
return pr.first->second;
};
if (socket->getPoller()->isCurrentThread()) {
// 该socket分配在本线程直接创建helper对象 [AUTO-TRANSLATED:18c9d95b]
//This socket is allocated in this thread, directly create a helper object
return helper_creator();
}
// 该socket分配在其他线程需要先转移走buffer然后在其所在线程创建helper对象并处理数据 [AUTO-TRANSLATED:7816a13f]
//This socket is allocated in another thread, need to transfer the buffer first, then create a helper object in its thread and process the data
auto cacheable_buf = std::move(buf);
socket->getPoller()->async([helper_creator, cacheable_buf]() {
// 在该socket所在线程创建helper对象 [AUTO-TRANSLATED:db8d6622]
//Create a helper object in the thread where the socket is located
auto helper = helper_creator();
if (helper) {
// 可能未实质创建hlepr对象成功可能获取到其他线程创建的helper对象 [AUTO-TRANSLATED:091f648e]
//May not have actually created a helper object successfully, may have obtained a helper object created by another thread
helper->session()->getPoller()->async([helper, cacheable_buf]() {
// 该数据不能丢弃给session对象消费 [AUTO-TRANSLATED:6941e5fa]
//This data cannot be discarded, provided to the session object for consumption
emitSessionRecv(helper, cacheable_buf);
});
}
});
return nullptr;
}
void UdpServer::setOnCreateSocket(onCreateSocket cb) {
if (cb) {
_on_create_socket = std::move(cb);
} else {
_on_create_socket = [](const EventPoller::Ptr &poller, const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) {
return Socket::createSocket(poller, false);
};
}
for (auto &pr : _cloned_server) {
pr.second->setOnCreateSocket(cb);
}
}
uint16_t UdpServer::getPort() {
if (!_socket) {
return 0;
}
return _socket->get_local_port();
}
Socket::Ptr UdpServer::createSocket(const EventPoller::Ptr &poller, const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) {
return _on_create_socket(poller, buf, addr, addr_len);
}
StatisticImp(UdpServer)
} // namespace toolkit

View File

@ -0,0 +1,201 @@
/*
* Copyright (c) 2021 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef TOOLKIT_NETWORK_UDPSERVER_H
#define TOOLKIT_NETWORK_UDPSERVER_H
#if __cplusplus >= 201703L
#include <array>
#include <string_view>
#endif
#include "Server.h"
#include "Session.h"
namespace toolkit {
class UdpServer : public Server {
public:
#if __cplusplus >= 201703L
class PeerIdType : public std::array<char, 18> {
#else
class PeerIdType : public std::string {
#endif
public:
#if __cplusplus < 201703L
PeerIdType() {
resize(18);
}
#endif
bool operator==(const PeerIdType &that) const {
return as<uint64_t>(0) == that.as<uint64_t>(0) &&
as<uint64_t>(8) == that.as<uint64_t>(8) &&
as<uint16_t>(16) == that.as<uint16_t>(16);
}
private:
template <class T>
const T& as(size_t offset) const {
return *(reinterpret_cast<const T *>(data() + offset));
}
};
using Ptr = std::shared_ptr<UdpServer>;
using onCreateSocket = std::function<Socket::Ptr(const EventPoller::Ptr &, const Buffer::Ptr &, struct sockaddr *, int)>;
explicit UdpServer(const EventPoller::Ptr &poller = nullptr);
~UdpServer() override;
/**
* @brief
* @brief Start listening to the server
* [AUTO-TRANSLATED:342e9d0e]
*/
template<typename SessionType>
void start(uint16_t port, const std::string &host = "::", const std::function<void(std::shared_ptr<SessionType> &)> &cb = nullptr) {
static std::string cls_name = toolkit::demangle(typeid(SessionType).name());
// Session 创建器, 通过它创建不同类型的服务器 [AUTO-TRANSLATED:a419bcd3]
//Session creator, creates different types of servers through it
_session_alloc = [cb](const UdpServer::Ptr &server, const Socket::Ptr &sock) {
auto session = std::shared_ptr<SessionType>(new SessionType(sock), [](SessionType * ptr) {
TraceP(static_cast<Session *>(ptr)) << "~" << cls_name;
delete ptr;
});
if (cb) {
cb(session);
}
TraceP(static_cast<Session *>(session.get())) << cls_name;
auto sock_creator = server->_on_create_socket;
session->setOnCreateSocket([sock_creator](const EventPoller::Ptr &poller) {
return sock_creator(poller, nullptr, nullptr, 0);
});
return std::make_shared<SessionHelper>(server, std::move(session), cls_name);
};
start_l(port, host);
}
/**
* @brief ,
* @brief Get the server listening port number, the server can choose to listen to a random port
* [AUTO-TRANSLATED:125ff8d8]
*/
uint16_t getPort();
/**
* @brief socket构建行为
* @brief Custom socket construction behavior
* [AUTO-TRANSLATED:4cf98e86]
*/
void setOnCreateSocket(onCreateSocket cb);
protected:
virtual Ptr onCreatServer(const EventPoller::Ptr &poller);
virtual void cloneFrom(const UdpServer &that);
private:
struct PeerIdHash {
#if __cplusplus >= 201703L
size_t operator()(const PeerIdType &v) const noexcept { return std::hash<std::string_view> {}(std::string_view(v.data(), v.size())); }
#else
size_t operator()(const PeerIdType &v) const noexcept { return std::hash<std::string> {}(v); }
#endif
};
using SessionMapType = std::unordered_map<PeerIdType, SessionHelper::Ptr, PeerIdHash>;
/**
* @brief udp server
* @param port 0
* @param host ip
* @brief Start UDP server
* @param port Local port, 0 for random
* @param host Listening network card IP
* [AUTO-TRANSLATED:1c46778d]
*/
void start_l(uint16_t port, const std::string &host = "::");
/**
* @brief Session, UDP
* @brief Periodically manage Session, UDP sessions need to handle timeouts as needed
* [AUTO-TRANSLATED:86ff2f9c]
*/
void onManagerSession();
void onRead(Buffer::Ptr &buf, struct sockaddr *addr, int addr_len);
/**
* @brief ,server fdpeer fd
* @param is_server_fd server fd
* @param id id
* @param buf
* @param addr
* @param addr_len
* @brief Receive data, may come from server fd or peer fd
* @param is_server_fd Whether it is a server fd
* @param id Client ID
* @param buf Data
* @param addr Client address
* @param addr_len Client address length
* [AUTO-TRANSLATED:1c02c9de]
*/
void onRead_l(bool is_server_fd, const PeerIdType &id, Buffer::Ptr &buf, struct sockaddr *addr, int addr_len);
/**
* @brief
* @brief Get or create a session based on peer information
* [AUTO-TRANSLATED:c7e1f0c3]
*/
SessionHelper::Ptr getOrCreateSession(const PeerIdType &id, Buffer::Ptr &buf, struct sockaddr *addr, int addr_len, bool &is_new);
/**
* @brief ,
* @brief Create a session and perform necessary settings
* [AUTO-TRANSLATED:355c4256]
*/
SessionHelper::Ptr createSession(const PeerIdType &id, Buffer::Ptr &buf, struct sockaddr *addr, int addr_len);
/**
* @brief socket
* @brief Create a socket
* [AUTO-TRANSLATED:c9aacad4]
*/
Socket::Ptr createSocket(const EventPoller::Ptr &poller, const Buffer::Ptr &buf = nullptr, struct sockaddr *addr = nullptr, int addr_len = 0);
void setupEvent();
private:
bool _cloned = false;
bool _multi_poller;
Socket::Ptr _socket;
std::shared_ptr<Timer> _timer;
onCreateSocket _on_create_socket;
//cloned server共享主server的session map防止数据在不同server间漂移 [AUTO-TRANSLATED:9a149e52]
//Cloned server shares the session map with the main server, preventing data drift between different servers
std::shared_ptr<std::recursive_mutex> _session_mutex;
std::shared_ptr<SessionMapType> _session_map;
//主server持有cloned server的引用 [AUTO-TRANSLATED:04a6403a]
//Main server holds a reference to the cloned server
std::unordered_map<EventPoller *, Ptr> _cloned_server;
std::function<SessionHelper::Ptr(const UdpServer::Ptr &, const Socket::Ptr &)> _session_alloc;
// 对象个数统计 [AUTO-TRANSLATED:f4a012d0]
//Object count statistics
ObjectStatistic<UdpServer> _statistic;
};
} // namespace toolkit
#endif // TOOLKIT_NETWORK_UDPSERVER_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,587 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef NETWORK_SOCKUTIL_H
#define NETWORK_SOCKUTIL_H
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#include <iphlpapi.h>
#pragma comment (lib, "Ws2_32.lib")
#pragma comment(lib,"Iphlpapi.lib")
#else
#include <netdb.h>
#include <arpa/inet.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <net/if.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#endif // defined(_WIN32)
#include <cstring>
#include <cstdint>
#include <map>
#include <vector>
#include <string>
namespace toolkit {
#if defined(_WIN32)
#ifndef socklen_t
#define socklen_t int
#endif //!socklen_t
int ioctl(int fd, long cmd, u_long *ptr);
int close(int fd);
#endif // defined(_WIN32)
#if !defined(SOCKET_DEFAULT_BUF_SIZE)
#define SOCKET_DEFAULT_BUF_SIZE (256 * 1024)
#else
#if SOCKET_DEFAULT_BUF_SIZE == 0 && !defined(__linux__)
// just for linux, because in some high-throughput environments,
// kernel control is more efficient and reasonable than program
// settings. For example, refer to cloudflare's blog
#undef SOCKET_DEFAULT_BUF_SIZE
#define SOCKET_DEFAULT_BUF_SIZE (256 * 1024)
#endif
#endif
#define TCP_KEEPALIVE_INTERVAL 30
#define TCP_KEEPALIVE_PROBE_TIMES 9
#define TCP_KEEPALIVE_TIME 120
//套接字工具类封装了socket、网络的一些基本操作 [AUTO-TRANSLATED:33a88b27]
//Socket tool class, encapsulating some basic socket and network operations
class SockUtil {
public:
struct SockAddrHash {
std::size_t operator()(const sockaddr_storage& addr) const {
switch (addr.ss_family) {
case AF_INET: {
const struct sockaddr_in* addr_in = reinterpret_cast<const struct sockaddr_in*>(&addr);
return std::hash<uint32_t>()(addr_in->sin_addr.s_addr) ^ std::hash<uint16_t>()(addr_in->sin_port);
}
case AF_INET6: {
const struct sockaddr_in6* addr_in6 = reinterpret_cast<const struct sockaddr_in6*>(&addr);
std::size_t h = 0;
for (int i = 0; i < 16; ++i) {
h ^= std::hash<uint8_t>()(addr_in6->sin6_addr.s6_addr[i]) << (i % 8);
}
return h ^ std::hash<uint16_t>()(addr_in6->sin6_port);
}
default:
return 0;
}
}
};
struct SockAddrEqual {
bool operator()(const sockaddr_storage& a, const sockaddr_storage& b) const {
return toolkit::SockUtil::is_same_addr(reinterpret_cast<const struct sockaddr*>(&a), reinterpret_cast<const struct sockaddr*>(&b));
}
};
/**
* tcp客户端套接字并连接服务器
* @param host ip或域名
* @param port
* @param async
* @param local_ip ip
* @param local_port
* @return -1socket fd号
* Create a TCP client socket and connect to the server
* @param host Server IP or domain name
* @param port Server port number
* @param async Whether to connect asynchronously
* @param local_ip Local network card IP to bind
* @param local_port Local port number to bind
* @return -1 represents failure, others are socket fd numbers
* [AUTO-TRANSLATED:3f0a872c]
*/
static int connect(const char *host, uint16_t port, bool async = true, const char *local_ip = "::", uint16_t local_port = 0);
/**
* tcp监听套接字
* @param port
* @param local_ip ip
* @param back_log accept列队长度
* @return -1socket fd号
* Create a TCP listening socket
* @param port Local port to listen on
* @param local_ip Local network card IP to bind
* @param back_log Accept queue length
* @return -1 represents failure, others are socket fd numbers
* [AUTO-TRANSLATED:d56ad901]
*/
static int listen(const uint16_t port, const char *local_ip = "::", int back_log = 1024);
/**
* udp套接字
* @param port
* @param local_ip ip
* @param enable_reuse bind端口
* @return -1socket fd号
* Create a UDP socket
* @param port Local port to listen on
* @param local_ip Local network card IP to bind
* @param enable_reuse Whether to allow repeated bind port
* @return -1 represents failure, others are socket fd numbers
* [AUTO-TRANSLATED:a3762f0f]
*/
static int bindUdpSock(const uint16_t port, const char *local_ip = "::", bool enable_reuse = true);
/**
* @brief sock
* @param sock, socket fd
* @return 0 , -1
* @brief Release the binding relationship related to sock
* @param sock, socket fd number
* @return 0 Success, -1 Failure
* [AUTO-TRANSLATED:50b002e8]
*/
static int dissolveUdpSock(int sock);
/**
* TCP_NODELAYTCP交互延时
* @param fd socket fd号
* @param on
* @return 0-1
* Enable TCP_NODELAY to reduce TCP interaction delay
* @param fd socket fd number
* @param on Whether to enable
* @return 0 represents success, -1 represents failure
* [AUTO-TRANSLATED:11b57392]
*/
static int setNoDelay(int fd, bool on = true);
/**
* socket不触发SIG_PIPE信号(mac有效)
* @param fd socket fd号
* @return 0-1
* Write socket does not trigger SIG_PIPE signal (seems to be effective only on Mac)
* @param fd socket fd number
* @return 0 represents success, -1 represents failure
* [AUTO-TRANSLATED:bdb49ca5]
*/
static int setNoSigpipe(int fd);
/**
* socket是否阻塞
* @param fd socket fd号
* @param noblock
* @return 0-1
* Set whether the read and write socket is blocked
* @param fd socket fd number
* @param noblock Whether to block
* @return 0 represents success, -1 represents failure
* [AUTO-TRANSLATED:2f9717df]
*/
static int setNoBlocked(int fd, bool noblock = true);
/**
* socket接收缓存8K左右
*
* @param fd socket fd号
* @param size
* @return 0-1
* Set the socket receive buffer, default is around 8K, generally has an upper limit
* Can be adjusted through kernel configuration file
* @param fd socket fd number
* @param size Receive buffer size
* @return 0 represents success, -1 represents failure
* [AUTO-TRANSLATED:4dcaa8b8]
*/
static int setRecvBuf(int fd, int size = SOCKET_DEFAULT_BUF_SIZE);
/**
* socket接收缓存8K左右
*
* @param fd socket fd号
* @param size
* @return 0-1
* Set the socket receive buffer, default is around 8K, generally has an upper limit
* Can be adjusted through kernel configuration file
* @param fd socket fd number
* @param size Receive buffer size
* @return 0 represents success, -1 represents failure
* [AUTO-TRANSLATED:4dcaa8b8]
*/
static int setSendBuf(int fd, int size = SOCKET_DEFAULT_BUF_SIZE);
/**
* (TIME_WAITE状态)
* @param fd socket fd号
* @param on
* @return 0-1
* Set subsequent bindable reuse port (in TIME_WAIT state)
* @param fd socket fd number
* @param on whether to enable this feature
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:4dcb4dff]
*/
static int setReuseable(int fd, bool on = true, bool reuse_port = true);
/**
* udp广播信息
* @param fd socket fd号
* @param on
* @return 0-1
* Run sending or receiving UDP broadcast messages
* @param fd socket fd number
* @param on whether to enable this feature
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:d5ce73e0]
*/
static int setBroadcast(int fd, bool on = true);
/**
* TCP KeepAlive特性
* @param fd socket fd号
* @param on
* @param idle keepalive空闲时间
* @param interval keepalive探测时间间隔
* @param times keepalive探测次数
* @return 0-1
* Enable TCP KeepAlive feature
* @param fd socket fd number
* @param on whether to enable this feature
* @param idle keepalive idle time
* @param interval keepalive probe time interval
* @param times keepalive probe times
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:9b44a8ec]
*/
static int setKeepAlive(int fd, bool on = true, int interval = TCP_KEEPALIVE_INTERVAL, int idle = TCP_KEEPALIVE_TIME, int times = TCP_KEEPALIVE_PROBE_TIMES);
/**
* FD_CLOEXEC特性()
* @param fd fd号socket
* @param on
* @return 0-1
* Enable FD_CLOEXEC feature (related to multiple processes)
* @param fd fd number, not necessarily a socket
* @param on whether to enable this feature
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:964368da]
*/
static int setCloExec(int fd, bool on = true);
/**
* SO_LINGER特性
* @param sock socket fd号
* @param second socket超时时间
* @return 0-1
* Enable SO_LINGER feature
* @param sock socket fd number
* @param second kernel waiting time for closing socket timeout, in seconds
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:92230daf]
*/
static int setCloseWait(int sock, int second = 0);
/**
* dns解析
* @param host ip
* @param port
* @param addr sockaddr结构体
* @return
* DNS resolution
* @param host domain name or IP
* @param port port number
* @param addr sockaddr structure
* @return whether successful
* [AUTO-TRANSLATED:3b79cf5d]
*/
static bool getDomainIP(const char *host, uint16_t port, struct sockaddr_storage &addr, int ai_family = AF_INET,
int ai_socktype = SOCK_STREAM, int ai_protocol = IPPROTO_TCP, int expire_sec = 60);
/**
* ttl
* @param sock socket fd号
* @param ttl ttl值
* @return 0-1
* Set multicast TTL
* @param sock socket fd number
* @param ttl TTL value
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:1828beb5]
*/
static int setMultiTTL(int sock, uint8_t ttl = 64);
/**
*
* @param sock socket fd号
* @param local_ip ip
* @return 0-1
* Set multicast sending network card
* @param sock socket fd number
* @param local_ip local network card IP
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:25e8e9d7]
*/
static int setMultiIF(int sock, const char *local_ip);
/**
*
* @param fd socket fd号
* @param acc
* @return 0-1
* Set whether to receive multicast packets sent by the local machine
* @param fd socket fd number
* @param acc whether to receive
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:83cec1e8]
*/
static int setMultiLOOP(int fd, bool acc = false);
/**
*
* @param fd socket fd号
* @param addr
* @param local_ip ip
* @return 0-1
* Join multicast
* @param fd socket fd number
* @param addr multicast address
* @param local_ip local network card IP
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:45523b25]
*/
static int joinMultiAddr(int fd, const char *addr, const char *local_ip = "0.0.0.0");
/**
* 退
* @param fd socket fd号
* @param addr
* @param local_ip ip
* @return 0-1
* Exit multicast
* @param fd socket fd number
* @param addr multicast address
* @param local_ip local network card ip
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:081785d3]
*/
static int leaveMultiAddr(int fd, const char *addr, const char *local_ip = "0.0.0.0");
/**
*
* @param sock socket fd号
* @param addr
* @param src_ip
* @param local_ip ip
* @return 0-1
* Join multicast and only receive multicast data from the specified source
* @param sock socket fd number
* @param addr multicast address
* @param src_ip source address
* @param local_ip local network card ip
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:061989eb]
*/
static int joinMultiAddrFilter(int sock, const char *addr, const char *src_ip, const char *local_ip = "0.0.0.0");
/**
* 退
* @param fd socket fd号
* @param addr
* @param src_ip
* @param local_ip ip
* @return 0-1
* Exit multicast
* @param fd socket fd number
* @param addr multicast address
* @param src_ip source address
* @param local_ip local network card ip
* @return 0 represents success, -1 for failure
* [AUTO-TRANSLATED:9cd166c7]
*/
static int leaveMultiAddrFilter(int fd, const char *addr, const char *src_ip, const char *local_ip = "0.0.0.0");
/**
* socket当前发生的错误
* @param fd socket fd号
* @return
* Get the current error of the socket
* @param fd socket fd number
* @return error code
* [AUTO-TRANSLATED:e4500a0f]
*/
static int getSockError(int fd);
/**
*
* @return vector<map<ip:name> >
* Get the list of network cards
* @return vector<map<ip:name> >
* [AUTO-TRANSLATED:94687465]
*/
static std::vector<std::map<std::string, std::string>> getInterfaceList();
/**
* ip
* Get the default local ip of the host
* [AUTO-TRANSLATED:9eb5d031]
*/
static std::string get_local_ip();
/**
* socket绑定的本地ip
* @param sock socket fd号
* Get the local ip bound to the socket
* @param sock socket fd number
* [AUTO-TRANSLATED:4e7b6040]
*/
static std::string get_local_ip(int sock);
/**
* socket绑定的本地端口
* @param sock socket fd号
* Get the local port bound to the socket
* @param sock socket fd number
* [AUTO-TRANSLATED:7b212118]
*/
static uint16_t get_local_port(int sock);
/**
* socket绑定的远端ip
* @param sock socket fd号
* Get the remote ip bound to the socket
* @param sock socket fd number
* [AUTO-TRANSLATED:952ddef8]
*/
static std::string get_peer_ip(int sock);
/**
* socket绑定的远端端口
* @param sock socket fd号
* Get the remote port bound to the socket
* @param sock socket fd number
* [AUTO-TRANSLATED:3b9bcf2e]
*/
static uint16_t get_peer_port(int sock);
static bool support_ipv6();
/**
* 线in_addr转ip字符串
* Thread-safe conversion of in_addr to IP string
* [AUTO-TRANSLATED:e0ff8b4b]
*/
static std::string inet_ntoa(const struct in_addr &addr);
static std::string inet_ntoa(const struct in6_addr &addr);
static std::string inet_ntoa(const struct sockaddr *addr, bool mapV4 = true);
static uint16_t inet_port(const struct sockaddr *addr);
static struct sockaddr_storage make_sockaddr(const char *ip, uint16_t port);
static socklen_t get_sock_len(const struct sockaddr *addr);
static bool get_sock_local_addr(int fd, struct sockaddr_storage &addr);
static bool get_sock_peer_addr(int fd, struct sockaddr_storage &addr);
static bool is_same_addr(const struct sockaddr* a, const struct sockaddr* b);
/**
* ip
* @param if_name
* Get the IP of the network card
* @param if_name Network card name
* [AUTO-TRANSLATED:e88f1554]
*/
static std::string get_ifr_ip(const char *if_name);
/**
*
* @param local_op ip
* Get the network card name
* @param local_op Network card IP
* [AUTO-TRANSLATED:cdaad7f0]
*/
static std::string get_ifr_name(const char *local_op);
/**
*
* @param if_name
* Get the subnet mask based on the network card name
* @param if_name Network card name
* [AUTO-TRANSLATED:a6714ee2]
*/
static std::string get_ifr_mask(const char *if_name);
/**
* 广
* @param if_name
* Get the broadcast address based on the network card name
* @param if_name Network card name
* [AUTO-TRANSLATED:20348c92]
*/
static std::string get_ifr_brdaddr(const char *if_name);
/**
* ip是否为同一网段
* @param src_ip ip
* @param dts_ip ip
* Determine if two IPs are in the same network segment
* @param src_ip My IP
* @param dts_ip Peer IP
* [AUTO-TRANSLATED:95acb68f]
*/
static bool in_same_lan(const char *src_ip, const char *dts_ip);
/**
* ipv4地址
* Determine if it is an IPv4 address
* [AUTO-TRANSLATED:b5af4ea0]
*/
static bool is_ipv4(const char *str);
/**
* ipv6地址
* Determine if it is an IPv6 address
* [AUTO-TRANSLATED:70526900]
*/
static bool is_ipv6(const char *str);
};
} // namespace toolkit
#endif // !NETWORK_SOCKUTIL_H

View File

@ -0,0 +1,646 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SelectWrap.h"
#include "EventPoller.h"
#include "Util/util.h"
#include "Util/uv_errno.h"
#include "Util/TimeTicker.h"
#include "Util/NoticeCenter.h"
#include "Network/sockutil.h"
#if defined(HAS_EPOLL)
#include <sys/epoll.h>
#if !defined(EPOLLEXCLUSIVE)
#define EPOLLEXCLUSIVE 0
#endif
#define EPOLL_SIZE 1024
//防止epoll惊群 [AUTO-TRANSLATED:ad53c775]
//Prevent epoll thundering
#ifndef EPOLLEXCLUSIVE
#define EPOLLEXCLUSIVE 0
#endif
#define toEpoll(event) (((event) & Event_Read) ? EPOLLIN : 0) \
| (((event) & Event_Write) ? EPOLLOUT : 0) \
| (((event) & Event_Error) ? (EPOLLHUP | EPOLLERR) : 0) \
| (((event) & Event_LT) ? 0 : EPOLLET)
#define toPoller(epoll_event) (((epoll_event) & (EPOLLIN | EPOLLRDNORM | EPOLLHUP)) ? Event_Read : 0) \
| (((epoll_event) & (EPOLLOUT | EPOLLWRNORM)) ? Event_Write : 0) \
| (((epoll_event) & EPOLLHUP) ? Event_Error : 0) \
| (((epoll_event) & EPOLLERR) ? Event_Error : 0)
#define create_event() epoll_create(EPOLL_SIZE)
#if !defined(_WIN32)
#define close_event(fd) close(fd)
#else
#define close_event(fd) epoll_close(fd)
#endif
#endif //HAS_EPOLL
#if defined(HAS_KQUEUE)
#include <sys/event.h>
#define KEVENT_SIZE 1024
#define create_event() kqueue()
#define close_event(fd) close(fd)
#endif // HAS_KQUEUE
using namespace std;
namespace toolkit {
EventPoller &EventPoller::Instance() {
return *(EventPollerPool::Instance().getFirstPoller());
}
void EventPoller::addEventPipe() {
SockUtil::setNoBlocked(_pipe.readFD());
SockUtil::setNoBlocked(_pipe.writeFD());
// 添加内部管道事件 [AUTO-TRANSLATED:6a72e39a]
//Add internal pipe event
if (addEvent(_pipe.readFD(), EventPoller::Event_Read, [this](int event) { onPipeEvent(); }) == -1) {
throw std::runtime_error("Add pipe fd to poller failed");
}
}
EventPoller::EventPoller(std::string name) {
#if defined(HAS_EPOLL) || defined(HAS_KQUEUE)
_event_fd = create_event();
if (_event_fd == INVALID_EVENT_FD) {
throw runtime_error(StrPrinter << "Create event fd failed: " << get_uv_errmsg());
}
#if !defined(_WIN32)
SockUtil::setCloExec(_event_fd);
#endif
#endif
_name = std::move(name);
_logger = Logger::Instance().shared_from_this();
addEventPipe();
}
void EventPoller::shutdown() {
async_l([]() {
throw ExitException();
}, false, true);
if (_loop_thread) {
//防止作为子进程时崩溃 [AUTO-TRANSLATED:68727e34]
//Prevent crash when running as a child process
try { _loop_thread->join(); } catch (...) { _loop_thread->detach(); }
delete _loop_thread;
_loop_thread = nullptr;
}
}
EventPoller::~EventPoller() {
shutdown();
#if defined(HAS_EPOLL) || defined(HAS_KQUEUE)
if (_event_fd != INVALID_EVENT_FD) {
close_event(_event_fd);
_event_fd = INVALID_EVENT_FD;
}
#endif
//退出前清理管道中的数据 [AUTO-TRANSLATED:60e26f9a]
//Clean up pipe data before exiting
onPipeEvent(true);
InfoL << getThreadName();
}
int EventPoller::addEvent(int fd, int event, PollEventCB cb) {
TimeTicker();
if (!cb) {
WarnL << "PollEventCB is empty";
return -1;
}
if (isCurrentThread()) {
#if defined(HAS_EPOLL)
struct epoll_event ev = {0};
ev.events = toEpoll(event) ;
ev.data.fd = fd;
int ret = epoll_ctl(_event_fd, EPOLL_CTL_ADD, fd, &ev);
if (ret != -1) {
_event_map.emplace(fd, std::make_shared<PollEventCB>(std::move(cb)));
}
_fd_count = _event_map.size();
return ret;
#elif defined(HAS_KQUEUE)
struct kevent kev[2];
int index = 0;
if (event & Event_Read) {
EV_SET(&kev[index++], fd, EVFILT_READ, EV_ADD | EV_CLEAR, 0, 0, nullptr);
}
if (event & Event_Write) {
EV_SET(&kev[index++], fd, EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, nullptr);
}
int ret = kevent(_event_fd, kev, index, nullptr, 0, nullptr);
if (ret != -1) {
_event_map.emplace(fd, std::make_shared<PollEventCB>(std::move(cb)));
}
_fd_count = _event_map.size();
return ret;
#else
#ifndef _WIN32
// win32平台socket套接字不等于文件描述符所以可能不适用这个限制 [AUTO-TRANSLATED:6adfc664]
//On the win32 platform, the socket does not equal the file descriptor, so this restriction may not apply
if (fd >= FD_SETSIZE) {
WarnL << "select() can not watch fd bigger than " << FD_SETSIZE;
return -1;
}
#endif
auto record = std::make_shared<Poll_Record>();
record->fd = fd;
record->event = event;
record->call_back = std::move(cb);
_event_map.emplace(fd, record);
_fd_count = _event_map.size();
return 0;
#endif
}
async([this, fd, event, cb]() mutable {
addEvent(fd, event, std::move(cb));
});
return 0;
}
int EventPoller::delEvent(int fd, PollCompleteCB cb) {
TimeTicker();
if (!cb) {
cb = [](bool success) {};
}
if (isCurrentThread()) {
#if defined(HAS_EPOLL)
int ret = -1;
if (_event_map.erase(fd)) {
_event_cache_expired.emplace(fd);
ret = epoll_ctl(_event_fd, EPOLL_CTL_DEL, fd, nullptr);
}
cb(ret != -1);
_fd_count = _event_map.size();
return ret;
#elif defined(HAS_KQUEUE)
int ret = -1;
if (_event_map.erase(fd)) {
_event_cache_expired.emplace(fd);
struct kevent kev[2];
int index = 0;
EV_SET(&kev[index++], fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr);
EV_SET(&kev[index++], fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr);
ret = kevent(_event_fd, kev, index, nullptr, 0, nullptr);
}
cb(ret != -1);
_fd_count = _event_map.size();
return ret;
#else
int ret = -1;
if (_event_map.erase(fd)) {
_event_cache_expired.emplace(fd);
ret = 0;
}
cb(ret != -1);
_fd_count = _event_map.size();
return ret;
#endif //HAS_EPOLL
}
//跨线程操作 [AUTO-TRANSLATED:4e116519]
//Cross-thread operation
async([this, fd, cb]() mutable {
delEvent(fd, std::move(cb));
});
return 0;
}
int EventPoller::modifyEvent(int fd, int event, PollCompleteCB cb) {
TimeTicker();
if (!cb) {
cb = [](bool success) {};
}
if (isCurrentThread()) {
#if defined(HAS_EPOLL)
struct epoll_event ev = { 0 };
ev.events = toEpoll(event);
ev.data.fd = fd;
auto ret = epoll_ctl(_event_fd, EPOLL_CTL_MOD, fd, &ev);
cb(ret != -1);
return ret;
#elif defined(HAS_KQUEUE)
struct kevent kev[2];
int index = 0;
EV_SET(&kev[index++], fd, EVFILT_READ, event & Event_Read ? EV_ADD | EV_CLEAR : EV_DELETE, 0, 0, nullptr);
EV_SET(&kev[index++], fd, EVFILT_WRITE, event & Event_Write ? EV_ADD | EV_CLEAR : EV_DELETE, 0, 0, nullptr);
int ret = kevent(_event_fd, kev, index, nullptr, 0, nullptr);
cb(ret != -1);
return ret;
#else
auto it = _event_map.find(fd);
if (it != _event_map.end()) {
it->second->event = event;
}
cb(it != _event_map.end());
return it != _event_map.end() ? 0 : -1;
#endif // HAS_EPOLL
}
async([this, fd, event, cb]() mutable {
modifyEvent(fd, event, std::move(cb));
});
return 0;
}
size_t EventPoller::fdCount() const {
return _fd_count;
}
Task::Ptr EventPoller::async(TaskIn task, bool may_sync) {
return async_l(std::move(task), may_sync, false);
}
Task::Ptr EventPoller::async_first(TaskIn task, bool may_sync) {
return async_l(std::move(task), may_sync, true);
}
Task::Ptr EventPoller::async_l(TaskIn task, bool may_sync, bool first) {
TimeTicker();
if (may_sync && isCurrentThread()) {
task();
return nullptr;
}
auto ret = std::make_shared<Task>(std::move(task));
{
lock_guard<mutex> lck(_mtx_task);
if (first) {
_list_task.emplace_front(ret);
} else {
_list_task.emplace_back(ret);
}
}
//写数据到管道,唤醒主线程 [AUTO-TRANSLATED:2ead8182]
//Write data to the pipe and wake up the main thread
_pipe.write("", 1);
return ret;
}
bool EventPoller::isCurrentThread() {
return !_loop_thread || _loop_thread->get_id() == this_thread::get_id();
}
inline void EventPoller::onPipeEvent(bool flush) {
char buf[1024];
int err = 0;
if (!flush) {
for (;;) {
if ((err = _pipe.read(buf, sizeof(buf))) > 0) {
// 读到管道数据,继续读,直到读空为止 [AUTO-TRANSLATED:47bd325c]
//Read data from the pipe, continue reading until it's empty
continue;
}
if (err == 0 || get_uv_error(true) != UV_EAGAIN) {
// 收到eof或非EAGAIN(无更多数据)错误,说明管道无效了,重新打开管道 [AUTO-TRANSLATED:5f7a013d]
//Received eof or non-EAGAIN (no more data) error, indicating that the pipe is invalid, reopen the pipe
ErrorL << "Invalid pipe fd of event poller, reopen it";
delEvent(_pipe.readFD());
_pipe.reOpen();
addEventPipe();
}
break;
}
}
decltype(_list_task) _list_swap;
{
lock_guard<mutex> lck(_mtx_task);
_list_swap.swap(_list_task);
}
_list_swap.for_each([&](const Task::Ptr &task) {
try {
(*task)();
} catch (ExitException &) {
_exit_flag = true;
} catch (std::exception &ex) {
ErrorL << "Exception occurred when do async task: " << ex.what();
}
});
}
SocketRecvBuffer::Ptr EventPoller::getSharedBuffer(bool is_udp) {
#if !defined(__linux) && !defined(__linux__)
// 非Linux平台下tcp和udp共享recvfrom方案使用同一个buffer [AUTO-TRANSLATED:2d2ee7bf]
//On non-Linux platforms, tcp and udp share the recvfrom scheme, using the same buffer
is_udp = 0;
#endif
auto ret = _shared_buffer[is_udp].lock();
if (!ret) {
ret = SocketRecvBuffer::create(is_udp);
_shared_buffer[is_udp] = ret;
}
return ret;
}
thread::id EventPoller::getThreadId() const {
return _loop_thread ? _loop_thread->get_id() : thread::id();
}
const std::string& EventPoller::getThreadName() const {
return _name;
}
static thread_local std::weak_ptr<EventPoller> s_current_poller;
// static
EventPoller::Ptr EventPoller::getCurrentPoller() {
return s_current_poller.lock();
}
void EventPoller::runLoop(bool blocked, bool ref_self) {
if (blocked) {
if (ref_self) {
s_current_poller = shared_from_this();
}
_sem_run_started.post();
_exit_flag = false;
int64_t minDelay;
#if defined(HAS_EPOLL)
struct epoll_event events[EPOLL_SIZE];
while (!_exit_flag) {
minDelay = getMinDelay();
startSleep(); // 用于统计当前线程负载情况
int ret = epoll_wait(_event_fd, events, EPOLL_SIZE, minDelay);
sleepWakeUp(); // 用于统计当前线程负载情况
if (ret <= 0) {
// 超时或被打断 [AUTO-TRANSLATED:7005fded]
// Timed out or interrupted
continue;
}
_event_cache_expired.clear();
for (int i = 0; i < ret; ++i) {
struct epoll_event &ev = events[i];
int fd = ev.data.fd;
if (_event_cache_expired.count(fd)) {
// event cache refresh
continue;
}
auto it = _event_map.find(fd);
if (it == _event_map.end()) {
epoll_ctl(_event_fd, EPOLL_CTL_DEL, fd, nullptr);
continue;
}
auto cb = it->second;
try {
(*cb)(toPoller(ev.events));
} catch (std::exception &ex) {
ErrorL << "Exception occurred when do event task: " << ex.what();
}
}
}
#elif defined(HAS_KQUEUE)
struct kevent kevents[KEVENT_SIZE];
while (!_exit_flag) {
minDelay = getMinDelay();
struct timespec timeout = { (long)minDelay / 1000, (long)minDelay % 1000 * 1000000 };
startSleep();
int ret = kevent(_event_fd, nullptr, 0, kevents, KEVENT_SIZE, minDelay == -1 ? nullptr : &timeout);
sleepWakeUp();
if (ret <= 0) {
continue;
}
_event_cache_expired.clear();
for (int i = 0; i < ret; ++i) {
auto &kev = kevents[i];
auto fd = kev.ident;
if (_event_cache_expired.count(fd)) {
// event cache refresh
continue;
}
auto it = _event_map.find(fd);
if (it == _event_map.end()) {
EV_SET(&kev, fd, kev.filter, EV_DELETE, 0, 0, nullptr);
kevent(_event_fd, &kev, 1, nullptr, 0, nullptr);
continue;
}
auto cb = it->second;
int event = 0;
switch (kev.filter) {
case EVFILT_READ: event = Event_Read; break;
case EVFILT_WRITE: event = Event_Write; break;
default: WarnL << "unknown kevent filter: " << kev.filter; break;
}
try {
(*cb)(event);
} catch (std::exception &ex) {
ErrorL << "Exception occurred when do event task: " << ex.what();
}
}
}
#else
int ret, max_fd;
FdSet set_read, set_write, set_err;
List<Poll_Record::Ptr> callback_list;
struct timeval tv;
while (!_exit_flag) {
// 定时器事件中可能操作_event_map [AUTO-TRANSLATED:f2a50ee2]
// Possible operations on _event_map in timer events
minDelay = getMinDelay();
tv.tv_sec = (decltype(tv.tv_sec))(minDelay / 1000);
tv.tv_usec = 1000 * (minDelay % 1000);
set_read.fdZero();
set_write.fdZero();
set_err.fdZero();
max_fd = 0;
for (auto &pr : _event_map) {
if (pr.first > max_fd) {
max_fd = pr.first;
}
if (pr.second->event & Event_Read) {
set_read.fdSet(pr.first); // 监听管道可读事件
}
if (pr.second->event & Event_Write) {
set_write.fdSet(pr.first); // 监听管道可写事件
}
if (pr.second->event & Event_Error) {
set_err.fdSet(pr.first); // 监听管道错误事件
}
}
startSleep(); // 用于统计当前线程负载情况
ret = zl_select(max_fd + 1, &set_read, &set_write, &set_err, minDelay == -1 ? nullptr : &tv);
sleepWakeUp(); // 用于统计当前线程负载情况
if (ret <= 0) {
// 超时或被打断 [AUTO-TRANSLATED:7005fded]
// Timed out or interrupted
continue;
}
_event_cache_expired.clear();
// 收集select事件类型 [AUTO-TRANSLATED:9a5c41d3]
// Collect select event types
for (auto &pr : _event_map) {
int event = 0;
if (set_read.isSet(pr.first)) {
event |= Event_Read;
}
if (set_write.isSet(pr.first)) {
event |= Event_Write;
}
if (set_err.isSet(pr.first)) {
event |= Event_Error;
}
if (event != 0) {
pr.second->attach = event;
callback_list.emplace_back(pr.second);
}
}
callback_list.for_each([&](Poll_Record::Ptr &record) {
if (_event_cache_expired.count(record->fd)) {
// event cache refresh
return;
}
try {
record->call_back(record->attach);
} catch (std::exception &ex) {
ErrorL << "Exception occurred when do event task: " << ex.what();
}
});
callback_list.clear();
}
#endif //HAS_EPOLL
} else {
_loop_thread = new thread(&EventPoller::runLoop, this, true, ref_self);
_sem_run_started.wait();
}
}
int64_t EventPoller::flushDelayTask(uint64_t now_time) {
decltype(_delay_task_map) task_copy;
task_copy.swap(_delay_task_map);
for (auto it = task_copy.begin(); it != task_copy.end() && it->first <= now_time; it = task_copy.erase(it)) {
//已到期的任务 [AUTO-TRANSLATED:849cdc29]
//Expired tasks
try {
auto next_delay = (*(it->second))();
if (next_delay) {
//可重复任务,更新时间截止线 [AUTO-TRANSLATED:c7746a21]
//Repeatable tasks, update deadline
_delay_task_map.emplace(next_delay + now_time, std::move(it->second));
}
} catch (std::exception &ex) {
ErrorL << "Exception occurred when do delay task: " << ex.what();
}
}
task_copy.insert(_delay_task_map.begin(), _delay_task_map.end());
task_copy.swap(_delay_task_map);
auto it = _delay_task_map.begin();
if (it == _delay_task_map.end()) {
//没有剩余的定时器了 [AUTO-TRANSLATED:23b1119e]
//No remaining timers
return -1;
}
//最近一个定时器的执行延时 [AUTO-TRANSLATED:2535621b]
//Delay in execution of the last timer
return it->first - now_time;
}
int64_t EventPoller::getMinDelay() {
auto it = _delay_task_map.begin();
if (it == _delay_task_map.end()) {
//没有剩余的定时器了 [AUTO-TRANSLATED:23b1119e]
//No remaining timers
return -1;
}
auto now = getCurrentMillisecond();
if (it->first > now) {
//所有任务尚未到期 [AUTO-TRANSLATED:8d80eabf]
//All tasks have not expired
return it->first - now;
}
//执行已到期的任务并刷新休眠延时 [AUTO-TRANSLATED:cd6348b7]
//Execute expired tasks and refresh sleep delay
return flushDelayTask(now);
}
EventPoller::DelayTask::Ptr EventPoller::doDelayTask(uint64_t delay_ms, function<uint64_t()> task) {
DelayTask::Ptr ret = std::make_shared<DelayTask>(std::move(task));
auto time_line = getCurrentMillisecond() + delay_ms;
async_first([time_line, ret, this]() {
//异步执行的目的是刷新select或epoll的休眠时间 [AUTO-TRANSLATED:a6b5c8d7]
//The purpose of asynchronous execution is to refresh the sleep time of select or epoll
_delay_task_map.emplace(time_line, ret);
});
return ret;
}
///////////////////////////////////////////////
static size_t s_pool_size = 0;
static bool s_enable_cpu_affinity = true;
INSTANCE_IMP(EventPollerPool)
EventPoller::Ptr EventPollerPool::getFirstPoller() {
return static_pointer_cast<EventPoller>(_threads.front());
}
EventPoller::Ptr EventPollerPool::getPoller(bool prefer_current_thread) {
auto poller = EventPoller::getCurrentPoller();
if (prefer_current_thread && _prefer_current_thread && poller) {
return poller;
}
return static_pointer_cast<EventPoller>(getExecutor());
}
void EventPollerPool::preferCurrentThread(bool flag) {
_prefer_current_thread = flag;
}
const std::string EventPollerPool::kOnStarted = "kBroadcastEventPollerPoolStarted";
EventPollerPool::EventPollerPool() {
auto size = addPoller("event poller", s_pool_size, ThreadPool::PRIORITY_HIGHEST, true, s_enable_cpu_affinity);
NOTICE_EMIT(EventPollerPoolOnStartedArgs, kOnStarted, *this, size);
InfoL << "EventPoller created size: " << size;
}
void EventPollerPool::setPoolSize(size_t size) {
s_pool_size = size;
}
void EventPollerPool::enableCpuAffinity(bool enable) {
s_enable_cpu_affinity = enable;
}
} // namespace toolkit

View File

@ -0,0 +1,422 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef EventPoller_h
#define EventPoller_h
#include <mutex>
#include <thread>
#include <string>
#include <functional>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "PipeWrap.h"
#include "Util/logger.h"
#include "Util/List.h"
#include "Thread/TaskExecutor.h"
#include "Thread/ThreadPool.h"
#include "Network/Buffer.h"
#include "Network/BufferSock.h"
#if defined(__linux__) || defined(__linux)
#define HAS_EPOLL
#endif //__linux__
#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
#define HAS_KQUEUE
#endif // __APPLE__
#if defined(HAS_EPOLL) || defined(HAS_KQUEUE)
#if defined(_WIN32)
using epoll_fd = void *;
constexpr epoll_fd INVALID_EVENT_FD = nullptr;
#else
using epoll_fd = int;
constexpr epoll_fd INVALID_EVENT_FD = -1;
#endif
#endif
namespace toolkit {
class EventPoller : public TaskExecutor, public AnyStorage, public std::enable_shared_from_this<EventPoller> {
public:
friend class TaskExecutorGetterImp;
using Ptr = std::shared_ptr<EventPoller>;
using PollEventCB = std::function<void(int event)>;
using PollCompleteCB = std::function<void(bool success)>;
using DelayTask = TaskCancelableImp<uint64_t(void)>;
typedef enum {
Event_Read = 1 << 0, // 读事件
Event_Write = 1 << 1, // 写事件
Event_Error = 1 << 2, // 错误事件
Event_LT = 1 << 3, // 水平触发
} Poll_Event;
~EventPoller();
/**
* EventPollerPool单例中的第一个EventPoller实例
*
* @return
* Gets the first EventPoller instance from the EventPollerPool singleton,
* This interface is preserved for compatibility with old code.
* @return singleton
* [AUTO-TRANSLATED:b536ebf6]
*/
static EventPoller &Instance();
/**
*
* @param fd
* @param event Event_Read | Event_Write
* @param cb functional
* @return -1:0:
* Adds an event listener
* @param fd The file descriptor to listen to
* @param event The event type, e.g. Event_Read | Event_Write
* @param cb The event callback function
* @return -1: failed, 0: success
* [AUTO-TRANSLATED:cfba4c75]
*/
int addEvent(int fd, int event, PollEventCB cb);
/**
*
* @param fd
* @param cb functional
* @return -1:0:
* Deletes an event listener
* @param fd The file descriptor to stop listening to
* @param cb The callback function for successful deletion
* @return -1: failed, 0: success
* [AUTO-TRANSLATED:be6fdf51]
*/
int delEvent(int fd, PollCompleteCB cb = nullptr);
/**
*
* @param fd
* @param event Event_Read | Event_Write
* @return -1:0:
* Modifies the event type being listened to
* @param fd The file descriptor to modify
* @param event The new event type, e.g. Event_Read | Event_Write
* @return -1: failed, 0: success
* [AUTO-TRANSLATED:becf3d09]
*/
int modifyEvent(int fd, int event, PollCompleteCB cb = nullptr);
/**
* fd事件
*/
size_t fdCount() const;
/**
*
* @param task
* @param may_sync 线线may_sync为true时就是同步执行任务
* @return true
* Executes a task asynchronously
* @param task The task to execute
* @param may_sync If the calling thread is the polling thread of this object,
* then if may_sync is true, the task will be executed synchronously
* @return Whether the task was executed successfully (always returns true)
* [AUTO-TRANSLATED:071f7ed8]
*/
Task::Ptr async(TaskIn task, bool may_sync = true) override;
/**
* async方法
* @param task
* @param may_sync 线线may_sync为true时就是同步执行任务
* @return true
* Similar to async, but adds the task to the head of the task queue,
* giving it the highest priority
* @param task The task to execute
* @param may_sync If the calling thread is the polling thread of this object,
* then if may_sync is true, the task will be executed synchronously
* @return Whether the task was executed successfully (always returns true)
* [AUTO-TRANSLATED:9ef5169b]
*/
Task::Ptr async_first(TaskIn task, bool may_sync = true) override;
/**
* 线线
* @return 线
* Checks if the thread calling this interface is the polling thread of this object
* @return Whether the calling thread is the polling thread
* [AUTO-TRANSLATED:db9a4916]
*/
bool isCurrentThread();
/**
*
* @param delay_ms
* @param task 0
* @return
* Delays the execution of a task
* @param delay_ms The delay in milliseconds
* @param task The task to execute, returns 0 to stop repeating the task,
* otherwise returns the delay for the next execution.
* If an exception is thrown in the task, it defaults to not repeating the task.
* @return A cancellable task label
* [AUTO-TRANSLATED:61f97e64]
*/
DelayTask::Ptr doDelayTask(uint64_t delay_ms, std::function<uint64_t()> task);
/**
* 线Poller实例
* Gets the Poller instance associated with the current thread
* [AUTO-TRANSLATED:debcf0e2]
*/
static EventPoller::Ptr getCurrentPoller();
/**
* 线socket共享的读缓存
* Gets the shared read buffer for all sockets in the current thread
* [AUTO-TRANSLATED:2796f458]
*/
SocketRecvBuffer::Ptr getSharedBuffer(bool is_udp);
/**
* poller线程id
* Get the poller thread ID
* [AUTO-TRANSLATED:1c968752]
*/
std::thread::id getThreadId() const;
/**
* 线
* Get the thread name
* [AUTO-TRANSLATED:842652d9]
*/
const std::string &getThreadName() const;
private:
/**
* EventPollerPool中构造
* This object can only be constructed in EventPollerPool
* [AUTO-TRANSLATED:0c9a8a28]
*/
EventPoller(std::string name);
/**
*
* @param blocked 线
* @param ref_self thread local变量
* Perform event polling
* @param blocked Whether to execute polling with the thread that calls this interface
* @param ref_self Whether to record this object to thread local variable
* [AUTO-TRANSLATED:b0ac803c]
*/
void runLoop(bool blocked, bool ref_self);
/**
* 线
* Internal pipe event, used to wake up the polling thread
* [AUTO-TRANSLATED:022754b9]
*/
void onPipeEvent(bool flush = false);
/**
* 线
* @param task
* @param may_sync
* @param first
* @return nullptr
* Switch threads and execute tasks
* @param task
* @param may_sync
* @param first
* @return The cancellable task itself, or nullptr if it has been executed synchronously
* [AUTO-TRANSLATED:e7019c4a]
*/
Task::Ptr async_l(TaskIn task, bool may_sync = true, bool first = false);
/**
*
* 线
* End event polling
* Note that once ended, the polling thread cannot be resumed
* [AUTO-TRANSLATED:4f232154]
*/
void shutdown();
/**
*
* Refresh delayed tasks
* [AUTO-TRANSLATED:88104b90]
*/
int64_t flushDelayTask(uint64_t now);
/**
* select或epoll休眠时间
* Get the sleep time for select or epoll
* [AUTO-TRANSLATED:34e0384e]
*/
int64_t getMinDelay();
/**
*
* Add pipe listening event
* [AUTO-TRANSLATED:06e5bc67]
*/
void addEventPipe();
private:
class ExitException : public std::exception {};
private:
// 标记loop线程是否退出 [AUTO-TRANSLATED:98250f84]
// 标记loop线程是否退出
// Mark the loop thread as exited
bool _exit_flag;
// 统计监听了多少个fd
size_t _fd_count = 0;
// 线程名 [AUTO-TRANSLATED:f1d62d9f]
// 线程名
// Thread name
std::string _name;
// 当前线程下所有socket共享的读缓存 [AUTO-TRANSLATED:6ce70017]
// 当前线程下所有socket共享的读缓存
// Shared read buffer for all sockets under the current thread
std::weak_ptr<SocketRecvBuffer> _shared_buffer[2];
// 执行事件循环的线程 [AUTO-TRANSLATED:2465cc75]
// 执行事件循环的线程
// Thread that executes the event loop
std::thread *_loop_thread = nullptr;
// 通知事件循环的线程已启动 [AUTO-TRANSLATED:61f478cf]
// 通知事件循环的线程已启动
// Notify the event loop thread that it has started
semaphore _sem_run_started;
// 内部事件管道 [AUTO-TRANSLATED:dc1d3a93]
// 内部事件管道
// Internal event pipe
PipeWrap _pipe;
// 从其他线程切换过来的任务 [AUTO-TRANSLATED:d16917d6]
// 从其他线程切换过来的任务
// Tasks switched from other threads
std::mutex _mtx_task;
List<Task::Ptr> _list_task;
// 保持日志可用 [AUTO-TRANSLATED:4a6c2438]
// 保持日志可用
// Keep the log available
Logger::Ptr _logger;
#if defined(HAS_EPOLL) || defined(HAS_KQUEUE)
// epoll和kqueue相关 [AUTO-TRANSLATED:84d2785e]
// epoll和kqueue相关
// epoll and kqueue related
epoll_fd _event_fd = INVALID_EVENT_FD;
std::unordered_map<int, std::shared_ptr<PollEventCB>> _event_map;
#else
// select相关 [AUTO-TRANSLATED:bf3e2edd]
// select相关
// select related
struct Poll_Record {
using Ptr = std::shared_ptr<Poll_Record>;
int fd;
int event;
int attach;
PollEventCB call_back;
};
std::unordered_map<int, Poll_Record::Ptr> _event_map;
#endif // HAS_EPOLL
std::unordered_set<int> _event_cache_expired;
// 定时器相关 [AUTO-TRANSLATED:fa2e84da]
// Timer related
std::multimap<uint64_t, DelayTask::Ptr> _delay_task_map;
};
class EventPollerPool : public std::enable_shared_from_this<EventPollerPool>, public TaskExecutorGetterImp {
public:
using Ptr = std::shared_ptr<EventPollerPool>;
static const std::string kOnStarted;
#define EventPollerPoolOnStartedArgs EventPollerPool &pool, size_t &size
~EventPollerPool() = default;
/**
*
* @return
* Get singleton
* @return
* [AUTO-TRANSLATED:1cb32aa7]
*/
static EventPollerPool &Instance();
/**
* EventPoller个数EventPollerPool单例创建前有效
* thread::hardware_concurrency()EventPoller实例
* @param size EventPoller个数0thread::hardware_concurrency()
* Set the number of EventPoller instances, effective before the EventPollerPool singleton is created
* If this method is not called, the default is to create thread::hardware_concurrency() EventPoller instances
* @param size Number of EventPoller instances, 0 means thread::hardware_concurrency()
* [AUTO-TRANSLATED:bdc02181]
*/
static void setPoolSize(size_t size = 0);
/**
* 线cpu亲和性cpu亲和性
* Whether to set CPU affinity for internal thread creation, default is to set CPU affinity
* [AUTO-TRANSLATED:46941c9f]
*/
static void enableCpuAffinity(bool enable);
/**
*
* @return
* Get the first instance
* @return
* [AUTO-TRANSLATED:a76aad3b]
*/
EventPoller::Ptr getFirstPoller();
/**
*
* 线线
* 线线
* @param prefer_current_thread 线
* Get a lightly loaded instance based on the load
* If prioritizing the current thread, it will return the current thread
* The purpose of returning the current thread is to improve thread safety
* @param prefer_current_thread Whether to prioritize getting the current thread
* [AUTO-TRANSLATED:f0830806]
*/
EventPoller::Ptr getPoller(bool prefer_current_thread = true);
/**
* getPoller() 线
* Socket对象时线
*
* @param flag 线
* Set whether getPoller() prioritizes returning the current thread
* When creating Socket objects in batches, if prioritizing the current thread,
* it will cause the load to be unbalanced, so it can be temporarily closed and then reopened
* @param flag Whether to prioritize returning the current thread
* [AUTO-TRANSLATED:c354e1d5]
*/
void preferCurrentThread(bool flag = true);
private:
EventPollerPool();
private:
bool _prefer_current_thread = true;
};
} // namespace toolkit
#endif /* EventPoller_h */

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include <fcntl.h>
#include "Pipe.h"
#include "Network/sockutil.h"
using namespace std;
namespace toolkit {
Pipe::Pipe(const onRead &cb, const EventPoller::Ptr &poller) {
_poller = poller;
if (!_poller) {
_poller = EventPollerPool::Instance().getPoller();
}
_pipe = std::make_shared<PipeWrap>();
auto pipe = _pipe;
_poller->addEvent(_pipe->readFD(), EventPoller::Event_Read, [cb, pipe](int event) {
#if defined(_WIN32)
unsigned long nread = 1024;
#else
int nread = 1024;
#endif //defined(_WIN32)
ioctl(pipe->readFD(), FIONREAD, &nread);
#if defined(_WIN32)
std::shared_ptr<char> buf(new char[nread + 1], [](char *ptr) {delete[] ptr; });
buf.get()[nread] = '\0';
nread = pipe->read(buf.get(), nread + 1);
if (cb) {
cb(nread, buf.get());
}
#else
char buf[nread + 1];
buf[nread] = '\0';
nread = pipe->read(buf, sizeof(buf));
if (cb) {
cb(nread, buf);
}
#endif // defined(_WIN32)
});
}
Pipe::~Pipe() {
if (_pipe) {
auto pipe = _pipe;
_poller->delEvent(pipe->readFD(), [pipe](bool success) {});
}
}
void Pipe::send(const char *buf, int size) {
_pipe->write(buf, size);
}
} // namespace toolkit

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef Pipe_h
#define Pipe_h
#include <functional>
#include "PipeWrap.h"
#include "EventPoller.h"
namespace toolkit {
class Pipe {
public:
using onRead = std::function<void(int size, const char *buf)>;
Pipe(const onRead &cb = nullptr, const EventPoller::Ptr &poller = nullptr);
~Pipe();
void send(const char *send, int size = 0);
private:
std::shared_ptr<PipeWrap> _pipe;
EventPoller::Ptr _poller;
};
} // namespace toolkit
#endif /* Pipe_h */

View File

@ -0,0 +1,96 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include <stdexcept>
#include "PipeWrap.h"
#include "Util/util.h"
#include "Util/uv_errno.h"
#include "Network/sockutil.h"
using namespace std;
#define checkFD(fd) \
if (fd == -1) { \
clearFD(); \
throw runtime_error(StrPrinter << "Create windows pipe failed: " << get_uv_errmsg());\
}
#define closeFD(fd) \
if (fd != -1) { \
close(fd);\
fd = -1;\
}
namespace toolkit {
PipeWrap::PipeWrap() {
reOpen();
}
void PipeWrap::reOpen() {
clearFD();
#if defined(_WIN32)
const char *localip = "127.0.0.1";
auto listener_fd = SockUtil::listen(0, localip);
checkFD(listener_fd)
SockUtil::setNoBlocked(listener_fd,false);
auto localPort = SockUtil::get_local_port(listener_fd);
_pipe_fd[1] = SockUtil::connect(localip, localPort,false);
checkFD(_pipe_fd[1])
_pipe_fd[0] = (int)accept(listener_fd, nullptr, nullptr);
checkFD(_pipe_fd[0])
SockUtil::setNoDelay(_pipe_fd[0]);
SockUtil::setNoDelay(_pipe_fd[1]);
close(listener_fd);
#else
if (pipe(_pipe_fd) == -1) {
throw runtime_error(StrPrinter << "Create posix pipe failed: " << get_uv_errmsg());
}
#endif // defined(_WIN32)
SockUtil::setNoBlocked(_pipe_fd[0], true);
SockUtil::setNoBlocked(_pipe_fd[1], false);
SockUtil::setCloExec(_pipe_fd[0]);
SockUtil::setCloExec(_pipe_fd[1]);
}
void PipeWrap::clearFD() {
closeFD(_pipe_fd[0]);
closeFD(_pipe_fd[1]);
}
PipeWrap::~PipeWrap() {
clearFD();
}
int PipeWrap::write(const void *buf, int n) {
int ret;
do {
#if defined(_WIN32)
ret = send(_pipe_fd[1], (char *)buf, n, 0);
#else
ret = ::write(_pipe_fd[1], buf, n);
#endif // defined(_WIN32)
} while (-1 == ret && UV_EINTR == get_uv_error(true));
return ret;
}
int PipeWrap::read(void *buf, int n) {
int ret;
do {
#if defined(_WIN32)
ret = recv(_pipe_fd[0], (char *)buf, n, 0);
#else
ret = ::read(_pipe_fd[0], buf, n);
#endif // defined(_WIN32)
} while (-1 == ret && UV_EINTR == get_uv_error(true));
return ret;
}
} /* namespace toolkit*/

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef PipeWarp_h
#define PipeWarp_h
namespace toolkit {
class PipeWrap {
public:
PipeWrap();
~PipeWrap();
int write(const void *buf, int n);
int read(void *buf, int n);
int readFD() const { return _pipe_fd[0]; }
int writeFD() const { return _pipe_fd[1]; }
void reOpen();
private:
void clearFD();
private:
int _pipe_fd[2] = { -1, -1 };
};
} /* namespace toolkit */
#endif // !PipeWarp_h

View File

@ -0,0 +1,52 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SelectWrap.h"
using namespace std;
namespace toolkit {
FdSet::FdSet() {
_ptr = new fd_set;
}
FdSet::~FdSet() {
delete (fd_set *)_ptr;
}
void FdSet::fdZero() {
FD_ZERO((fd_set *)_ptr);
}
void FdSet::fdClr(int fd) {
FD_CLR(fd, (fd_set *)_ptr);
}
void FdSet::fdSet(int fd) {
FD_SET(fd, (fd_set *)_ptr);
}
bool FdSet::isSet(int fd) {
return FD_ISSET(fd, (fd_set *)_ptr);
}
int zl_select(int cnt, FdSet *read, FdSet *write, FdSet *err, struct timeval *tv) {
void *rd, *wt, *er;
rd = read ? read->_ptr : nullptr;
wt = write ? write->_ptr : nullptr;
er = err ? err->_ptr : nullptr;
return ::select(cnt, (fd_set *) rd, (fd_set *) wt, (fd_set *) er, tv);
}
} /* namespace toolkit */

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SRC_POLLER_SELECTWRAP_H_
#define SRC_POLLER_SELECTWRAP_H_
#include "Util/util.h"
namespace toolkit {
class FdSet {
public:
FdSet();
~FdSet();
void fdZero();
void fdSet(int fd);
void fdClr(int fd);
bool isSet(int fd);
void *_ptr;
};
int zl_select(int cnt, FdSet *read, FdSet *write, FdSet *err, struct timeval *tv);
} /* namespace toolkit */
#endif /* SRC_POLLER_SELECTWRAP_H_ */

View File

@ -0,0 +1,44 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "Timer.h"
namespace toolkit {
Timer::Timer(float second, const std::function<bool()> &cb, const EventPoller::Ptr &poller) {
_poller = poller;
if (!_poller) {
_poller = EventPollerPool::Instance().getPoller();
}
_tag = _poller->doDelayTask((uint64_t) (second * 1000), [cb, second]() {
try {
if (cb()) {
//重复的任务 [AUTO-TRANSLATED:2d440b54]
//Recurring task
return (uint64_t) (1000 * second);
}
//该任务不再重复 [AUTO-TRANSLATED:4249fc53]
//This task no longer recurs
return (uint64_t) 0;
} catch (std::exception &ex) {
ErrorL << "Exception occurred when do timer task: " << ex.what();
return (uint64_t) (1000 * second);
}
});
}
Timer::~Timer() {
auto tag = _tag.lock();
if (tag) {
tag->cancel();
}
}
} // namespace toolkit

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef Timer_h
#define Timer_h
#include <functional>
#include "EventPoller.h"
namespace toolkit {
class Timer {
public:
using Ptr = std::shared_ptr<Timer>;
/**
*
* @param second
* @param cb true表示重复下次任务
* @param poller EventPoller对象nullptr
* Constructs a timer
* @param second Timer repeat interval in seconds
* @param cb Timer task, returns true to repeat the next task, otherwise does not repeat. If an exception is thrown in the task, it defaults to repeating the next task
* @param poller EventPoller object, can be nullptr
* [AUTO-TRANSLATED:7dc94698]
*/
Timer(float second, const std::function<bool()> &cb, const EventPoller::Ptr &poller);
~Timer();
private:
std::weak_ptr<EventPoller::DelayTask> _tag;
//定时器保持EventPoller的强引用 [AUTO-TRANSLATED:d171cd2f]
//Timer keeps a strong reference to EventPoller
EventPoller::Ptr _poller;
};
} // namespace toolkit
#endif /* Timer_h */

View File

@ -0,0 +1,64 @@
源代码放置在`src`文件夹下,里面有若干模块:
```
src
|
|-- NetWork # 网络模块
| |-- Socket.cpp # 套接字抽象封装包含了TCP服务器/客户端UDP套接字
| |-- Socket.h
| |-- sockutil.cpp # 系统网络相关API的统一封装
| |-- sockutil.h
| |-- TcpClient.cpp # TCP客户端封装派生该类可以很容易实现客户端程序
| |-- TcpClient.h
| |-- TcpServer.h # TCP服务器模板类可以很容易就实现一个高性能私有协议服务器
| |-- Session.h # TCP/UDP服务私有协议实现会话基类用于处理TCP/UDP长连接数据及响应
|
|-- Poller # 主线程事件轮询模块
| |-- EventPoller.cpp # 主线程,所有网络事件由此线程轮询并触发
| |-- EventPoller.h
| |-- Pipe.cpp # 管道的对象封装
| |-- Pipe.h
| |-- PipeWrap.cpp # 管道的包装windows下由socket模拟
| |-- SelectWrap.cpp # select 模型的简单包装
| |-- SelectWrap.h
| |-- Timer.cpp # 在主线程触发的定时器
| |-- Timer.h
|
|-- Thread # 线程模块
| |-- AsyncTaskThread.cpp # 后台异步任务线程,可以提交一个可定时重复的任务后台执行
| |-- AsyncTaskThread.h
| |-- rwmutex.h # 读写锁,实验性质的
| |-- semaphore.h # 信号量,由条件变量实现
| |-- spin_mutex.h # 自旋锁,在低延时临界区适用,单核/低性能设备慎用
| |-- TaskQueue.h # functional的任务列队
| |-- threadgroup.h # 线程组移植自boost
| |-- ThreadPool.h # 线程池可以输入functional任务至后台线程执行
| |-- WorkThreadPool.cpp # 获取一个可用的线程池(可以加入线程负载均衡分配算法)
| |-- WorkThreadPool.h
|
|-- Util # 工具模块
|-- File.cpp # 文件/目录操作模块
|-- File.h
|-- function_traits.h # 函数、lambda转functional
|-- logger.h # 日志模块
|-- MD5.cpp # md5加密模块
|-- MD5.h
|-- mini.h # ini配置文件读写模块支持unix/windows格式的回车符
|-- NoticeCenter.h # 消息广播器,可以广播传递任意个数任意类型参数
|-- onceToken.h # 使用RAII模式实现可以在对象构造和析构时执行一段代码
|-- ResourcePool.h # 基于智能指针实现的一个循环池,不需要手动回收对象
|-- RingBuffer.h # 环形缓冲可以自适应大小适用于GOP缓存等
|-- SqlConnection.cpp # mysql客户端
|-- SqlConnection.h
|-- SqlPool.h # mysql连接池以及简单易用的sql语句生成工具
|-- SSLBox.cpp # openssl的黑盒封装屏蔽了ssl握手细节支持多线程
|-- SSLBox.h
|-- TimeTicker.h # 计时器,可以用于统计函数执行时间
|-- util.cpp # 其他一些工具代码,适配了多种系统
|-- util.h
|-- uv_errno.cpp # 提取自libuv的错误代码系统主要是为了兼容windows
|-- uv_errno.h
```

View File

@ -0,0 +1,256 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include <memory>
#include <atomic>
#include "TaskExecutor.h"
#include "Poller/EventPoller.h"
#include "Util/onceToken.h"
#include "Util/TimeTicker.h"
using namespace std;
namespace toolkit {
ThreadLoadCounter::ThreadLoadCounter(uint64_t max_size, uint64_t max_usec) {
_last_sleep_time = _last_wake_time = getCurrentMicrosecond();
_max_size = max_size;
_max_usec = max_usec;
}
void ThreadLoadCounter::startSleep() {
lock_guard<mutex> lck(_mtx);
_sleeping = true;
auto current_time = getCurrentMicrosecond();
auto run_time = current_time - _last_wake_time;
_last_sleep_time = current_time;
_time_list.emplace_back(run_time, false);
if (_time_list.size() > _max_size) {
_time_list.pop_front();
}
}
void ThreadLoadCounter::sleepWakeUp() {
lock_guard<mutex> lck(_mtx);
_sleeping = false;
auto current_time = getCurrentMicrosecond();
auto sleep_time = current_time - _last_sleep_time;
_last_wake_time = current_time;
_time_list.emplace_back(sleep_time, true);
if (_time_list.size() > _max_size) {
_time_list.pop_front();
}
}
int ThreadLoadCounter::load() {
lock_guard<mutex> lck(_mtx);
uint64_t totalSleepTime = 0;
uint64_t totalRunTime = 0;
_time_list.for_each([&](const TimeRecord &rcd) {
if (rcd._sleep) {
totalSleepTime += rcd._time;
} else {
totalRunTime += rcd._time;
}
});
if (_sleeping) {
totalSleepTime += (getCurrentMicrosecond() - _last_sleep_time);
} else {
totalRunTime += (getCurrentMicrosecond() - _last_wake_time);
}
uint64_t totalTime = totalRunTime + totalSleepTime;
while ((_time_list.size() != 0) && (totalTime > _max_usec || _time_list.size() > _max_size)) {
TimeRecord &rcd = _time_list.front();
if (rcd._sleep) {
totalSleepTime -= rcd._time;
} else {
totalRunTime -= rcd._time;
}
totalTime -= rcd._time;
_time_list.pop_front();
}
if (totalTime == 0) {
return 0;
}
return (int) (totalRunTime * 100 / totalTime);
}
////////////////////////////////////////////////////////////////////////////
Task::Ptr TaskExecutorInterface::async_first(TaskIn task, bool may_sync) {
return async(std::move(task), may_sync);
}
void TaskExecutorInterface::sync(const TaskIn &task) {
semaphore sem;
auto ret = async([&]() {
onceToken token(nullptr, [&]() {
//通过RAII原理防止抛异常导致不执行这句代码 [AUTO-TRANSLATED:206bd80e]
//Prevent this code from not being executed due to an exception being thrown through RAII principle
sem.post();
});
task();
});
if (ret && *ret) {
sem.wait();
}
}
void TaskExecutorInterface::sync_first(const TaskIn &task) {
semaphore sem;
auto ret = async_first([&]() {
onceToken token(nullptr, [&]() {
//通过RAII原理防止抛异常导致不执行这句代码 [AUTO-TRANSLATED:206bd80e]
//Prevent this code from not being executed due to an exception being thrown through RAII principle
sem.post();
});
task();
});
if (ret && *ret) {
sem.wait();
}
}
//////////////////////////////////////////////////////////////////
TaskExecutor::TaskExecutor(uint64_t max_size, uint64_t max_usec) : ThreadLoadCounter(max_size, max_usec) {}
//////////////////////////////////////////////////////////////////
TaskExecutor::Ptr TaskExecutorGetterImp::getExecutor() {
auto thread_pos = _thread_pos;
if (thread_pos >= _threads.size()) {
thread_pos = 0;
}
TaskExecutor::Ptr executor_min_load = _threads[thread_pos];
auto min_load = executor_min_load->load();
for (size_t i = 0; i < _threads.size(); ++i) {
++thread_pos;
if (thread_pos >= _threads.size()) {
thread_pos = 0;
}
auto th = _threads[thread_pos];
auto load = th->load();
if (load < min_load) {
min_load = load;
executor_min_load = th;
}
if (min_load == 0) {
break;
}
}
_thread_pos = thread_pos;
return executor_min_load;
}
vector<int> TaskExecutorGetterImp::getExecutorLoad() {
vector<int> vec(_threads.size());
int i = 0;
for (auto &executor : _threads) {
vec[i++] = executor->load();
}
return vec;
}
void TaskExecutorGetterImp::getExecutorDelay(const function<void(const vector<int> &)> &callback) {
std::shared_ptr<vector<int> > delay_vec = std::make_shared<vector<int>>(_threads.size());
shared_ptr<void> finished(nullptr, [callback, delay_vec](void *) {
//此析构回调触发时说明已执行完毕所有async任务 [AUTO-TRANSLATED:8adf8212]
//When this destructor callback is triggered, it means all async tasks have been executed
callback((*delay_vec));
});
int index = 0;
for (auto &th : _threads) {
std::shared_ptr<Ticker> delay_ticker = std::make_shared<Ticker>();
th->async([finished, delay_vec, index, delay_ticker]() {
(*delay_vec)[index] = (int) delay_ticker->elapsedTime();
}, false);
++index;
}
}
using onGetExecutor = std::function<void(const TaskExecutor::Ptr &)>;
class onGetExecutorCB {
public:
onGetExecutorCB(onGetExecutor cb): _cb(std::move(cb)) {}
void operator()(const TaskExecutor::Ptr &exe) {
bool expected = false;
if (_done.compare_exchange_strong(expected, true)) {
_cb(exe);
_cb = nullptr;
}
}
private:
std::atomic<bool> _done { false };
std::function<void(const TaskExecutor::Ptr &)> _cb;
};
void TaskExecutorGetterImp::getExecutor(const onGetExecutor &cb) {
auto callback = std::make_shared<onGetExecutorCB>(cb);
auto thread_pos = _thread_pos;
if (thread_pos >= _threads.size()) {
thread_pos = 0;
}
for (size_t i = 0; i < _threads.size(); ++i) {
++thread_pos;
if (thread_pos >= _threads.size()) {
thread_pos = 0;
}
auto &th = _threads[thread_pos];
th->async([th, callback]() mutable { (*callback)(th); }, false);
}
_thread_pos = thread_pos;
}
void TaskExecutorGetterImp::for_each(const function<void(const TaskExecutor::Ptr &)> &cb) {
for (auto &th : _threads) {
cb(th);
}
}
size_t TaskExecutorGetterImp::getExecutorSize() const {
return _threads.size();
}
size_t TaskExecutorGetterImp::addPoller(const string &name, size_t size, int priority, bool register_thread, bool enable_cpu_affinity) {
auto cpus = thread::hardware_concurrency();
size = size > 0 ? size : cpus;
for (size_t i = 0; i < size; ++i) {
auto full_name = name + " " + to_string(i);
auto cpu_index = i % cpus;
EventPoller::Ptr poller(new EventPoller(full_name));
poller->runLoop(false, register_thread);
poller->async([cpu_index, full_name, priority, enable_cpu_affinity]() {
// 设置线程优先级 [AUTO-TRANSLATED:2966f860]
//Set thread priority
ThreadPool::setPriority((ThreadPool::Priority)priority);
// 设置线程名 [AUTO-TRANSLATED:f5eb4704]
//Set thread name
setThreadName(full_name.data());
// 设置cpu亲和性 [AUTO-TRANSLATED:ba213aed]
//Set CPU affinity
if (enable_cpu_affinity) {
setThreadAffinity(cpu_index);
}
});
_threads.emplace_back(std::move(poller));
}
return size;
}
}//toolkit

View File

@ -0,0 +1,308 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_TASKEXECUTOR_H
#define ZLTOOLKIT_TASKEXECUTOR_H
#include <mutex>
#include <memory>
#include <functional>
#include "Util/List.h"
#include "Util/util.h"
namespace toolkit {
/**
* cpu负载计算器
* CPU Load Calculator
* [AUTO-TRANSLATED:46dad663]
*/
class ThreadLoadCounter {
public:
/**
*
* @param max_size
* @param max_usec ,{max_usec}cpu负载率
* Constructor
* @param max_size Number of statistical samples
* @param max_usec Statistical time window, i.e., the CPU load rate for the most recent {max_usec}
* [AUTO-TRANSLATED:718cb173]
*/
ThreadLoadCounter(uint64_t max_size, uint64_t max_usec);
~ThreadLoadCounter() = default;
/**
* 线
* Thread enters sleep
* [AUTO-TRANSLATED:d831fad1]
*/
void startSleep();
/**
* ,
* Wake up from sleep, end sleep
* [AUTO-TRANSLATED:361831f8]
*/
void sleepWakeUp();
/**
* 线cpu使用率 0 ~ 100
* @return 线cpu使用率
* Returns the current thread's CPU usage rate, ranging from 0 to 100
* @return Current thread's CPU usage rate
* [AUTO-TRANSLATED:c9953342]
*/
int load();
private:
struct TimeRecord {
TimeRecord(uint64_t tm, bool slp) {
_time = tm;
_sleep = slp;
}
bool _sleep;
uint64_t _time;
};
private:
bool _sleeping = true;
uint64_t _last_sleep_time;
uint64_t _last_wake_time;
uint64_t _max_size;
uint64_t _max_usec;
std::mutex _mtx;
List<TimeRecord> _time_list;
};
class TaskCancelable : public noncopyable {
public:
TaskCancelable() = default;
virtual ~TaskCancelable() = default;
virtual void cancel() = 0;
};
template<class R, class... ArgTypes>
class TaskCancelableImp;
template<class R, class... ArgTypes>
class TaskCancelableImp<R(ArgTypes...)> : public TaskCancelable {
public:
using Ptr = std::shared_ptr<TaskCancelableImp>;
using func_type = std::function<R(ArgTypes...)>;
~TaskCancelableImp() = default;
template<typename FUNC>
TaskCancelableImp(FUNC &&task) {
_strongTask = std::make_shared<func_type>(std::forward<FUNC>(task));
_weakTask = _strongTask;
}
void cancel() override {
_strongTask = nullptr;
}
operator bool() {
return _strongTask && *_strongTask;
}
void operator=(std::nullptr_t) {
_strongTask = nullptr;
}
R operator()(ArgTypes ...args) const {
auto strongTask = _weakTask.lock();
if (strongTask && *strongTask) {
return (*strongTask)(std::forward<ArgTypes>(args)...);
}
return defaultValue<R>();
}
template<typename T>
static typename std::enable_if<std::is_void<T>::value, void>::type
defaultValue() {}
template<typename T>
static typename std::enable_if<std::is_pointer<T>::value, T>::type
defaultValue() {
return nullptr;
}
template<typename T>
static typename std::enable_if<std::is_integral<T>::value, T>::type
defaultValue() {
return 0;
}
protected:
std::weak_ptr<func_type> _weakTask;
std::shared_ptr<func_type> _strongTask;
};
using TaskIn = std::function<void()>;
using Task = TaskCancelableImp<void()>;
class TaskExecutorInterface {
public:
TaskExecutorInterface() = default;
virtual ~TaskExecutorInterface() = default;
/**
*
* @param task
* @param may_sync
* @return
* Asynchronously execute a task
* @param task Task
* @param may_sync Whether to allow synchronous execution of the task
* @return Whether the task was added successfully
* [AUTO-TRANSLATED:271d48a2]
*/
virtual Task::Ptr async(TaskIn task, bool may_sync = true) = 0;
/**
*
* @param task
* @param may_sync
* @return
* Asynchronously execute a task with the highest priority
* @param task Task
* @param may_sync Whether to allow synchronous execution of the task
* @return Whether the task was added successfully
* [AUTO-TRANSLATED:d52ce80b]
*/
virtual Task::Ptr async_first(TaskIn task, bool may_sync = true);
/**
*
* @param task
* @return
* Synchronously execute a task
* @param task
* @return
* [AUTO-TRANSLATED:24854b4a]
*/
void sync(const TaskIn &task);
/**
*
* @param task
* @return
* Synchronously execute a task with the highest priority
* @param task
* @return
* [AUTO-TRANSLATED:3d15452d]
*/
void sync_first(const TaskIn &task);
};
/**
*
* Task Executor
* [AUTO-TRANSLATED:630c364f]
*/
class TaskExecutor : public ThreadLoadCounter, public TaskExecutorInterface {
public:
using Ptr = std::shared_ptr<TaskExecutor>;
/**
*
* @param max_size cpu负载统计样本数
* @param max_usec cpu负载统计时间窗口大小
*/
TaskExecutor(uint64_t max_size = 32, uint64_t max_usec = 2 * 1000 * 1000);
~TaskExecutor() = default;
};
class TaskExecutorGetter {
public:
using Ptr = std::shared_ptr<TaskExecutorGetter>;
virtual ~TaskExecutorGetter() = default;
/**
*
* @return
*/
virtual TaskExecutor::Ptr getExecutor() = 0;
/**
*
* @param cb
*/
virtual void getExecutor(const std::function<void(const TaskExecutor::Ptr &)> &cb) = 0;
/**
*
*/
virtual size_t getExecutorSize() const = 0;
};
class TaskExecutorGetterImp : public TaskExecutorGetter {
public:
TaskExecutorGetterImp() = default;
~TaskExecutorGetterImp() = default;
/**
* 线
* @return
*/
TaskExecutor::Ptr getExecutor() override;
/**
*
* @param cb
*/
void getExecutor(const std::function<void(const TaskExecutor::Ptr &)> &cb) override;
/**
* 线
* @return 线
*/
std::vector<int> getExecutorLoad();
/**
* 线
* 线
* @return
*/
void getExecutorDelay(const std::function<void(const std::vector<int> &)> &callback);
/**
* 线
*/
void for_each(const std::function<void(const TaskExecutor::Ptr &)> &cb);
/**
* 线
*/
size_t getExecutorSize() const override;
protected:
size_t addPoller(const std::string &name, size_t size, int priority, bool register_thread, bool enable_cpu_affinity = true);
protected:
size_t _thread_pos = 0;
std::vector<TaskExecutor::Ptr> _threads;
};
}//toolkit
#endif //ZLTOOLKIT_TASKEXECUTOR_H

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef TASKQUEUE_H_
#define TASKQUEUE_H_
#include <mutex>
#include "Util/List.h"
#include "semaphore.h"
namespace toolkit {
//实现了一个基于函数对象的任务列队,该列队是线程安全的,任务列队任务数由信号量控制 [AUTO-TRANSLATED:67e02e93]
//Implemented a task queue based on function objects, which is thread-safe, and the number of tasks in the task queue is controlled by a semaphore
template<typename T>
class TaskQueue {
public:
//打入任务至列队 [AUTO-TRANSLATED:d08b5817]
//Put a task into the queue
template<typename C>
void push_task(C &&task_func) {
{
std::lock_guard<decltype(_mutex)> lock(_mutex);
_queue.emplace_back(std::forward<C>(task_func));
}
_sem.post();
}
template<typename C>
void push_task_first(C &&task_func) {
{
std::lock_guard<decltype(_mutex)> lock(_mutex);
_queue.emplace_front(std::forward<C>(task_func));
}
_sem.post();
}
//清空任务列队 [AUTO-TRANSLATED:dbcd7fe9]
//Clear the task queue
void push_exit(size_t n) {
_sem.post(n);
}
//从列队获取一个任务,由执行线程执行 [AUTO-TRANSLATED:4a1143ae]
//Get a task from the queue and execute it by the executing thread
bool get_task(T &tsk) {
_sem.wait();
std::lock_guard<decltype(_mutex)> lock(_mutex);
if (_queue.empty()) {
return false;
}
tsk = std::move(_queue.front());
_queue.pop_front();
return true;
}
size_t size() const {
std::lock_guard<decltype(_mutex)> lock(_mutex);
return _queue.size();
}
private:
List <T> _queue;
mutable std::mutex _mutex;
semaphore _sem;
};
} /* namespace toolkit */
#endif /* TASKQUEUE_H_ */

View File

@ -0,0 +1,169 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef THREADPOOL_H_
#define THREADPOOL_H_
#include "threadgroup.h"
#include "TaskQueue.h"
#include "TaskExecutor.h"
#include "Util/util.h"
#include "Util/logger.h"
namespace toolkit {
class ThreadPool : public TaskExecutor {
public:
enum Priority {
PRIORITY_LOWEST = 0,
PRIORITY_LOW,
PRIORITY_NORMAL,
PRIORITY_HIGH,
PRIORITY_HIGHEST
};
ThreadPool(int num = 1, Priority priority = PRIORITY_HIGHEST, bool auto_run = true, bool set_affinity = true,
const std::string &pool_name = "thread pool") {
_thread_num = num;
_on_setup = [pool_name, priority, set_affinity](int index) {
std::string name = pool_name + ' ' + std::to_string(index);
setPriority(priority);
setThreadName(name.data());
if (set_affinity) {
setThreadAffinity(index % std::thread::hardware_concurrency());
}
};
_logger = Logger::Instance().shared_from_this();
if (auto_run) {
start();
}
}
~ThreadPool() {
shutdown();
wait();
}
//把任务打入线程池并异步执行 [AUTO-TRANSLATED:651c8d5a]
//Put the task into the thread pool and execute it asynchronously
Task::Ptr async(TaskIn task, bool may_sync = true) override {
if (may_sync && _thread_group.is_this_thread_in()) {
task();
return nullptr;
}
auto ret = std::make_shared<Task>(std::move(task));
_queue.push_task([ret](size_t) {
(*ret)();
});
return ret;
}
Task::Ptr async_first(TaskIn task, bool may_sync = true) override {
if (may_sync && _thread_group.is_this_thread_in()) {
task();
return nullptr;
}
auto ret = std::make_shared<Task>(std::move(task));
_queue.push_task_first([ret](size_t) {
(*ret)();
});
return ret;
}
void async2(std::function<void(size_t index)> task, bool may_sync = true) {
if (may_sync && _thread_group.is_this_thread_in()) {
task(0);
return;
}
_queue.push_task(std::move(task));
}
size_t size() {
return _queue.size();
}
static bool setPriority(Priority priority = PRIORITY_NORMAL, std::thread::native_handle_type threadId = 0) {
// set priority
#if defined(_WIN32)
static int Priorities[] = { THREAD_PRIORITY_LOWEST, THREAD_PRIORITY_BELOW_NORMAL, THREAD_PRIORITY_NORMAL, THREAD_PRIORITY_ABOVE_NORMAL, THREAD_PRIORITY_HIGHEST };
if (priority != PRIORITY_NORMAL && SetThreadPriority(GetCurrentThread(), Priorities[priority]) == 0) {
return false;
}
return true;
#else
static int Min = sched_get_priority_min(SCHED_FIFO);
if (Min == -1) {
return false;
}
static int Max = sched_get_priority_max(SCHED_FIFO);
if (Max == -1) {
return false;
}
static int Priorities[] = {Min, Min + (Max - Min) / 4, Min + (Max - Min) / 2, Min + (Max - Min) * 3 / 4, Max};
if (threadId == 0) {
threadId = pthread_self();
}
struct sched_param params;
params.sched_priority = Priorities[priority];
return pthread_setschedparam(threadId, SCHED_FIFO, &params) == 0;
#endif
}
void start() {
if (_thread_num <= 0) {
return;
}
size_t total = _thread_num - _thread_group.size();
for (size_t i = 0; i < total; ++i) {
_thread_group.create_thread([this, i]() {run(i);});
}
}
private:
void run(size_t index) {
_on_setup(index);
std::function<void(size_t index)> task;
while (true) {
startSleep();
if (!_queue.get_task(task)) {
// 空任务,退出线程 [AUTO-TRANSLATED:583e2f11]
// Empty task, exit the thread
break;
}
sleepWakeUp();
try {
task(index);
task = nullptr;
} catch (std::exception &ex) {
ErrorL << "ThreadPool catch a exception: " << ex.what();
}
}
}
void wait() {
_thread_group.join_all();
}
void shutdown() {
_queue.push_exit(_thread_num);
}
private:
size_t _thread_num;
Logger::Ptr _logger;
thread_group _thread_group;
TaskQueue<std::function<void(size_t index)>> _queue;
std::function<void(int)> _on_setup;
};
} /* namespace toolkit */
#endif /* THREADPOOL_H_ */

View File

@ -0,0 +1,43 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "WorkThreadPool.h"
namespace toolkit {
static size_t s_pool_size = 0;
static bool s_enable_cpu_affinity = true;
INSTANCE_IMP(WorkThreadPool)
EventPoller::Ptr WorkThreadPool::getFirstPoller() {
return std::static_pointer_cast<EventPoller>(_threads.front());
}
EventPoller::Ptr WorkThreadPool::getPoller() {
return std::static_pointer_cast<EventPoller>(getExecutor());
}
WorkThreadPool::WorkThreadPool() {
//最低优先级 [AUTO-TRANSLATED:cd1f0dbc]
//Lowest priority
addPoller("work poller", s_pool_size, ThreadPool::PRIORITY_LOWEST, false, s_enable_cpu_affinity);
}
void WorkThreadPool::setPoolSize(size_t size) {
s_pool_size = size;
}
void WorkThreadPool::enableCpuAffinity(bool enable) {
s_enable_cpu_affinity = enable;
}
} /* namespace toolkit */

View File

@ -0,0 +1,82 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef UTIL_WORKTHREADPOOL_H_
#define UTIL_WORKTHREADPOOL_H_
#include <memory>
#include "Poller/EventPoller.h"
namespace toolkit {
class WorkThreadPool : public std::enable_shared_from_this<WorkThreadPool>, public TaskExecutorGetterImp {
public:
using Ptr = std::shared_ptr<WorkThreadPool>;
~WorkThreadPool() override = default;
/**
*
* Get the singleton instance
* [AUTO-TRANSLATED:c8852589]
*/
static WorkThreadPool &Instance();
/**
* EventPoller个数WorkThreadPool单例创建前有效
* thread::hardware_concurrency()EventPoller实例
* @param size EventPoller个数0thread::hardware_concurrency()
* Set the number of EventPoller instances, effective before the WorkThreadPool singleton is created
* If this method is not called, the default is to create thread::hardware_concurrency() EventPoller instances
* @param size The number of EventPoller instances, if 0 then use thread::hardware_concurrency()
* [AUTO-TRANSLATED:bb236d87]
*/
static void setPoolSize(size_t size = 0);
/**
* 线cpu亲和性cpu亲和性
* Whether to set CPU affinity when creating internal threads, CPU affinity is set by default
* [AUTO-TRANSLATED:46941c9f]
*/
static void enableCpuAffinity(bool enable);
/**
*
* @return
* Get the first instance
* @return
* [AUTO-TRANSLATED:a76aad3b]
*/
EventPoller::Ptr getFirstPoller();
/**
*
* 线线
* 线线
* @return
* Get a lightly loaded instance based on the load situation
* If priority is given to the current thread, it will return the current thread
* The purpose of returning the current thread is to improve thread safety
* @return
* [AUTO-TRANSLATED:1282b772]
*/
EventPoller::Ptr getPoller();
protected:
WorkThreadPool();
};
} /* namespace toolkit */
#endif /* UTIL_WORKTHREADPOOL_H_ */

View File

@ -0,0 +1,121 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SEMAPHORE_H_
#define SEMAPHORE_H_
#include <mutex>
#include <chrono>
#include <condition_variable>
namespace toolkit {
class semaphore {
public:
explicit semaphore(size_t initial = 0) {
#if defined(HAVE_SEM)
sem_init(&_sem, 0, initial);
#else
_count = initial;
#endif
}
~semaphore() {
#if defined(HAVE_SEM)
sem_destroy(&_sem);
#endif
}
void post(size_t n = 1) {
#if defined(HAVE_SEM)
while (n--) {
sem_post(&_sem);
}
#else
std::unique_lock<std::recursive_mutex> lock(_mutex);
_count += n;
if (n == 1) {
_condition.notify_one();
} else {
_condition.notify_all();
}
#endif
}
void wait() {
#if defined(HAVE_SEM)
sem_wait(&_sem);
#else
std::unique_lock<std::recursive_mutex> lock(_mutex);
while (_count == 0) {
_condition.wait(lock);
}
--_count;
#endif
}
bool wait(unsigned int timeout_ms) {
#if defined(HAVE_SEM)
struct timespec ts;
// 获取当前时间
if (clock_gettime(CLOCK_REALTIME, &ts) == -1) {
perror("clock_gettime failed");
return -1;
}
// 添加超时时间到当前时间以得到绝对时间
ts.tv_sec += timeout_ms / 1000;
ts.tv_nsec += (timeout_ms % 1000) * 1000000;
if (ts.tv_nsec >= 1000000000) {
ts.tv_sec += ts.tv_nsec / 1000000000;
ts.tv_nsec = ts.tv_nsec % 1000000000;
}
sem_timedwait(&_sem, &ts);
struct timespec ts;
if (clock_gettime(CLOCK_REALTIME, &ts) == -1) {
return false;
}
ts.tv_sec += timeout_ms / 1000;
ts.tv_nsec += (timeout_ms % 1000) * 1000000;
if (ts.tv_nsec >= 1000000000) {
ts.tv_sec += ts.tv_nsec / 1000000000;
ts.tv_nsec = ts.tv_nsec % 1000000000;
}
int result = sem_timedwait(&_sem, &ts);
return result == 0; // 成功返回true超时/失败返回false
#else
std::unique_lock<std::recursive_mutex> lock(_mutex);
auto now = std::chrono::system_clock::now();
auto waitTime = now + std::chrono::milliseconds(timeout_ms);
bool success = _condition.wait_until(lock, waitTime, [this] { return _count > 0; });
if (success) {
--_count;
}
return success;
#endif
}
private:
#if defined(HAVE_SEM)
sem_t _sem;
#else
size_t _count;
std::recursive_mutex _mutex;
std::condition_variable_any _condition;
#endif
};
} /* namespace toolkit */
#endif /* SEMAPHORE_H_ */

View File

@ -0,0 +1,85 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef THREADGROUP_H_
#define THREADGROUP_H_
#include <stdexcept>
#include <thread>
#include <unordered_map>
namespace toolkit {
class thread_group {
private:
thread_group(thread_group const &);
thread_group &operator=(thread_group const &);
public:
thread_group() {}
~thread_group() {
_threads.clear();
}
bool is_this_thread_in() {
auto thread_id = std::this_thread::get_id();
if (_thread_id == thread_id) {
return true;
}
return _threads.find(thread_id) != _threads.end();
}
bool is_thread_in(std::thread *thrd) {
if (!thrd) {
return false;
}
auto it = _threads.find(thrd->get_id());
return it != _threads.end();
}
template<typename F>
std::thread *create_thread(F &&threadfunc) {
auto thread_new = std::make_shared<std::thread>(std::forward<F>(threadfunc));
_thread_id = thread_new->get_id();
_threads[_thread_id] = thread_new;
return thread_new.get();
}
void remove_thread(std::thread *thrd) {
auto it = _threads.find(thrd->get_id());
if (it != _threads.end()) {
_threads.erase(it);
}
}
void join_all() {
if (is_this_thread_in()) {
throw std::runtime_error("Trying joining itself in thread_group");
}
for (auto &it : _threads) {
if (it.second->joinable()) {
it.second->join(); //等待线程主动退出
}
}
_threads.clear();
}
size_t size() {
return _threads.size();
}
private:
std::thread::id _thread_id;
std::unordered_map<std::thread::id, std::shared_ptr<std::thread>> _threads;
};
} /* namespace toolkit */
#endif /* THREADGROUP_H_ */

View File

@ -0,0 +1,175 @@
/**
ISC License
Copyright © 2015, Iñaki Baz Castillo <ibc@aliax.net>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
#ifndef SRC_UTIL_BYTE_H_
#define SRC_UTIL_BYTE_H_
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment (lib, "Ws2_32.lib")
#else
#include <arpa/inet.h>
#endif // defined(_WIN32)
#include <cinttypes>// PRIu64, etc
#include <cstddef>// size_t
#include <cstdint>// uint8_t, etc
namespace toolkit {
class Byte {
public:
/**
* Getters below get value in Host Byte Order.
* Setters below set value in Network Byte Order.
*/
static uint8_t Get1Byte(const uint8_t *data, size_t i);
static uint16_t Get2Bytes(const uint8_t *data, size_t i);
static uint32_t Get3Bytes(const uint8_t *data, size_t i);
static uint32_t Get4Bytes(const uint8_t *data, size_t i);
static uint64_t Get8Bytes(const uint8_t *data, size_t i);
static void Set1Byte(uint8_t *data, size_t i, uint8_t value);
static void Set2Bytes(uint8_t *data, size_t i, uint16_t value);
static void Set3Bytes(uint8_t *data, size_t i, uint32_t value);
static void Set4Bytes(uint8_t *data, size_t i, uint32_t value);
static void Set8Bytes(uint8_t *data, size_t i, uint64_t value);
static uint16_t PadTo4Bytes(uint16_t size);
static uint32_t PadTo4Bytes(uint32_t size);
static uint16_t Get2BytesLE(const uint8_t *data, size_t i);
static uint32_t Get3BytesLE(const uint8_t *data, size_t i);
static uint32_t Get4BytesLE(const uint8_t *data, size_t i);
static uint64_t Get8BytesLE(const uint8_t *data, size_t i);
static void Set2BytesLE(uint8_t *data, size_t i, uint16_t value);
static void Set3BytesLE(uint8_t *data, size_t i, uint32_t value);
static void Set4BytesLE(uint8_t *data, size_t i, uint32_t value);
static void Set8BytesLE(uint8_t *data, size_t i, uint64_t value);
};
/* Inline static methods. */
inline uint8_t Byte::Get1Byte(const uint8_t *data, size_t i) { return data[i]; }
inline uint16_t Byte::Get2Bytes(const uint8_t *data, size_t i) {
return uint16_t{data[i + 1]} | uint16_t{data[i]} << 8;
}
inline uint32_t Byte::Get3Bytes(const uint8_t *data, size_t i) {
return uint32_t{data[i + 2]} | uint32_t{data[i + 1]} << 8 | uint32_t{data[i]} << 16;
}
inline uint32_t Byte::Get4Bytes(const uint8_t *data, size_t i) {
return uint32_t{data[i + 3]} | uint32_t{data[i + 2]} << 8 | uint32_t{data[i + 1]} << 16 |
uint32_t{data[i]} << 24;
}
inline uint64_t Byte::Get8Bytes(const uint8_t *data, size_t i) {
return uint64_t{Byte::Get4Bytes(data, i)} << 32 | Byte::Get4Bytes(data, i + 4);
}
inline void Byte::Set1Byte(uint8_t *data, size_t i, uint8_t value) { data[i] = value; }
inline void Byte::Set2Bytes(uint8_t *data, size_t i, uint16_t value) {
data[i + 1] = static_cast<uint8_t>(value);
data[i] = static_cast<uint8_t>(value >> 8);
}
inline void Byte::Set3Bytes(uint8_t *data, size_t i, uint32_t value) {
data[i + 2] = static_cast<uint8_t>(value);
data[i + 1] = static_cast<uint8_t>(value >> 8);
data[i] = static_cast<uint8_t>(value >> 16);
}
inline void Byte::Set4Bytes(uint8_t *data, size_t i, uint32_t value) {
data[i + 3] = static_cast<uint8_t>(value);
data[i + 2] = static_cast<uint8_t>(value >> 8);
data[i + 1] = static_cast<uint8_t>(value >> 16);
data[i] = static_cast<uint8_t>(value >> 24);
}
inline void Byte::Set8Bytes(uint8_t *data, size_t i, uint64_t value) {
data[i + 7] = static_cast<uint8_t>(value);
data[i + 6] = static_cast<uint8_t>(value >> 8);
data[i + 5] = static_cast<uint8_t>(value >> 16);
data[i + 4] = static_cast<uint8_t>(value >> 24);
data[i + 3] = static_cast<uint8_t>(value >> 32);
data[i + 2] = static_cast<uint8_t>(value >> 40);
data[i + 1] = static_cast<uint8_t>(value >> 48);
data[i] = static_cast<uint8_t>(value >> 56);
}
inline uint16_t Byte::Get2BytesLE(const uint8_t *data, size_t i) {
return uint16_t{data[i]} | uint16_t{data[i + 1]} << 8;
}
inline uint32_t Byte::Get3BytesLE(const uint8_t *data, size_t i) {
return uint32_t{data[i]} | uint32_t{data[i + 1]} << 8 | uint32_t{data[i + 2]} << 16;
}
inline uint32_t Byte::Get4BytesLE(const uint8_t *data, size_t i) {
return uint32_t{data[i]} | uint32_t{data[i + 1]} << 8 | uint32_t{data[i + 2]} << 16 |
uint32_t{data[i + 3]} << 24;
}
inline uint64_t Byte::Get8BytesLE(const uint8_t *data, size_t i) {
return uint64_t{Byte::Get4Bytes(data, i + 4)} << 32 | Byte::Get4Bytes(data, i);
}
inline void Byte::Set2BytesLE(uint8_t *data, size_t i, uint16_t value) {
data[i] = static_cast<uint8_t>(value);
data[i + 1] = static_cast<uint8_t>(value >> 8);
}
inline void Byte::Set3BytesLE(uint8_t *data, size_t i, uint32_t value) {
data[i] = static_cast<uint8_t>(value);
data[i + 1] = static_cast<uint8_t>(value >> 8);
data[i + 2] = static_cast<uint8_t>(value >> 16);
}
inline void Byte::Set4BytesLE(uint8_t *data, size_t i, uint32_t value) {
data[i] = static_cast<uint8_t>(value);
data[i + 1] = static_cast<uint8_t>(value >> 8);
data[i + 2] = static_cast<uint8_t>(value >> 16);
data[i + 3] = static_cast<uint8_t>(value >> 24);
}
inline void Byte::Set8BytesLE(uint8_t *data, size_t i, uint64_t value) {
data[i] = static_cast<uint8_t>(value);
data[i + 1] = static_cast<uint8_t>(value >> 8);
data[i + 2] = static_cast<uint8_t>(value >> 16);
data[i + 3] = static_cast<uint8_t>(value >> 24);
data[i + 4] = static_cast<uint8_t>(value >> 32);
data[i + 5] = static_cast<uint8_t>(value >> 40);
data[i + 6] = static_cast<uint8_t>(value >> 48);
data[i + 7] = static_cast<uint8_t>(value >> 56);
}
inline uint16_t Byte::PadTo4Bytes(uint16_t size) {
// If size is not multiple of 32 bits then pad it.
if (size & 0x03)
return (size & 0xFFFC) + 4;
else
return size;
}
}// namespace toolkit
#endif //SRC_UTIL_BYTE_H_

View File

@ -0,0 +1,124 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "CMD.h"
#include "onceToken.h"
#if defined(_WIN32)
#include "win32/getopt.h"
#else
#include <getopt.h>
#endif // defined(_WIN32)
using namespace std;
namespace toolkit {
//默认注册exit/quit/help/clear命令 [AUTO-TRANSLATED:1411f05e]
//Default registration of exit/quit/help/clear commands
static onceToken s_token([]() {
REGIST_CMD(exit)
REGIST_CMD(quit)
REGIST_CMD(help)
REGIST_CMD(clear)
});
CMDRegister &CMDRegister::Instance() {
static CMDRegister instance;
return instance;
}
void OptionParser::operator()(mINI &all_args, int argc, char *argv[], const std::shared_ptr<ostream> &stream) {
vector<struct option> vec_long_opt;
string str_short_opt;
do {
struct option tmp;
for (auto &pr : _map_options) {
auto &opt = pr.second;
//long opt
tmp.name = (char *) opt._long_opt.data();
tmp.has_arg = opt._type;
tmp.flag = nullptr;
tmp.val = pr.first;
vec_long_opt.emplace_back(tmp);
//short opt
if (!opt._short_opt) {
continue;
}
str_short_opt.push_back(opt._short_opt);
switch (opt._type) {
case Option::ArgRequired: str_short_opt.append(":"); break;
case Option::ArgOptional: str_short_opt.append("::"); break;
default: break;
}
}
tmp.flag = 0;
tmp.name = 0;
tmp.has_arg = 0;
tmp.val = 0;
vec_long_opt.emplace_back(tmp);
} while (0);
static mutex s_mtx_opt;
lock_guard<mutex> lck(s_mtx_opt);
int index;
optind = 0;
opterr = 0;
while ((index = getopt_long(argc, argv, &str_short_opt[0], &vec_long_opt[0], nullptr)) != -1) {
stringstream ss;
ss << " 未识别的选项,输入\"-h\"获取帮助.";
if (index < 0xFF) {
//短参数 [AUTO-TRANSLATED:87b4c1df]
//Short parameters
auto it = _map_char_index.find(index);
if (it == _map_char_index.end()) {
throw std::invalid_argument(ss.str());
}
index = it->second;
}
auto it = _map_options.find(index);
if (it == _map_options.end()) {
throw std::invalid_argument(ss.str());
}
auto &opt = it->second;
auto pr = all_args.emplace(opt._long_opt, optarg ? optarg : "");
if (!opt(stream, pr.first->second)) {
return;
}
optarg = nullptr;
}
for (auto &pr : _map_options) {
if (pr.second._default_value && all_args.find(pr.second._long_opt) == all_args.end()) {
//有默认值,赋值默认值 [AUTO-TRANSLATED:9a82f49c]
//Has default value, assigns default value
all_args.emplace(pr.second._long_opt, *pr.second._default_value);
}
}
for (auto &pr : _map_options) {
if (pr.second._must_exist) {
if (all_args.find(pr.second._long_opt) == all_args.end()) {
stringstream ss;
ss << " 参数\"" << pr.second._long_opt << "\"必须提供,输入\"-h\"选项获取帮助";
throw std::invalid_argument(ss.str());
}
}
}
if (all_args.empty() && _map_options.size() > 1 && !_enable_empty_args) {
_helper(stream, "");
return;
}
if (_on_completed) {
_on_completed(stream, all_args);
}
}
}//namespace toolkit

View File

@ -0,0 +1,384 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SRC_UTIL_CMD_H_
#define SRC_UTIL_CMD_H_
#include <map>
#include <mutex>
#include <string>
#include <memory>
#include <vector>
#include <iostream>
#include <functional>
#include "mini.h"
namespace toolkit{
class Option {
public:
using OptionHandler = std::function<bool(const std::shared_ptr<std::ostream> &stream, const std::string &arg)>;
enum ArgType {
ArgNone = 0,//no_argument,
ArgRequired = 1,//required_argument,
ArgOptional = 2,//optional_argument
};
Option() = default;
Option(char short_opt, const char *long_opt, enum ArgType type, const char *default_value, bool must_exist,
const char *des, const OptionHandler &cb) {
_short_opt = short_opt;
_long_opt = long_opt;
_type = type;
if (type != ArgNone) {
if (default_value) {
_default_value = std::make_shared<std::string>(default_value);
}
if (!_default_value && must_exist) {
_must_exist = true;
}
}
_des = des;
_cb = cb;
}
bool operator()(const std::shared_ptr<std::ostream> &stream, const std::string &arg) {
return _cb ? _cb(stream, arg) : true;
}
private:
friend class OptionParser;
bool _must_exist = false;
char _short_opt;
enum ArgType _type;
std::string _des;
std::string _long_opt;
OptionHandler _cb;
std::shared_ptr<std::string> _default_value;
};
class OptionParser {
public:
using OptionCompleted = std::function<void(const std::shared_ptr<std::ostream> &, mINI &)>;
OptionParser(const OptionCompleted &cb = nullptr, bool enable_empty_args = true) {
_on_completed = cb;
_enable_empty_args = enable_empty_args;
_helper = Option('h', "help", Option::ArgNone, nullptr, false, "打印此信息",
[this](const std::shared_ptr<std::ostream> &stream,const std::string &arg)->bool {
static const char *argsType[] = {"无参", "有参", "选参"};
static const char *mustExist[] = {"选填", "必填"};
static std::string defaultPrefix = "默认:";
static std::string defaultNull = "null";
std::stringstream printer;
size_t maxLen_longOpt = 0;
auto maxLen_default = defaultNull.size();
for (auto &pr : _map_options) {
auto &opt = pr.second;
if (opt._long_opt.size() > maxLen_longOpt) {
maxLen_longOpt = opt._long_opt.size();
}
if (opt._default_value) {
if (opt._default_value->size() > maxLen_default) {
maxLen_default = opt._default_value->size();
}
}
}
for (auto &pr : _map_options) {
auto &opt = pr.second;
//打印短参和长参名
if (opt._short_opt) {
printer << " -" << opt._short_opt << " --" << opt._long_opt;
} else {
printer << " " << " " << " --" << opt._long_opt;
}
for (size_t i = 0; i < maxLen_longOpt - opt._long_opt.size(); ++i) {
printer << " ";
}
//打印是否有参
printer << " " << argsType[opt._type];
//打印默认参数
std::string defaultValue = defaultNull;
if (opt._default_value) {
defaultValue = *opt._default_value;
}
printer << " " << defaultPrefix << defaultValue;
for (size_t i = 0; i < maxLen_default - defaultValue.size(); ++i) {
printer << " ";
}
//打印是否必填参数
printer << " " << mustExist[opt._must_exist];
//打印描述
printer << " " << opt._des << std::endl;
}
throw std::invalid_argument(printer.str());
});
(*this) << _helper;
}
OptionParser &operator<<(Option &&option) {
int index = 0xFF + (int) _map_options.size();
if (option._short_opt) {
_map_char_index.emplace(option._short_opt, index);
}
_map_options.emplace(index, std::forward<Option>(option));
return *this;
}
OptionParser &operator<<(const Option &option) {
int index = 0xFF + (int) _map_options.size();
if (option._short_opt) {
_map_char_index.emplace(option._short_opt, index);
}
_map_options.emplace(index, option);
return *this;
}
void delOption(const char *key) {
for (auto &pr : _map_options) {
if (pr.second._long_opt == key) {
if (pr.second._short_opt) {
_map_char_index.erase(pr.second._short_opt);
}
_map_options.erase(pr.first);
break;
}
}
}
void operator ()(mINI &all_args, int argc, char *argv[], const std::shared_ptr<std::ostream> &stream);
private:
bool _enable_empty_args;
Option _helper;
std::map<char, int> _map_char_index;
std::map<int, Option> _map_options;
OptionCompleted _on_completed;
};
class CMD : public mINI {
public:
virtual ~CMD() = default;
virtual const char *description() const {
return "description";
}
void operator()(int argc, char *argv[], const std::shared_ptr<std::ostream> &stream = nullptr) {
this->clear();
std::shared_ptr<std::ostream> coutPtr(&std::cout, [](std::ostream *) {});
(*_parser)(*this, argc, argv, stream ? stream : coutPtr);
}
bool hasKey(const char *key) {
return this->find(key) != this->end();
}
std::vector<variant> splitedVal(const char *key, const char *delim = ":") {
std::vector<variant> ret;
auto &val = (*this)[key];
split(val, delim, ret);
return ret;
}
void delOption(const char *key) {
if (_parser) {
_parser->delOption(key);
}
}
protected:
std::shared_ptr<OptionParser> _parser;
private:
void split(const std::string &s, const char *delim, std::vector<variant> &ret) {
size_t last = 0;
auto index = s.find(delim, last);
while (index != std::string::npos) {
if (index - last > 0) {
ret.push_back(s.substr(last, index - last));
}
last = index + strlen(delim);
index = s.find(delim, last);
}
if (s.size() - last > 0) {
ret.push_back(s.substr(last));
}
}
};
class CMDRegister {
public:
static CMDRegister &Instance();
void clear() {
std::lock_guard<std::recursive_mutex> lck(_mtx);
_cmd_map.clear();
}
void registCMD(const char *name, const std::shared_ptr<CMD> &cmd) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
_cmd_map.emplace(name, cmd);
}
void unregistCMD(const char *name) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
_cmd_map.erase(name);
}
std::shared_ptr<CMD> operator[](const char *name) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
auto it = _cmd_map.find(name);
if (it == _cmd_map.end()) {
throw std::invalid_argument(std::string("CMD not existed: ") + name);
}
return it->second;
}
void operator()(const char *name, int argc, char *argv[], const std::shared_ptr<std::ostream> &stream = nullptr) {
auto cmd = (*this)[name];
if (!cmd) {
throw std::invalid_argument(std::string("CMD not existed: ") + name);
}
(*cmd)(argc, argv, stream);
}
void printHelp(const std::shared_ptr<std::ostream> &streamTmp = nullptr) {
auto stream = streamTmp;
if (!stream) {
stream.reset(&std::cout, [](std::ostream *) {});
}
std::lock_guard<std::recursive_mutex> lck(_mtx);
size_t maxLen = 0;
for (auto &pr : _cmd_map) {
if (pr.first.size() > maxLen) {
maxLen = pr.first.size();
}
}
for (auto &pr : _cmd_map) {
(*stream) << " " << pr.first;
for (size_t i = 0; i < maxLen - pr.first.size(); ++i) {
(*stream) << " ";
}
(*stream) << " " << pr.second->description() << std::endl;
}
}
void operator()(const std::string &line, const std::shared_ptr<std::ostream> &stream = nullptr) {
if (line.empty()) {
return;
}
std::vector<char *> argv;
size_t argc = getArgs((char *) line.data(), argv);
if (argc == 0) {
return;
}
std::string cmd = argv[0];
std::lock_guard<std::recursive_mutex> lck(_mtx);
auto it = _cmd_map.find(cmd);
if (it == _cmd_map.end()) {
std::stringstream ss;
ss << " 未识别的命令\"" << cmd << "\",输入 \"help\" 获取帮助.";
throw std::invalid_argument(ss.str());
}
(*it->second)((int) argc, &argv[0], stream);
}
private:
size_t getArgs(char *buf, std::vector<char *> &argv) {
size_t argc = 0;
bool start = false;
auto len = strlen(buf);
for (size_t i = 0; i < len; ++i) {
if (buf[i] != ' ' && buf[i] != '\t' && buf[i] != '\r' && buf[i] != '\n') {
if (!start) {
start = true;
if (argv.size() < argc + 1) {
argv.resize(argc + 1);
}
argv[argc++] = buf + i;
}
} else {
buf[i] = '\0';
start = false;
}
}
return argc;
}
private:
std::recursive_mutex _mtx;
std::map<std::string, std::shared_ptr<CMD> > _cmd_map;
};
//帮助命令(help),该命令默认已注册
class CMD_help : public CMD {
public:
CMD_help() {
_parser = std::make_shared<OptionParser>([](const std::shared_ptr<std::ostream> &stream, mINI &) {
CMDRegister::Instance().printHelp(stream);
});
}
const char *description() const override {
return "打印帮助信息";
}
};
class ExitException : public std::exception {};
//退出程序命令(exit),该命令默认已注册
class CMD_exit : public CMD {
public:
CMD_exit() {
_parser = std::make_shared<OptionParser>([](const std::shared_ptr<std::ostream> &, mINI &) {
throw ExitException();
});
}
const char *description() const override {
return "退出shell";
}
};
//退出程序命令(quit),该命令默认已注册
#define CMD_quit CMD_exit
//清空屏幕信息命令(clear),该命令默认已注册
class CMD_clear : public CMD {
public:
CMD_clear() {
_parser = std::make_shared<OptionParser>([this](const std::shared_ptr<std::ostream> &stream, mINI &args) {
clear(stream);
});
}
const char *description() const {
return "清空屏幕输出";
}
private:
void clear(const std::shared_ptr<std::ostream> &stream) {
(*stream) << "\x1b[2J\x1b[H";
stream->flush();
}
};
#define GET_CMD(name) (*(CMDRegister::Instance()[name]))
#define CMD_DO(name,...) (*(CMDRegister::Instance()[name]))(__VA_ARGS__)
#define REGIST_CMD(name) CMDRegister::Instance().registCMD(#name,std::make_shared<CMD_##name>());
}//namespace toolkit
#endif /* SRC_UTIL_CMD_H_ */

View File

@ -0,0 +1,384 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#if defined(_WIN32)
#include <io.h>
#include <direct.h>
#else
#include <dirent.h>
#include <limits.h>
#endif // WIN32
#include <sys/stat.h>
#include "File.h"
#include "util.h"
#include "logger.h"
#include "uv_errno.h"
using namespace std;
using namespace toolkit;
#if !defined(_WIN32)
#define _unlink unlink
#define _rmdir rmdir
#define _access access
#endif
#if defined(_WIN32)
int mkdir(const char *path, int mode) {
return _mkdir(path);
}
DIR *opendir(const char *name) {
char namebuf[512];
snprintf(namebuf, sizeof(namebuf), "%s\\*.*", name);
WIN32_FIND_DATAA FindData;
auto hFind = FindFirstFileA(namebuf, &FindData);
if (hFind == INVALID_HANDLE_VALUE) {
return nullptr;
}
DIR *dir = (DIR *)malloc(sizeof(DIR));
memset(dir, 0, sizeof(DIR));
dir->dd_fd = 0; // simulate return
dir->handle = hFind;
return dir;
}
struct dirent *readdir(DIR *d) {
HANDLE hFind = d->handle;
WIN32_FIND_DATAA FileData;
//fail or end
if (!FindNextFileA(hFind, &FileData)) {
return nullptr;
}
struct dirent *dir = (struct dirent *)malloc(sizeof(struct dirent) + sizeof(FileData.cFileName));
strcpy(dir->d_name, (char *)FileData.cFileName);
dir->d_reclen = (uint16_t)strlen(dir->d_name);
//check there is file or directory
if (FileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) {
dir->d_type = 2;
}
else {
dir->d_type = 1;
}
if (d->index) {
//覆盖前释放内存 [AUTO-TRANSLATED:1cb478a1]
//Release memory before covering
free(d->index);
d->index = nullptr;
}
d->index = dir;
return dir;
}
int closedir(DIR *d) {
if (!d) {
return -1;
}
//关闭句柄 [AUTO-TRANSLATED:ec4f562d]
//Close handle
if (d->handle != INVALID_HANDLE_VALUE) {
FindClose(d->handle);
d->handle = INVALID_HANDLE_VALUE;
}
//释放内存 [AUTO-TRANSLATED:0f4046dc]
//Release memory
if (d->index) {
free(d->index);
d->index = nullptr;
}
free(d);
return 0;
}
#endif // defined(_WIN32)
namespace toolkit {
FILE *File::create_file(const std::string &file, const std::string &mode) {
std::string path = file;
std::string dir;
size_t index = 1;
FILE *ret = nullptr;
while (true) {
index = path.find('/', index) + 1;
dir = path.substr(0, index);
if (dir.length() == 0) {
break;
}
if (_access(dir.data(), 0) == -1) { //access函数是查看是不是存在
if (mkdir(dir.data(), 0777) == -1) { //如果不存在就用mkdir函数来创建
TraceL << "mkdir " << dir << " failed: " << get_uv_errmsg();
}
}
}
if (path[path.size() - 1] != '/') {
ret = fopen(file.data(), mode.data());
}
return ret;
}
bool File::create_path(const std::string &file, unsigned int mod) {
std::string path = file;
std::string dir;
size_t index = 1;
while (true) {
index = path.find('/', index) + 1;
dir = path.substr(0, index);
if (dir.length() == 0) {
break;
}
if (_access(dir.data(), 0) == -1) { //access函数是查看是不是存在
if (mkdir(dir.data(), mod) == -1) { //如果不存在就用mkdir函数来创建
WarnL << "mkdir " << dir << " failed: " << get_uv_errmsg();
return false;
}
}
}
return true;
}
//判断是否为目录 [AUTO-TRANSLATED:639e15fa]
//Determine if it is a directory
bool File::is_dir(const std::string &path) {
auto dir = opendir(path.data());
if (!dir) {
return false;
}
closedir(dir);
return true;
}
//判断是否为常规文件 [AUTO-TRANSLATED:59e6b610]
//Determine if it is a regular file
bool File::fileExist(const std::string &path) {
auto fp = fopen(path.data(), "rb");
if (!fp) {
return false;
}
fclose(fp);
return true;
}
//判断是否是特殊目录 [AUTO-TRANSLATED:cda5ed9f]
//Determine if it is a special directory
bool File::is_special_dir(const std::string &path) {
return path == "." || path == "..";
}
static int delete_file_l(const std::string &path_in) {
DIR *dir;
dirent *dir_info;
auto path = path_in;
if (path.back() == '/') {
path.pop_back();
}
if (File::is_dir(path)) {
if ((dir = opendir(path.data())) == nullptr) {
return _rmdir(path.data());
}
while ((dir_info = readdir(dir)) != nullptr) {
if (File::is_special_dir(dir_info->d_name)) {
continue;
}
File::delete_file(path + "/" + dir_info->d_name);
}
auto ret = _rmdir(path.data());
closedir(dir);
return ret;
}
return remove(path.data()) ? _unlink(path.data()) : 0;
}
int File::delete_file(const std::string &path, bool del_empty_dir, bool backtrace) {
auto ret = delete_file_l(path);
if (!ret && del_empty_dir) {
// delete success
File::deleteEmptyDir(parentDir(path), backtrace);
}
return ret;
}
string File::loadFile(const std::string &path) {
FILE *fp = fopen(path.data(), "rb");
if (!fp) {
return "";
}
fseek(fp, 0, SEEK_END);
auto len = ftell(fp);
fseek(fp, 0, SEEK_SET);
string str(len, '\0');
if (len != (decltype(len))fread((char *)str.data(), 1, str.size(), fp)) {
WarnL << "fread " << path << " failed: " << get_uv_errmsg();
}
fclose(fp);
return str;
}
bool File::saveFile(const string &data, const std::string &path) {
FILE *fp = fopen(path.data(), "wb");
if (!fp) {
return false;
}
fwrite(data.data(), data.size(), 1, fp);
fclose(fp);
return true;
}
string File::parentDir(const std::string &path) {
auto parent_dir = path;
if (parent_dir.back() == '/') {
parent_dir.pop_back();
}
auto pos = parent_dir.rfind('/');
if (pos != string::npos) {
parent_dir = parent_dir.substr(0, pos + 1);
}
return parent_dir;
}
string File::absolutePath(const std::string &path, const std::string &current_path, bool can_access_parent) {
string currentPath = current_path;
if (!currentPath.empty()) {
//当前目录不为空 [AUTO-TRANSLATED:5bf272ae]
//Current directory is not empty
if (currentPath.front() == '.') {
//如果当前目录是相对路径,那么先转换成绝对路径 [AUTO-TRANSLATED:3cc6469e]
//If the current directory is a relative path, convert it to an absolute path first
currentPath = absolutePath(current_path, exeDir(), true);
}
} else {
currentPath = exeDir();
}
if (path.empty()) {
//相对路径为空,那么返回当前目录 [AUTO-TRANSLATED:6dd21c11]
//Relative path is empty, return the current directory
return currentPath;
}
if (currentPath.back() != '/') {
//确保当前目录最后字节为'/' [AUTO-TRANSLATED:fc83fcfe]
//Ensure the last byte of the current directory is '/
currentPath.push_back('/');
}
auto rootPath = currentPath;
auto dir_vec = split(path, "/");
for (auto &dir : dir_vec) {
if (dir.empty() || dir == ".") {
//忽略空或本文件夹 [AUTO-TRANSLATED:3dd69d88]
//Ignore empty or current folder
continue;
}
if (dir == "..") {
//访问上级目录 [AUTO-TRANSLATED:d3c0b980]
//Access parent directory
if (!can_access_parent && currentPath.size() <= rootPath.size()) {
//不能访问根目录之外的目录, 返回根目录 [AUTO-TRANSLATED:9d79ec25]
//Cannot access directories outside the root, return to root
return rootPath;
}
currentPath = parentDir(currentPath);
continue;
}
currentPath.append(dir);
currentPath.append("/");
}
if (path.back() != '/' && currentPath.back() == '/') {
//在路径是文件的情况下,防止转换成目录 [AUTO-TRANSLATED:db91e611]
//Prevent conversion to directory when path is a file
currentPath.pop_back();
}
return currentPath;
}
void File::scanDir(const std::string &path_in, const function<bool(const string &path, bool is_dir)> &cb,
bool enter_subdirectory, bool show_hidden_file) {
string path = path_in;
if (path.back() == '/') {
path.pop_back();
}
DIR *pDir;
dirent *pDirent;
if ((pDir = opendir(path.data())) == nullptr) {
//文件夹无效 [AUTO-TRANSLATED:ee3339ea]
//Invalid folder
return;
}
while ((pDirent = readdir(pDir)) != nullptr) {
if (is_special_dir(pDirent->d_name)) {
continue;
}
if (!show_hidden_file && pDirent->d_name[0] == '.') {
//隐藏的文件 [AUTO-TRANSLATED:3b2eb642]
//Hidden file
continue;
}
string strAbsolutePath = path + "/" + pDirent->d_name;
bool isDir = is_dir(strAbsolutePath);
if (!cb(strAbsolutePath, isDir)) {
//不再继续扫描 [AUTO-TRANSLATED:991bdb3f]
//Stop scanning
break;
}
if (isDir && enter_subdirectory) {
//如果是文件夹并且扫描子文件夹,那么递归扫描 [AUTO-TRANSLATED:36773722]
//If it's a folder and scanning subfolders, then recursively scan
scanDir(strAbsolutePath, cb, enter_subdirectory);
}
}
closedir(pDir);
}
uint64_t File::fileSize(FILE *fp, bool remain_size) {
if (!fp) {
return 0;
}
auto current = ftell64(fp);
fseek64(fp, 0L, SEEK_END); /* 定位到文件末尾 */
auto end = ftell64(fp); /* 得到文件大小 */
fseek64(fp, current, SEEK_SET);
return end - (remain_size ? current : 0);
}
uint64_t File::fileSize(const std::string &path) {
if (path.empty()) {
return 0;
}
auto fp = std::unique_ptr<FILE, decltype(&fclose)>(fopen(path.data(), "rb"), fclose);
return fileSize(fp.get());
}
static bool isEmptyDir(const std::string &path) {
bool empty = true;
File::scanDir(path,[&](const std::string &path, bool isDir) {
empty = false;
return false;
}, true, true);
return empty;
}
void File::deleteEmptyDir(const std::string &dir, bool backtrace) {
if (!File::is_dir(dir) || !isEmptyDir(dir)) {
// 不是文件夹或者非空 [AUTO-TRANSLATED:fad1712d]
//Not a folder or not empty
return;
}
File::delete_file(dir);
if (backtrace) {
deleteEmptyDir(File::parentDir(dir), true);
}
}
} /* namespace toolkit */

View File

@ -0,0 +1,206 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SRC_UTIL_FILE_H_
#define SRC_UTIL_FILE_H_
#include <cstdio>
#include <cstdlib>
#include <string>
#include "util.h"
#include <functional>
#if defined(__linux__)
#include <limits.h>
#endif
#if defined(_WIN32)
#ifndef PATH_MAX
#define PATH_MAX 1024
#endif // !PATH_MAX
struct dirent{
long d_ino; /* inode number*/
off_t d_off; /* offset to this dirent*/
unsigned short d_reclen; /* length of this d_name*/
unsigned char d_type; /* the type of d_name*/
char d_name[1]; /* file name (null-terminated)*/
};
typedef struct _dirdesc {
int dd_fd; /** file descriptor associated with directory */
long dd_loc; /** offset in current buffer */
long dd_size; /** amount of data returned by getdirentries */
char *dd_buf; /** data buffer */
int dd_len; /** size of data buffer */
long dd_seek; /** magic cookie returned by getdirentries */
HANDLE handle;
struct dirent *index;
} DIR;
# define __dirfd(dp) ((dp)->dd_fd)
int mkdir(const char *path, int mode);
DIR *opendir(const char *);
int closedir(DIR *);
struct dirent *readdir(DIR *);
#endif // defined(_WIN32)
#if defined(_WIN32) || defined(_WIN64)
#define fseek64 _fseeki64
#define ftell64 _ftelli64
#else
#define fseek64 fseek
#define ftell64 ftell
#endif
namespace toolkit {
class File {
public:
//创建路径 [AUTO-TRANSLATED:419b36b7]
//Create path
static bool create_path(const std::string &file, unsigned int mod);
//新建文件,目录文件夹自动生成 [AUTO-TRANSLATED:e605efe8]
//Create a new file, and the directory folder will be generated automatically
static FILE *create_file(const std::string &file, const std::string &mode);
//判断是否为目录 [AUTO-TRANSLATED:639e15fa]
//Determine if it is a directory
static bool is_dir(const std::string &path);
//判断是否是特殊目录(. or .. [AUTO-TRANSLATED:f61f7e33]
//Determine if it is a special directory (. or ..)
static bool is_special_dir(const std::string &path);
//删除目录或文件 [AUTO-TRANSLATED:79bed783]
//Delete a directory or file
static int delete_file(const std::string &path, bool del_empty_dir = false, bool backtrace = true);
//判断文件是否存在 [AUTO-TRANSLATED:edf3cf49]
//Determine if a file exists
static bool fileExist(const std::string &path);
/**
* string
* @param path
* @return
* Load file content to string
* @param path The path of the file to load
* @return The file content
* [AUTO-TRANSLATED:c2f0e9fa]
*/
static std::string loadFile(const std::string &path);
/**
*
* @param data
* @param path
* @return
* Save content to file
* @param data The file content
* @param path The path to save the file
* @return Whether the save was successful
* [AUTO-TRANSLATED:a919ad75]
*/
static bool saveFile(const std::string &data, const std::string &path);
/**
*
* @param path
* @return
* Get the parent folder
* @param path The path
* @return The folder
* [AUTO-TRANSLATED:3a584db5]
*/
static std::string parentDir(const std::string &path);
/**
* "../"
* @param path "../"
* @param current_path
* @param can_access_parent 访
* @return "../"
* Replace "../" and get the absolute path
* @param path The relative path, which may contain "../"
* @param current_path The current directory
* @param can_access_parent Whether it can access directories outside the parent directory
* @return The path after replacing "../"
* [AUTO-TRANSLATED:45686bfc]
*/
static std::string absolutePath(const std::string &path, const std::string &current_path, bool can_access_parent = false);
/**
*
* @param path
* @param cb path为绝对路径isDir为该路径是否为文件夹true代表继续扫描
* @param enter_subdirectory
* @param show_hidden_file
* Traverse all files under the folder
* @param path Folder path
* @param cb Callback object, path is the absolute path, isDir indicates whether the path is a folder, returns true to continue scanning, otherwise stops
* @param enter_subdirectory Whether to enter subdirectory scanning
* @param show_hidden_file Whether to display hidden files
* [AUTO-TRANSLATED:e97ab081]
*/
static void scanDir(const std::string &path, const std::function<bool(const std::string &path, bool isDir)> &cb,
bool enter_subdirectory = false, bool show_hidden_file = false);
/**
*
* @param fp
* @param remain_size true:false:
* Get file size
* @param fp File handle
* @param remain_size true: Get the remaining unread data size of the file, false: Get the total file size
* [AUTO-TRANSLATED:9abfdae9]
*/
static uint64_t fileSize(FILE *fp, bool remain_size = false);
/**
*
* @param path
* @return
* @warning
* Get file size
* @param path File path
* @return File size
* @warning The caller should ensure the file exists
* [AUTO-TRANSLATED:6985b813]
*/
static uint64_t fileSize(const std::string &path);
/**
*
* @param dir
* @param backtrace
* Attempt to delete an empty folder
* @param dir Folder path
* @param backtrace Whether to backtrack to the upper-level folder, if the upper-level folder is empty, it will also be deleted, and so on
* [AUTO-TRANSLATED:a1780506]
*/
static void deleteEmptyDir(const std::string &dir, bool backtrace = true);
private:
File();
~File();
};
} /* namespace toolkit */
#endif /* SRC_UTIL_FILE_H_ */

View File

@ -0,0 +1,218 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_LIST_H
#define ZLTOOLKIT_LIST_H
#include <list>
#include <type_traits>
namespace toolkit {
#if 0
template<typename T>
class List;
template<typename T>
class ListNode
{
public:
friend class List<T>;
~ListNode(){}
template <class... Args>
ListNode(Args&&... args):_data(std::forward<Args>(args)...){}
private:
T _data;
ListNode *next = nullptr;
};
template<typename T>
class List {
public:
using NodeType = ListNode<T>;
List(){}
List(List &&that){
swap(that);
}
~List(){
clear();
}
void clear(){
auto ptr = _front;
auto last = ptr;
while(ptr){
last = ptr;
ptr = ptr->next;
delete last;
}
_size = 0;
_front = nullptr;
_back = nullptr;
}
template<typename FUN>
void for_each(FUN &&fun) const {
auto ptr = _front;
while (ptr) {
fun(ptr->_data);
ptr = ptr->next;
}
}
size_t size() const{
return _size;
}
bool empty() const{
return _size == 0;
}
template <class... Args>
void emplace_front(Args&&... args){
NodeType *node = new NodeType(std::forward<Args>(args)...);
if(!_front){
_front = node;
_back = node;
_size = 1;
}else{
node->next = _front;
_front = node;
++_size;
}
}
template <class...Args>
void emplace_back(Args&&... args){
NodeType *node = new NodeType(std::forward<Args>(args)...);
if(!_back){
_back = node;
_front = node;
_size = 1;
}else{
_back->next = node;
_back = node;
++_size;
}
}
T &front() const{
return _front->_data;
}
T &back() const{
return _back->_data;
}
T &operator[](size_t pos){
NodeType *front = _front ;
while(pos--){
front = front->next;
}
return front->_data;
}
void pop_front(){
if(!_front){
return;
}
auto ptr = _front;
_front = _front->next;
delete ptr;
if(!_front){
_back = nullptr;
}
--_size;
}
void swap(List &other){
NodeType *tmp_node;
tmp_node = _front;
_front = other._front;
other._front = tmp_node;
tmp_node = _back;
_back = other._back;
other._back = tmp_node;
size_t tmp_size = _size;
_size = other._size;
other._size = tmp_size;
}
void append(List<T> &other){
if(other.empty()){
return;
}
if(_back){
_back->next = other._front;
_back = other._back;
}else{
_front = other._front;
_back = other._back;
}
_size += other._size;
other._front = other._back = nullptr;
other._size = 0;
}
List &operator=(const List &that) {
that.for_each([&](const T &t) {
emplace_back(t);
});
return *this;
}
private:
NodeType *_front = nullptr;
NodeType *_back = nullptr;
size_t _size = 0;
};
#else
template<typename T>
class List : public std::list<T> {
public:
template<typename ... ARGS>
List(ARGS &&...args) : std::list<T>(std::forward<ARGS>(args)...) {};
~List() = default;
void append(List<T> &other) {
if (other.empty()) {
return;
}
this->insert(this->end(), other.begin(), other.end());
other.clear();
}
template<typename FUNC>
void for_each(FUNC &&func) {
for (auto &t : *this) {
func(t);
}
}
template<typename FUNC>
void for_each(FUNC &&func) const {
for (auto &t : *this) {
func(t);
}
}
};
#endif
} /* namespace toolkit */
#endif //ZLTOOLKIT_LIST_H

View File

@ -0,0 +1,372 @@
//MD5.cpp
/* MD5
converted to C++ class by Frank Thilo (thilo@unix-ag.org)
for bzflag (http://www.bzflag.org)
based on:
md5.h and md5.c
reference implemantion of RFC 1321
Copyright (C) 1991-2, RSA Data Security, Inc. Created 1991. All
rights reserved.
Copyright (c) 2016-2019 xiongziliang <771730766@qq.com>
License to copy and use this software is granted provided that it
is identified as the "RSA Data Security, Inc. MD5 Message-Digest
Algorithm" in all material mentioning or referencing this software
or this function.
License is also granted to make and use derivative works provided
that such works are identified as "derived from the RSA Data
Security, Inc. MD5 Message-Digest Algorithm" in all material
mentioning or referencing the derived work.
RSA Data Security, Inc. makes no representations concerning either
the merchantability of this software or the suitability of this
software for any particular purpose. It is provided "as is"
without express or implied warranty of any kind.
These notices must be retained in any copies of any part of this
documentation and/or software.
*/
/* interface header */
#include "MD5.h"
/* system implementation headers */
#include <cstdio>
#include <cstring>
namespace toolkit {
// Constants for MD5Transform routine.
#define S11 7
#define S12 12
#define S13 17
#define S14 22
#define S21 5
#define S22 9
#define S23 14
#define S24 20
#define S31 4
#define S32 11
#define S33 16
#define S34 23
#define S41 6
#define S42 10
#define S43 15
#define S44 21
///////////////////////////////////////////////
// F, G, H and I are basic MD5 functions.
inline MD5::uint4 MD5::F(uint4 x, uint4 y, uint4 z) {
return (x&y) | (~x&z);
}
inline MD5::uint4 MD5::G(uint4 x, uint4 y, uint4 z) {
return (x&z) | (y&~z);
}
inline MD5::uint4 MD5::H(uint4 x, uint4 y, uint4 z) {
return x^y^z;
}
inline MD5::uint4 MD5::I(uint4 x, uint4 y, uint4 z) {
return y ^ (x | ~z);
}
// rotate_left rotates x left n bits.
inline MD5::uint4 MD5::rotate_left(uint4 x, int n) {
return (x << n) | (x >> (32-n));
}
// FF, GG, HH, and II transformations for rounds 1, 2, 3, and 4.
// Rotation is separate from addition to prevent recomputation.
inline void MD5::FF(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) {
a = rotate_left(a+ F(b,c,d) + x + ac, s) + b;
}
inline void MD5::GG(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) {
a = rotate_left(a + G(b,c,d) + x + ac, s) + b;
}
inline void MD5::HH(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) {
a = rotate_left(a + H(b,c,d) + x + ac, s) + b;
}
inline void MD5::II(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) {
a = rotate_left(a + I(b,c,d) + x + ac, s) + b;
}
//////////////////////////////////////////////
// default ctor, just initailize
MD5::MD5()
{
init();
}
//////////////////////////////////////////////
// nifty shortcut ctor, compute MD5 for string and finalize it right away
MD5::MD5(const std::string &text)
{
init();
update(text.c_str(), text.length());
finalize();
}
//////////////////////////////
void MD5::init()
{
finalized=false;
count[0] = 0;
count[1] = 0;
// load magic initialization constants.
state[0] = 0x67452301;
state[1] = 0xefcdab89;
state[2] = 0x98badcfe;
state[3] = 0x10325476;
}
//////////////////////////////
// decodes input (unsigned char) into output (uint4). Assumes len is a multiple of 4.
void MD5::decode(uint4 output[], const uint1 input[], size_type len)
{
for (unsigned int i = 0, j = 0; j < len; i++, j += 4)
output[i] = ((uint4)input[j]) | (((uint4)input[j+1]) << 8) |
(((uint4)input[j+2]) << 16) | (((uint4)input[j+3]) << 24);
}
//////////////////////////////
// encodes input (uint4) into output (unsigned char). Assumes len is
// a multiple of 4.
void MD5::encode(uint1 output[], const uint4 input[], size_type len)
{
for (size_type i = 0, j = 0; j < len; i++, j += 4) {
output[j] = input[i] & 0xff;
output[j+1] = (input[i] >> 8) & 0xff;
output[j+2] = (input[i] >> 16) & 0xff;
output[j+3] = (input[i] >> 24) & 0xff;
}
}
//////////////////////////////
// apply MD5 algo on a block
void MD5::transform(const uint1 block[blocksize])
{
uint4 a = state[0], b = state[1], c = state[2], d = state[3], x[16];
decode (x, block, blocksize);
/* Round 1 */
FF (a, b, c, d, x[ 0], S11, 0xd76aa478); /* 1 */
FF (d, a, b, c, x[ 1], S12, 0xe8c7b756); /* 2 */
FF (c, d, a, b, x[ 2], S13, 0x242070db); /* 3 */
FF (b, c, d, a, x[ 3], S14, 0xc1bdceee); /* 4 */
FF (a, b, c, d, x[ 4], S11, 0xf57c0faf); /* 5 */
FF (d, a, b, c, x[ 5], S12, 0x4787c62a); /* 6 */
FF (c, d, a, b, x[ 6], S13, 0xa8304613); /* 7 */
FF (b, c, d, a, x[ 7], S14, 0xfd469501); /* 8 */
FF (a, b, c, d, x[ 8], S11, 0x698098d8); /* 9 */
FF (d, a, b, c, x[ 9], S12, 0x8b44f7af); /* 10 */
FF (c, d, a, b, x[10], S13, 0xffff5bb1); /* 11 */
FF (b, c, d, a, x[11], S14, 0x895cd7be); /* 12 */
FF (a, b, c, d, x[12], S11, 0x6b901122); /* 13 */
FF (d, a, b, c, x[13], S12, 0xfd987193); /* 14 */
FF (c, d, a, b, x[14], S13, 0xa679438e); /* 15 */
FF (b, c, d, a, x[15], S14, 0x49b40821); /* 16 */
/* Round 2 */
GG (a, b, c, d, x[ 1], S21, 0xf61e2562); /* 17 */
GG (d, a, b, c, x[ 6], S22, 0xc040b340); /* 18 */
GG (c, d, a, b, x[11], S23, 0x265e5a51); /* 19 */
GG (b, c, d, a, x[ 0], S24, 0xe9b6c7aa); /* 20 */
GG (a, b, c, d, x[ 5], S21, 0xd62f105d); /* 21 */
GG (d, a, b, c, x[10], S22, 0x2441453); /* 22 */
GG (c, d, a, b, x[15], S23, 0xd8a1e681); /* 23 */
GG (b, c, d, a, x[ 4], S24, 0xe7d3fbc8); /* 24 */
GG (a, b, c, d, x[ 9], S21, 0x21e1cde6); /* 25 */
GG (d, a, b, c, x[14], S22, 0xc33707d6); /* 26 */
GG (c, d, a, b, x[ 3], S23, 0xf4d50d87); /* 27 */
GG (b, c, d, a, x[ 8], S24, 0x455a14ed); /* 28 */
GG (a, b, c, d, x[13], S21, 0xa9e3e905); /* 29 */
GG (d, a, b, c, x[ 2], S22, 0xfcefa3f8); /* 30 */
GG (c, d, a, b, x[ 7], S23, 0x676f02d9); /* 31 */
GG (b, c, d, a, x[12], S24, 0x8d2a4c8a); /* 32 */
/* Round 3 */
HH (a, b, c, d, x[ 5], S31, 0xfffa3942); /* 33 */
HH (d, a, b, c, x[ 8], S32, 0x8771f681); /* 34 */
HH (c, d, a, b, x[11], S33, 0x6d9d6122); /* 35 */
HH (b, c, d, a, x[14], S34, 0xfde5380c); /* 36 */
HH (a, b, c, d, x[ 1], S31, 0xa4beea44); /* 37 */
HH (d, a, b, c, x[ 4], S32, 0x4bdecfa9); /* 38 */
HH (c, d, a, b, x[ 7], S33, 0xf6bb4b60); /* 39 */
HH (b, c, d, a, x[10], S34, 0xbebfbc70); /* 40 */
HH (a, b, c, d, x[13], S31, 0x289b7ec6); /* 41 */
HH (d, a, b, c, x[ 0], S32, 0xeaa127fa); /* 42 */
HH (c, d, a, b, x[ 3], S33, 0xd4ef3085); /* 43 */
HH (b, c, d, a, x[ 6], S34, 0x4881d05); /* 44 */
HH (a, b, c, d, x[ 9], S31, 0xd9d4d039); /* 45 */
HH (d, a, b, c, x[12], S32, 0xe6db99e5); /* 46 */
HH (c, d, a, b, x[15], S33, 0x1fa27cf8); /* 47 */
HH (b, c, d, a, x[ 2], S34, 0xc4ac5665); /* 48 */
/* Round 4 */
II (a, b, c, d, x[ 0], S41, 0xf4292244); /* 49 */
II (d, a, b, c, x[ 7], S42, 0x432aff97); /* 50 */
II (c, d, a, b, x[14], S43, 0xab9423a7); /* 51 */
II (b, c, d, a, x[ 5], S44, 0xfc93a039); /* 52 */
II (a, b, c, d, x[12], S41, 0x655b59c3); /* 53 */
II (d, a, b, c, x[ 3], S42, 0x8f0ccc92); /* 54 */
II (c, d, a, b, x[10], S43, 0xffeff47d); /* 55 */
II (b, c, d, a, x[ 1], S44, 0x85845dd1); /* 56 */
II (a, b, c, d, x[ 8], S41, 0x6fa87e4f); /* 57 */
II (d, a, b, c, x[15], S42, 0xfe2ce6e0); /* 58 */
II (c, d, a, b, x[ 6], S43, 0xa3014314); /* 59 */
II (b, c, d, a, x[13], S44, 0x4e0811a1); /* 60 */
II (a, b, c, d, x[ 4], S41, 0xf7537e82); /* 61 */
II (d, a, b, c, x[11], S42, 0xbd3af235); /* 62 */
II (c, d, a, b, x[ 2], S43, 0x2ad7d2bb); /* 63 */
II (b, c, d, a, x[ 9], S44, 0xeb86d391); /* 64 */
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
// Zeroize sensitive information.
memset(x, 0, sizeof x);
}
//////////////////////////////
// MD5 block update operation. Continues an MD5 message-digest
// operation, processing another message block
void MD5::update(const unsigned char input[], size_type length)
{
// compute number of bytes mod 64
size_type index = count[0] / 8 % blocksize;
// Update number of bits
if ((count[0] += (length << 3)) < (length << 3))
count[1]++;
count[1] += (length >> 29);
// number of bytes we need to fill in buffer
size_type firstpart = 64 - index;
size_type i;
// transform as many times as possible.
if (length >= firstpart)
{
// fill buffer first, transform
memcpy(&buffer[index], input, firstpart);
transform(buffer);
// transform chunks of blocksize (64 bytes)
for (i = firstpart; i + blocksize <= length; i += blocksize)
transform(&input[i]);
index = 0;
}
else
i = 0;
// buffer remaining input
memcpy(&buffer[index], &input[i], length-i);
}
//////////////////////////////
// for convenience provide a verson with signed char
void MD5::update(const char input[], size_type length)
{
update((const unsigned char*)input, length);
}
//////////////////////////////
// MD5 finalization. Ends an MD5 message-digest operation, writing the
// the message digest and zeroizing the context.
MD5& MD5::finalize()
{
static unsigned char padding[64] = {
0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};
if (!finalized) {
// Save number of bits
unsigned char bits[8];
encode(bits, count, 8);
// pad out to 56 mod 64.
size_type index = count[0] / 8 % 64;
size_type padLen = (index < 56) ? (56 - index) : (120 - index);
update(padding, padLen);
// Append length (before padding)
update(bits, 8);
// Store state in digest
encode(digest, state, 16);
// Zeroize sensitive information.
memset(buffer, 0, sizeof buffer);
memset(count, 0, sizeof count);
finalized=true;
}
return *this;
}
//////////////////////////////
// return hex representation of digest as string
std::string MD5::hexdigest() const
{
if (!finalized)
return "";
char buf[33];
for (int i=0; i<16; i++)
sprintf(buf+i*2, "%02x", digest[i]);
buf[32]=0;
return std::string(buf);
}
std::string MD5::rawdigest() const{
return std::string((char *)digest, sizeof(digest));
}
//////////////////////////////
std::ostream& operator<<(std::ostream& out, MD5 md5)
{
return out << md5.hexdigest();
}
//////////////////////////////
std::string md5(const std::string str)
{
MD5 md5 = MD5(str);
return md5.hexdigest();
}
} /* namespace toolkit */

View File

@ -0,0 +1,97 @@
//MD5.cpp
/* MD5
converted to C++ class by Frank Thilo (thilo@unix-ag.org)
for bzflag (http://www.bzflag.org)
based on:
md5.h and md5.c
reference implemantion of RFC 1321
Copyright (C) 1991-2, RSA Data Security, Inc. Created 1991. All
rights reserved.
License to copy and use this software is granted provided that it
is identified as the "RSA Data Security, Inc. MD5 Message-Digest
Algorithm" in all material mentioning or referencing this software
or this function.
License is also granted to make and use derivative works provided
that such works are identified as "derived from the RSA Data
Security, Inc. MD5 Message-Digest Algorithm" in all material
mentioning or referencing the derived work.
RSA Data Security, Inc. makes no representations concerning either
the merchantability of this software or the suitability of this
software for any particular purpose. It is provided "as is"
without express or implied warranty of any kind.
These notices must be retained in any copies of any part of this
documentation and/or software.
*/
#ifndef SRC_UTIL_MD5_H_
#define SRC_UTIL_MD5_H_
#include <string>
#include <iostream>
#include <cstdint>
namespace toolkit {
// a small class for calculating MD5 hashes of strings or byte arrays
// it is not meant to be fast or secure
//
// usage: 1) feed it blocks of uchars with update()
// 2) finalize()
// 3) get hexdigest() string
// or
// MD5(std::string).hexdigest()
//
// assumes that char is 8 bit and int is 32 bit
class MD5
{
public:
typedef unsigned int size_type; // must be 32bit
MD5();
MD5(const std::string& text);
void update(const unsigned char *buf, size_type length);
void update(const char *buf, size_type length);
MD5& finalize();
std::string hexdigest() const;
std::string rawdigest() const;
friend std::ostream& operator<<(std::ostream&, MD5 md5);
private:
void init();
typedef uint8_t uint1; // 8bit
typedef uint32_t uint4; // 32bit
enum {blocksize = 64}; // VC6 won't eat a const static int here
void transform(const uint1 block[blocksize]);
static void decode(uint4 output[], const uint1 input[], size_type len);
static void encode(uint1 output[], const uint4 input[], size_type len);
bool finalized;
uint1 buffer[blocksize]; // bytes that didn't fit in last 64 byte chunk
uint4 count[2]; // 64bit counter for number of bits (lo, hi)
uint4 state[4]; // digest so far
uint1 digest[16]; // the result
// low level logic operations
static inline uint4 F(uint4 x, uint4 y, uint4 z);
static inline uint4 G(uint4 x, uint4 y, uint4 z);
static inline uint4 H(uint4 x, uint4 y, uint4 z);
static inline uint4 I(uint4 x, uint4 y, uint4 z);
static inline uint4 rotate_left(uint4 x, int n);
static inline void FF(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac);
static inline void GG(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac);
static inline void HH(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac);
static inline void II(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac);
};
} /* namespace toolkit */
#endif /* SRC_UTIL_MD5_H_ */

View File

@ -0,0 +1,19 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "util.h"
#include "NoticeCenter.h"
namespace toolkit {
INSTANCE_IMP(NoticeCenter)
} /* namespace toolkit */

View File

@ -0,0 +1,206 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SRC_UTIL_NOTICECENTER_H_
#define SRC_UTIL_NOTICECENTER_H_
#include <mutex>
#include <memory>
#include <string>
#include <exception>
#include <functional>
#include <unordered_map>
#include <stdexcept>
#include "util.h"
#include "function_traits.h"
namespace toolkit {
class EventDispatcher {
public:
friend class NoticeCenter;
using Ptr = std::shared_ptr<EventDispatcher>;
~EventDispatcher() = default;
private:
using MapType = std::unordered_multimap<void *, Any>;
EventDispatcher() = default;
class InterruptException : public std::runtime_error {
public:
InterruptException() : std::runtime_error("InterruptException") {}
~InterruptException() {}
};
template <typename... ArgsType>
int emitEvent(bool safe, ArgsType &&...args) {
using stl_func = std::function<void(decltype(std::forward<ArgsType>(args))...)>;
decltype(_mapListener) copy;
{
// 先拷贝(开销比较小),目的是防止在触发回调时还是上锁状态从而导致交叉互锁 [AUTO-TRANSLATED:62bff466]
//First copy (lower overhead), to prevent cross-locking when triggering callbacks while still locked
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
copy = _mapListener;
}
int ret = 0;
for (auto &pr : copy) {
try {
pr.second.get<stl_func>(safe)(std::forward<ArgsType>(args)...);
++ret;
} catch (InterruptException &) {
++ret;
break;
}
}
return ret;
}
template <typename FUNC>
void addListener(void *tag, FUNC &&func) {
using stl_func = typename function_traits<typename std::remove_reference<FUNC>::type>::stl_function_type;
Any listener;
listener.set<stl_func>(std::forward<FUNC>(func));
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
_mapListener.emplace(tag, listener);
}
void delListener(void *tag, bool &empty) {
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
_mapListener.erase(tag);
empty = _mapListener.empty();
}
private:
std::recursive_mutex _mtxListener;
MapType _mapListener;
};
class NoticeCenter : public std::enable_shared_from_this<NoticeCenter> {
public:
using Ptr = std::shared_ptr<NoticeCenter>;
static NoticeCenter &Instance();
template <typename... ArgsType>
int emitEvent(const std::string &event, ArgsType &&...args) {
return emitEvent_l(false, event, std::forward<ArgsType>(args)...);
}
template <typename... ArgsType>
int emitEventSafe(const std::string &event, ArgsType &&...args) {
return emitEvent_l(true, event, std::forward<ArgsType>(args)...);
}
template <typename FUNC>
void addListener(void *tag, const std::string &event, FUNC &&func) {
getDispatcher(event, true)->addListener(tag, std::forward<FUNC>(func));
}
void delListener(void *tag, const std::string &event) {
auto dispatcher = getDispatcher(event);
if (!dispatcher) {
// 不存在该事件 [AUTO-TRANSLATED:d9014749]
//This event does not exist
return;
}
bool empty;
dispatcher->delListener(tag, empty);
if (empty) {
delDispatcher(event, dispatcher);
}
}
// 这个方法性能比较差 [AUTO-TRANSLATED:71ea304b]
//This method has poor performance
void delListener(void *tag) {
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
bool empty;
for (auto it = _mapListener.begin(); it != _mapListener.end();) {
it->second->delListener(tag, empty);
if (empty) {
it = _mapListener.erase(it);
continue;
}
++it;
}
}
void clearAll() {
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
_mapListener.clear();
}
private:
template <typename... ArgsType>
int emitEvent_l(bool safe, const std::string &event, ArgsType &&...args) {
auto dispatcher = getDispatcher(event);
if (!dispatcher) {
// 该事件无人监听 [AUTO-TRANSLATED:9196cf42]
//No one is listening to this event
return 0;
}
return dispatcher->emitEvent(safe, std::forward<ArgsType>(args)...);
}
EventDispatcher::Ptr getDispatcher(const std::string &event, bool create = false) {
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
auto it = _mapListener.find(event);
if (it != _mapListener.end()) {
return it->second;
}
if (create) {
// 如果为空则创建一个 [AUTO-TRANSLATED:8412a9ae]
//Create one if it is empty
EventDispatcher::Ptr dispatcher(new EventDispatcher());
_mapListener.emplace(event, dispatcher);
return dispatcher;
}
return nullptr;
}
void delDispatcher(const std::string &event, const EventDispatcher::Ptr &dispatcher) {
std::lock_guard<std::recursive_mutex> lck(_mtxListener);
auto it = _mapListener.find(event);
if (it != _mapListener.end() && dispatcher == it->second) {
// 两者相同则删除 [AUTO-TRANSLATED:8d84179d]
//If both are the same, delete it
_mapListener.erase(it);
}
}
private:
std::recursive_mutex _mtxListener;
std::unordered_map<std::string, EventDispatcher::Ptr> _mapListener;
};
template <typename T>
struct NoticeHelper;
template <typename RET, typename... Args>
struct NoticeHelper<RET(Args...)> {
public:
template <typename... ArgsType>
static int emit(const std::string &event, ArgsType &&...args) {
return emit(NoticeCenter::Instance(), event, std::forward<ArgsType>(args)...);
}
template <typename... ArgsType>
static int emit(NoticeCenter &center, const std::string &event, ArgsType &&...args) {
return center.emitEventSafe(event, std::forward<Args>(args)...);
}
};
#define NOTICE_EMIT(types, ...) NoticeHelper<void(types)>::emit(__VA_ARGS__)
} /* namespace toolkit */
#endif /* SRC_UTIL_NOTICECENTER_H_ */

View File

@ -0,0 +1,227 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef UTIL_RECYCLEPOOL_H_
#define UTIL_RECYCLEPOOL_H_
#include "List.h"
#include <atomic>
#include <deque>
#include <functional>
#include <memory>
#include <mutex>
#include <unordered_set>
namespace toolkit {
#if (defined(__GNUC__) && (__GNUC__ >= 5 || (__GNUC__ >= 4 && __GNUC_MINOR__ >= 9))) || defined(__clang__) \
|| !defined(__GNUC__)
#define SUPPORT_DYNAMIC_TEMPLATE
#endif
template <typename C>
class ResourcePool_l;
template <typename C>
class ResourcePool;
template <typename C>
class shared_ptr_imp : public std::shared_ptr<C> {
public:
shared_ptr_imp() {}
/**
*
* @param ptr
* @param weakPool
* @param quit 使
* Constructs a smart pointer
* @param ptr Raw pointer
* @param weakPool Circular pool managing this pointer
* @param quit Whether to give up circular reuse
* [AUTO-TRANSLATED:5af6d6a5]
*/
shared_ptr_imp(
C *ptr, const std::weak_ptr<ResourcePool_l<C>> &weakPool, std::shared_ptr<std::atomic_bool> quit,
const std::function<void(C *)> &on_recycle);
/**
* 使
* @param flag
* Abandon or recover to continue using in the circular pool
* @param flag
* [AUTO-TRANSLATED:eda3e499]
*/
void quit(bool flag = true) {
if (_quit) {
*_quit = flag;
}
}
private:
std::shared_ptr<std::atomic_bool> _quit;
};
template <typename C>
class ResourcePool_l : public std::enable_shared_from_this<ResourcePool_l<C>> {
public:
using ValuePtr = shared_ptr_imp<C>;
friend class shared_ptr_imp<C>;
friend class ResourcePool<C>;
ResourcePool_l() {
_alloc = []() -> C * { return new C(); };
}
#if defined(SUPPORT_DYNAMIC_TEMPLATE)
template <typename... ArgTypes>
ResourcePool_l(ArgTypes &&...args) {
_alloc = [args...]() -> C * { return new C(args...); };
}
#endif // defined(SUPPORT_DYNAMIC_TEMPLATE)
~ResourcePool_l() {
for (auto ptr : _objs) {
delete ptr;
}
}
void setSize(size_t size) {
_pool_size = size;
_objs.reserve(size);
}
ValuePtr obtain(const std::function<void(C *)> &on_recycle = nullptr) {
return ValuePtr(getPtr(), _weak_self, std::make_shared<std::atomic_bool>(false), on_recycle);
}
std::shared_ptr<C> obtain2() {
auto weak_self = _weak_self;
return std::shared_ptr<C>(getPtr(), [weak_self](C *ptr) {
auto strongPool = weak_self.lock();
if (strongPool) {
//放入循环池 [AUTO-TRANSLATED:5ec73a78]
//Put into circular pool
strongPool->recycle(ptr);
} else {
delete ptr;
}
});
}
private:
void recycle(C *obj) {
auto is_busy = _busy.test_and_set();
if (!is_busy) {
//获取到锁 [AUTO-TRANSLATED:6eb7c6e9]
//Acquired lock
if (_objs.size() >= _pool_size) {
delete obj;
} else {
_objs.emplace_back(obj);
}
_busy.clear();
} else {
//未获取到锁 [AUTO-TRANSLATED:2b5e8adb]
//Failed to acquire lock
delete obj;
}
}
C *getPtr() {
C *ptr;
auto is_busy = _busy.test_and_set();
if (!is_busy) {
//获取到锁 [AUTO-TRANSLATED:6eb7c6e9]
//Acquired lock
if (_objs.size() == 0) {
ptr = _alloc();
} else {
ptr = _objs.back();
_objs.pop_back();
}
_busy.clear();
} else {
//未获取到锁 [AUTO-TRANSLATED:2b5e8adb]
//Failed to acquire lock
ptr = _alloc();
}
return ptr;
}
void setup() { _weak_self = this->shared_from_this(); }
private:
size_t _pool_size = 8;
std::vector<C *> _objs;
std::function<C *(void)> _alloc;
std::atomic_flag _busy { false };
std::weak_ptr<ResourcePool_l> _weak_self;
};
/**
* enable_shared_from_this
* @tparam C
* Circular pool, note that objects in the circular pool cannot inherit from enable_shared_from_this!
* @tparam C
* [AUTO-TRANSLATED:e08caac8]
*/
template <typename C>
class ResourcePool {
public:
using ValuePtr = shared_ptr_imp<C>;
ResourcePool() {
pool.reset(new ResourcePool_l<C>());
pool->setup();
}
#if defined(SUPPORT_DYNAMIC_TEMPLATE)
template <typename... ArgTypes>
ResourcePool(ArgTypes &&...args) {
pool = std::make_shared<ResourcePool_l<C>>(std::forward<ArgTypes>(args)...);
pool->setup();
}
#endif // defined(SUPPORT_DYNAMIC_TEMPLATE)
void setSize(size_t size) { pool->setSize(size); }
//获取一个对象,性能差些,但是功能丰富些 [AUTO-TRANSLATED:88b9a207]
//Get an object, performance is slightly worse, but with more features
ValuePtr obtain(const std::function<void(C *)> &on_recycle = nullptr) { return pool->obtain(on_recycle); }
//获取一个对象,性能好些 [AUTO-TRANSLATED:0032c7ca]
//Get an object, performance is slightly better
std::shared_ptr<C> obtain2() { return pool->obtain2(); }
private:
std::shared_ptr<ResourcePool_l<C>> pool;
};
template<typename C>
shared_ptr_imp<C>::shared_ptr_imp(C *ptr,
const std::weak_ptr<ResourcePool_l<C> > &weakPool,
std::shared_ptr<std::atomic_bool> quit,
const std::function<void(C *)> &on_recycle) :
std::shared_ptr<C>(ptr, [weakPool, quit, on_recycle](C *ptr) {
if (on_recycle) {
on_recycle(ptr);
}
auto strongPool = weakPool.lock();
if (strongPool && !(*quit)) {
//循环池还在并且不放弃放入循环池 [AUTO-TRANSLATED:96e856da]
//Loop pool is still in and does not give up putting into loop pool
strongPool->recycle(ptr);
} else {
delete ptr;
}
}), _quit(std::move(quit)) {}
} /* namespace toolkit */
#endif /* UTIL_RECYCLEPOOL_H_ */

View File

@ -0,0 +1,507 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef UTIL_RINGBUFFER_H_
#define UTIL_RINGBUFFER_H_
#include <assert.h>
#include <atomic>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <condition_variable>
#include <functional>
#include "util.h"
#include "List.h"
#include "Poller/EventPoller.h"
// GOP缓存最大长度下限值 [AUTO-TRANSLATED:63162058]
//GOP cache minimum length lower bound value
#define RING_MIN_SIZE 32
#define LOCK_GUARD(mtx) std::lock_guard<decltype(mtx)> lck(mtx)
namespace toolkit {
template <typename T>
class RingDelegate {
public:
using Ptr = std::shared_ptr<RingDelegate>;
RingDelegate() = default;
virtual ~RingDelegate() = default;
virtual void onWrite(T in, bool is_key = true) = 0;
};
template <typename T>
class _RingStorage;
template <typename T>
class _RingReaderDispatcher;
/**
*
* poller线程中执行
*
* poller线程中执行
* Circular cache reader
* All events triggered by this object will be executed in the bound poller thread
* So the lock is removed
* All operations on this object should be executed in the poller thread
* [AUTO-TRANSLATED:3d0f773d]
*/
template <typename T>
class _RingReader {
public:
using Ptr = std::shared_ptr<_RingReader>;
friend class _RingReaderDispatcher<T>;
_RingReader(std::shared_ptr<_RingStorage<T>> storage) {
_storage = std::move(storage);
setReadCB(nullptr);
setDetachCB(nullptr);
setGetInfoCB(nullptr);
setMessageCB(nullptr);
}
~_RingReader() = default;
void setReadCB(std::function<void(const T &)> cb) {
if (!cb) {
_read_cb = [](const T &) {};
} else {
_read_cb = std::move(cb);
flushGop();
}
}
void setDetachCB(std::function<void()> cb) {
_detach_cb = cb ? std::move(cb) : []() {};
}
void setGetInfoCB(std::function<Any()> cb) {
_info_cb = cb ? std::move(cb) : []() { return Any(); };
}
void setMessageCB(std::function<void(const Any &data)> cb) {
_msg_cb = cb ? std::move(cb) : [](const Any &data) {};
}
void flushGop() {
if (!_storage) {
return;
}
_storage->getCache().for_each([this](const List<std::pair<bool, T>> &lst) {
lst.for_each([this](const std::pair<bool, T> &pr) { onRead(pr.second, pr.first); });
});
}
private:
void onRead(const T &data, bool /*is_key*/) { _read_cb(data); }
void onMessage(const Any &data) { _msg_cb(data); }
void onDetach() const { _detach_cb(); }
Any getInfo() { return _info_cb(); }
private:
std::shared_ptr<_RingStorage<T>> _storage;
std::function<void(void)> _detach_cb;
std::function<void(const T &)> _read_cb;
std::function<Any()> _info_cb;
std::function<void(const Any &data)> _msg_cb;
};
template <typename T>
class _RingStorage {
public:
using Ptr = std::shared_ptr<_RingStorage>;
using GopType = List<List<std::pair<bool, T>>>;
_RingStorage(size_t max_size, size_t max_gop_size) {
// gop缓存个数不能小于32 [AUTO-TRANSLATED:63d52404]
//The number of GOP caches cannot be less than 32
if (max_size < RING_MIN_SIZE) {
max_size = RING_MIN_SIZE;
}
_max_size = max_size;
_max_gop_size = max_gop_size;
clearCache();
}
~_RingStorage() = default;
/**
*
* @param in
* @param is_key
* @return
* Write data to the circular cache
* @param in Data
* @param is_key Whether it is a key frame
* @return Whether to trigger a reset of the circular cache size
* [AUTO-TRANSLATED:8ccedd1d]
*/
void write(T in, bool is_key = true) {
if (is_key) {
_have_idr = true;
_started = true;
if (!_data_cache.back().empty()) {
//当前gop列队还没收到任意缓存 [AUTO-TRANSLATED:81e257d0]
//The current GOP queue has not received any cache
_data_cache.emplace_back();
}
if (_data_cache.size() > _max_gop_size) {
// GOP个数超过限制那么移除最早的GOP [AUTO-TRANSLATED:054ad5e4]
//The number of GOPs exceeds the limit, so remove the earliest GOP
popFrontGop();
}
}
if (!_have_idr && _started) {
//缓存中没有关键帧那么gop缓存无效 [AUTO-TRANSLATED:394a9170]
//There is no key frame in the cache, so the GOP cache is invalid
return;
}
_data_cache.back().emplace_back(std::make_pair(is_key, std::move(in)));
if (++_size > _max_size) {
// GOP缓存溢出 [AUTO-TRANSLATED:1cd0ddc4]
//GOP cache overflow
while (_data_cache.size() > 1) {
//先尝试清除老的GOP缓存 [AUTO-TRANSLATED:a01422a1]
//Try to clear the old GOP cache first
popFrontGop();
}
if (_size > _max_size) {
//还是大于最大缓冲限制那么清空所有GOP [AUTO-TRANSLATED:dec7aa9b]
//Still greater than the maximum buffer limit, so clear all GOPs
clearCache();
}
}
}
Ptr clone() const {
Ptr ret(new _RingStorage());
ret->_size = _size;
ret->_have_idr = _have_idr;
ret->_started = _started;
ret->_max_size = _max_size;
ret->_max_gop_size = _max_gop_size;
ret->_data_cache = _data_cache;
return ret;
}
const GopType &getCache() const { return _data_cache; }
void clearCache() {
_size = 0;
_have_idr = false;
_data_cache.clear();
_data_cache.emplace_back();
}
private:
_RingStorage() = default;
void popFrontGop() {
if (!_data_cache.empty()) {
_size -= _data_cache.front().size();
_data_cache.pop_front();
if (_data_cache.empty()) {
_data_cache.emplace_back();
}
}
}
private:
bool _started = false;
bool _have_idr;
size_t _size;
size_t _max_size;
size_t _max_gop_size;
GopType _data_cache;
};
template <typename T>
class RingBuffer;
/**
* poller线程操作它
* @tparam T
* Ring buffer event dispatcher, can only be operated by one poller thread
* @tparam T
* [AUTO-TRANSLATED:6c0d8449]
*/
template <typename T>
class _RingReaderDispatcher : public std::enable_shared_from_this<_RingReaderDispatcher<T>> {
public:
using Ptr = std::shared_ptr<_RingReaderDispatcher>;
using RingReader = _RingReader<T>;
using RingStorage = _RingStorage<T>;
using onChangeInfoCB = std::function<Any(Any &&info)>;
friend class RingBuffer<T>;
~_RingReaderDispatcher() {
decltype(_reader_map) reader_map;
reader_map.swap(_reader_map);
for (auto &pr : reader_map) {
auto reader = pr.second.lock();
if (reader) {
reader->onDetach();
}
}
}
private:
_RingReaderDispatcher(
const typename RingStorage::Ptr &storage, std::function<void(int, bool)> onSizeChanged) {
_reader_size = 0;
_storage = storage;
_on_size_changed = std::move(onSizeChanged);
assert(_on_size_changed);
}
void write(T in, bool is_key = true) {
for (auto it = _reader_map.begin(); it != _reader_map.end();) {
auto reader = it->second.lock();
if (!reader) {
it = _reader_map.erase(it);
--_reader_size;
onSizeChanged(false);
continue;
}
reader->onRead(in, is_key);
++it;
}
_storage->write(std::move(in), is_key);
}
void sendMessage(const Any &data) {
for (auto it = _reader_map.begin(); it != _reader_map.end();) {
auto reader = it->second.lock();
if (!reader) {
it = _reader_map.erase(it);
--_reader_size;
onSizeChanged(false);
continue;
}
reader->onMessage(data);
++it;
}
}
std::shared_ptr<RingReader> attach(const EventPoller::Ptr &poller, bool use_cache) {
if (!poller->isCurrentThread()) {
throw std::runtime_error("You can attach RingBuffer only in it's poller thread");
}
std::weak_ptr<_RingReaderDispatcher> weak_self = this->shared_from_this();
auto on_dealloc = [weak_self, poller](RingReader *ptr) {
poller->async([weak_self, ptr]() {
auto strong_self = weak_self.lock();
if (strong_self && strong_self->_reader_map.erase(ptr)) {
--strong_self->_reader_size;
strong_self->onSizeChanged(false);
}
delete ptr;
});
};
std::shared_ptr<RingReader> reader(new RingReader(use_cache ? _storage : nullptr), on_dealloc);
_reader_map[reader.get()] = reader;
++_reader_size;
onSizeChanged(true);
return reader;
}
void onSizeChanged(bool add_flag) { _on_size_changed(_reader_size, add_flag); }
void clearCache() {
if (_reader_size == 0) {
_storage->clearCache();
}
}
std::list<Any> getInfoList(const onChangeInfoCB &on_change) {
std::list<Any> ret;
for (auto &pr : _reader_map) {
auto reader = pr.second.lock();
if (!reader) {
continue;
}
auto info = reader->getInfo();
if (!info) {
continue;
}
ret.emplace_back(on_change(std::move(info)));
}
return ret;
}
private:
std::atomic_int _reader_size;
std::function<void(int, bool)> _on_size_changed;
typename RingStorage::Ptr _storage;
std::unordered_map<void *, std::weak_ptr<RingReader>> _reader_map;
};
template <typename T>
class RingBuffer : public std::enable_shared_from_this<RingBuffer<T>> {
public:
using Ptr = std::shared_ptr<RingBuffer>;
using RingReader = _RingReader<T>;
using RingStorage = _RingStorage<T>;
using RingReaderDispatcher = _RingReaderDispatcher<T>;
using onReaderChanged = std::function<void(int size)>;
using onGetInfoCB = std::function<void(std::list<Any> &info_list)>;
RingBuffer(size_t max_size = 1024, onReaderChanged cb = nullptr, size_t max_gop_size = 1) {
_storage = std::make_shared<RingStorage>(max_size, max_gop_size);
_on_reader_changed = cb ? std::move(cb) : [](int size) {};
//先触发无人观看 [AUTO-TRANSLATED:34c64fef]
//First trigger no one watching
_on_reader_changed(0);
}
~RingBuffer() = default;
void write(T in, bool is_key = true) {
if (_delegate) {
_delegate->onWrite(std::move(in), is_key);
return;
}
LOCK_GUARD(_mtx_map);
for (auto &pr : _dispatcher_map) {
auto &second = pr.second;
//切换线程后触发onRead事件 [AUTO-TRANSLATED:4ca6647d]
//Switch thread and trigger onRead event
pr.first->async([second, in, is_key]() mutable { second->write(std::move(in), is_key); }, false);
}
_storage->write(std::move(in), is_key);
}
void sendMessage(const Any &data) {
LOCK_GUARD(_mtx_map);
for (auto &pr : _dispatcher_map) {
auto &second = pr.second;
// 切换线程后触发sendMessage [AUTO-TRANSLATED:350138c9]
//Switch thread and trigger sendMessage
pr.first->async([second, data]() { second->sendMessage(data); }, false);
}
}
void setDelegate(const typename RingDelegate<T>::Ptr &delegate) { _delegate = delegate; }
std::shared_ptr<RingReader> attach(const EventPoller::Ptr &poller, bool use_cache = true) {
typename RingReaderDispatcher::Ptr dispatcher;
{
LOCK_GUARD(_mtx_map);
auto &ref = _dispatcher_map[poller];
if (!ref) {
std::weak_ptr<RingBuffer> weak_self = this->shared_from_this();
auto onSizeChanged = [weak_self, poller](int size, bool add_flag) {
if (auto strong_self = weak_self.lock()) {
strong_self->onSizeChanged(poller, size, add_flag);
}
};
auto onDealloc = [poller](RingReaderDispatcher *ptr) { poller->async([ptr]() { delete ptr; }); };
ref.reset(new RingReaderDispatcher(_storage->clone(), std::move(onSizeChanged)), std::move(onDealloc));
}
dispatcher = ref;
}
return dispatcher->attach(poller, use_cache);
}
int readerCount() { return _total_count; }
void clearCache() {
LOCK_GUARD(_mtx_map);
_storage->clearCache();
for (auto &pr : _dispatcher_map) {
auto &second = pr.second;
//切换线程后清空缓存 [AUTO-TRANSLATED:150f7fa4]
//Switch thread and clear cache
pr.first->async([second]() { second->clearCache(); }, false);
}
}
void flushGop(std::function<void(const T &)> cb) {
LOCK_GUARD(_mtx_map);
_storage->getCache().for_each([&](const List<std::pair<bool, T>> &lst) {
lst.for_each([&](const std::pair<bool, T> &pr) { cb(pr.second); });
});
}
void getInfoList(const onGetInfoCB &cb, const typename RingReaderDispatcher::onChangeInfoCB &on_change = nullptr) {
if (!cb) {
return;
}
if (!on_change) {
const_cast<typename RingReaderDispatcher::onChangeInfoCB &>(on_change) = [](Any &&info) { return std::move(info); };
}
LOCK_GUARD(_mtx_map);
auto info_vec = std::make_shared<std::vector<std::list<Any>>>();
// 1、最少确保一个元素 [AUTO-TRANSLATED:6dafe078]
//1. Ensure at least one element
info_vec->resize(_dispatcher_map.empty() ? 1 : _dispatcher_map.size());
std::shared_ptr<void> on_finished(nullptr, [cb, info_vec](void *) mutable {
// 2、防止这里为空 [AUTO-TRANSLATED:4484baf7]
//2. Prevent this from being empty
auto &lst = *info_vec->begin();
for (auto &item : *info_vec) {
if (&lst != &item) {
lst.insert(lst.end(), item.begin(), item.end());
}
}
cb(lst);
});
auto i = 0U;
for (auto &pr : _dispatcher_map) {
auto &second = pr.second;
pr.first->async([second, info_vec, on_finished, i, on_change]() { (*info_vec)[i] = second->getInfoList(on_change); });
++i;
}
}
private:
void onSizeChanged(const EventPoller::Ptr &poller, int size, bool add_flag) {
if (size == 0) {
LOCK_GUARD(_mtx_map);
_dispatcher_map.erase(poller);
}
if (add_flag) {
++_total_count;
} else {
--_total_count;
}
_on_reader_changed(_total_count);
}
private:
struct HashOfPtr {
std::size_t operator()(const EventPoller::Ptr &key) const { return (std::size_t)key.get(); }
};
private:
std::mutex _mtx_map;
std::atomic_int _total_count { 0 };
typename RingStorage::Ptr _storage;
typename RingDelegate<T>::Ptr _delegate;
onReaderChanged _on_reader_changed;
std::unordered_map<EventPoller::Ptr, typename RingReaderDispatcher::Ptr, HashOfPtr> _dispatcher_map;
};
} /* namespace toolkit */
#endif /* UTIL_RINGBUFFER_H_ */

View File

@ -0,0 +1,341 @@
// 100% Public Domain.
//
// Original C Code
// -- Steve Reid <steve@edmweb.com>
// Small changes to fit into bglibs
// -- Bruce Guenter <bruce@untroubled.org>
// Translation to simpler C++ Code
// -- Volker Grabsch <vog@notjusthosting.com>
// Safety fixes
// -- Eugene Hopkinson <slowriot at voxelstorm dot com>
// Adapt for project
// Dmitriy Khaustov <khaustov.dm@gmail.com>
//
// File created on: 2017.02.25
// SHA1.cpp
#include "SHA1.h"
#include <sstream>
#include <iomanip>
#include <fstream>
namespace toolkit {
static const size_t BLOCK_INTS = 16; /* number of 32bit integers per SHA1 block */
static const size_t BLOCK_BYTES = BLOCK_INTS * 4;
static void reset(uint32_t digest[], std::string &buffer, uint64_t &transforms)
{
/* SHA1 initialization constants */
digest[0] = 0x67452301;
digest[1] = 0xefcdab89;
digest[2] = 0x98badcfe;
digest[3] = 0x10325476;
digest[4] = 0xc3d2e1f0;
/* Reset counters */
buffer = "";
transforms = 0;
}
static uint32_t rol(const uint32_t value, const size_t bits)
{
return (value << bits) | (value >> (32 - bits));
}
static uint32_t blk(const uint32_t block[BLOCK_INTS], const size_t i)
{
return rol(block[(i+13)&15] ^ block[(i+8)&15] ^ block[(i+2)&15] ^ block[i], 1);
}
/*
* (R0+R1), R2, R3, R4 are the different operations used in SHA1
*/
static void R0(const uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
z += ((w&(x^y))^y) + block[i] + 0x5a827999 + rol(v, 5);
w = rol(w, 30);
}
static void R1(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += ((w&(x^y))^y) + block[i] + 0x5a827999 + rol(v, 5);
w = rol(w, 30);
}
static void R2(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += (w^x^y) + block[i] + 0x6ed9eba1 + rol(v, 5);
w = rol(w, 30);
}
static void R3(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += (((w|x)&y)|(w&x)) + block[i] + 0x8f1bbcdc + rol(v, 5);
w = rol(w, 30);
}
static void R4(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += (w^x^y) + block[i] + 0xca62c1d6 + rol(v, 5);
w = rol(w, 30);
}
/*
* Hash a single 512-bit block. This is the core of the algorithm.
*/
static void transform(uint32_t digest[], uint32_t block[BLOCK_INTS], uint64_t &transforms)
{
/* Copy digest[] to working vars */
uint32_t a = digest[0];
uint32_t b = digest[1];
uint32_t c = digest[2];
uint32_t d = digest[3];
uint32_t e = digest[4];
/* 4 rounds of 20 operations each. Loop unrolled. */
R0(block, a, b, c, d, e, 0);
R0(block, e, a, b, c, d, 1);
R0(block, d, e, a, b, c, 2);
R0(block, c, d, e, a, b, 3);
R0(block, b, c, d, e, a, 4);
R0(block, a, b, c, d, e, 5);
R0(block, e, a, b, c, d, 6);
R0(block, d, e, a, b, c, 7);
R0(block, c, d, e, a, b, 8);
R0(block, b, c, d, e, a, 9);
R0(block, a, b, c, d, e, 10);
R0(block, e, a, b, c, d, 11);
R0(block, d, e, a, b, c, 12);
R0(block, c, d, e, a, b, 13);
R0(block, b, c, d, e, a, 14);
R0(block, a, b, c, d, e, 15);
R1(block, e, a, b, c, d, 0);
R1(block, d, e, a, b, c, 1);
R1(block, c, d, e, a, b, 2);
R1(block, b, c, d, e, a, 3);
R2(block, a, b, c, d, e, 4);
R2(block, e, a, b, c, d, 5);
R2(block, d, e, a, b, c, 6);
R2(block, c, d, e, a, b, 7);
R2(block, b, c, d, e, a, 8);
R2(block, a, b, c, d, e, 9);
R2(block, e, a, b, c, d, 10);
R2(block, d, e, a, b, c, 11);
R2(block, c, d, e, a, b, 12);
R2(block, b, c, d, e, a, 13);
R2(block, a, b, c, d, e, 14);
R2(block, e, a, b, c, d, 15);
R2(block, d, e, a, b, c, 0);
R2(block, c, d, e, a, b, 1);
R2(block, b, c, d, e, a, 2);
R2(block, a, b, c, d, e, 3);
R2(block, e, a, b, c, d, 4);
R2(block, d, e, a, b, c, 5);
R2(block, c, d, e, a, b, 6);
R2(block, b, c, d, e, a, 7);
R3(block, a, b, c, d, e, 8);
R3(block, e, a, b, c, d, 9);
R3(block, d, e, a, b, c, 10);
R3(block, c, d, e, a, b, 11);
R3(block, b, c, d, e, a, 12);
R3(block, a, b, c, d, e, 13);
R3(block, e, a, b, c, d, 14);
R3(block, d, e, a, b, c, 15);
R3(block, c, d, e, a, b, 0);
R3(block, b, c, d, e, a, 1);
R3(block, a, b, c, d, e, 2);
R3(block, e, a, b, c, d, 3);
R3(block, d, e, a, b, c, 4);
R3(block, c, d, e, a, b, 5);
R3(block, b, c, d, e, a, 6);
R3(block, a, b, c, d, e, 7);
R3(block, e, a, b, c, d, 8);
R3(block, d, e, a, b, c, 9);
R3(block, c, d, e, a, b, 10);
R3(block, b, c, d, e, a, 11);
R4(block, a, b, c, d, e, 12);
R4(block, e, a, b, c, d, 13);
R4(block, d, e, a, b, c, 14);
R4(block, c, d, e, a, b, 15);
R4(block, b, c, d, e, a, 0);
R4(block, a, b, c, d, e, 1);
R4(block, e, a, b, c, d, 2);
R4(block, d, e, a, b, c, 3);
R4(block, c, d, e, a, b, 4);
R4(block, b, c, d, e, a, 5);
R4(block, a, b, c, d, e, 6);
R4(block, e, a, b, c, d, 7);
R4(block, d, e, a, b, c, 8);
R4(block, c, d, e, a, b, 9);
R4(block, b, c, d, e, a, 10);
R4(block, a, b, c, d, e, 11);
R4(block, e, a, b, c, d, 12);
R4(block, d, e, a, b, c, 13);
R4(block, c, d, e, a, b, 14);
R4(block, b, c, d, e, a, 15);
/* Add the working vars back into digest[] */
digest[0] += a;
digest[1] += b;
digest[2] += c;
digest[3] += d;
digest[4] += e;
/* Count the number of transformations */
transforms++;
}
static void buffer_to_block(const std::string &buffer, uint32_t block[BLOCK_INTS])
{
/* Convert the std::string (byte buffer) to a uint32_t array (MSB) */
for (size_t i = 0; i < BLOCK_INTS; i++)
{
block[i] =
(buffer[4*i+3] & 0xFF)
| (buffer[4*i+2] & 0xFF) << 8
| (buffer[4*i+1] & 0xff) << 16
| (buffer[4*i+0] & 0xff) << 24;
}
}
SHA1::SHA1()
{
reset(digest, buffer, transforms);
}
void SHA1::update(const std::string &s)
{
std::istringstream is(s);
update(is);
}
void SHA1::update(std::istream &is)
{
while (true)
{
char sbuf[BLOCK_BYTES];
is.read(sbuf, BLOCK_BYTES - buffer.size());
buffer.append(sbuf, is.gcount());
if (buffer.size() != BLOCK_BYTES)
{
return;
}
uint32_t block[BLOCK_INTS];
buffer_to_block(buffer, block);
transform(digest, block, transforms);
buffer.clear();
}
}
/*
* Add padding and return the message digest.
*/
std::string SHA1::final()
{
auto str = final_bin();
std::ostringstream result;
for (size_t i = 0; i < str.size(); i++)
{
char b[3];
sprintf(b, "%02x", static_cast<unsigned char>(str[i]));
result << b;
}
return result.str();
}
std::string SHA1::final_bin()
{
/* Total number of hashed bits */
uint64_t total_bits = (transforms*BLOCK_BYTES + buffer.size()) * 8;
/* Padding */
buffer += 0x80;
size_t orig_size = buffer.size();
while (buffer.size() < BLOCK_BYTES)
{
buffer += (char)0x00;
}
uint32_t block[BLOCK_INTS];
buffer_to_block(buffer, block);
if (orig_size > BLOCK_BYTES - 8)
{
transform(digest, block, transforms);
for (size_t i = 0; i < BLOCK_INTS - 2; i++)
{
block[i] = 0;
}
}
/* Append total_bits, split this uint64_t into two uint32_t */
block[BLOCK_INTS - 1] = total_bits;
block[BLOCK_INTS - 2] = (total_bits >> 32);
transform(digest, block, transforms);
/* Hex std::string */
std::string result;
for (size_t i = 0; i < sizeof(digest) / sizeof(digest[0]); i++)
{
for (size_t b = 0; b < sizeof(digest[0])/sizeof(uint8_t); b++)
{
result.push_back((digest[i] >> (8 * (sizeof(digest[0]) / sizeof(uint8_t) - 1 - b))) & 0xFF);
}
}
/* Reset for next run */
reset(digest, buffer, transforms);
return result;
}
std::string SHA1::from_file(const std::string &filename)
{
std::ifstream stream(filename.c_str(), std::ios::binary);
SHA1 checksum;
checksum.update(stream);
return checksum.final();
}
std::string SHA1::encode(const std::string &s)
{
SHA1 sha1;
sha1.update(s);
return sha1.final();
}
std::string SHA1::encode_bin(const std::string &s)
{
SHA1 sha1;
sha1.update(s);
return sha1.final_bin();
}
} //namespace toolkit

View File

@ -0,0 +1,47 @@
// 100% Public Domain.
//
// Original C Code
// -- Steve Reid <steve@edmweb.com>
// Small changes to fit into bglibs
// -- Bruce Guenter <bruce@untroubled.org>
// Translation to simpler C++ Code
// -- Volker Grabsch <vog@notjusthosting.com>
// Safety fixes
// -- Eugene Hopkinson <slowriot at voxelstorm dot com>
// Adapt for project
// Dmitriy Khaustov <khaustov.dm@gmail.com>
//
// File created on: 2017.02.25
// SHA1.h
#pragma once
#include <cstdint>
#include <iostream>
#include <string>
namespace toolkit {
class SHA1 final
{
public:
SHA1();
void update(const std::string &s);
void update(std::istream &is);
std::string final();
std::string final_bin();
static std::string from_file(const std::string &filename);
static std::string encode(const std::string &s);
static std::string encode_bin(const std::string &s);
private:
uint32_t digest[5];
std::string buffer;
uint64_t transforms;
};
}//namespace toolkit

View File

@ -0,0 +1,550 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SSLBox.h"
#include "onceToken.h"
#include "SSLUtil.h"
#if defined(ENABLE_OPENSSL)
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/conf.h>
#include <openssl/bio.h>
#include <openssl/ossl_typ.h>
#endif //defined(ENABLE_OPENSSL)
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
//openssl版本是否支持sni [AUTO-TRANSLATED:4c92a880]
//Is the OpenSSL version SNI supported
#define SSL_ENABLE_SNI
#endif
using namespace std;
namespace toolkit {
static bool s_ignore_invalid_cer = true;
SSL_Initor &SSL_Initor::Instance() {
static SSL_Initor obj;
return obj;
}
void SSL_Initor::ignoreInvalidCertificate(bool ignore) {
s_ignore_invalid_cer = ignore;
}
SSL_Initor::SSL_Initor() {
#if defined(ENABLE_OPENSSL)
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_digests();
OpenSSL_add_all_ciphers();
OpenSSL_add_all_algorithms();
CRYPTO_set_locking_callback([](int mode, int n, const char *file, int line) {
static mutex *s_mutexes = new mutex[CRYPTO_num_locks()];
static onceToken token(nullptr, []() {
delete[] s_mutexes;
});
if (mode & CRYPTO_LOCK) {
s_mutexes[n].lock();
} else {
s_mutexes[n].unlock();
}
});
CRYPTO_set_id_callback([]() -> unsigned long {
#if !defined(_WIN32)
return (unsigned long) pthread_self();
#else
return (unsigned long) GetCurrentThreadId();
#endif
});
setContext("", SSLUtil::makeSSLContext(vector<shared_ptr<X509> >(), nullptr, false), false);
setContext("", SSLUtil::makeSSLContext(vector<shared_ptr<X509> >(), nullptr, true), true);
#endif //defined(ENABLE_OPENSSL)
}
SSL_Initor::~SSL_Initor() {
#if defined(ENABLE_OPENSSL)
EVP_cleanup();
ERR_free_strings();
ERR_clear_error();
#if OPENSSL_VERSION_NUMBER >= 0x10000000L && OPENSSL_VERSION_NUMBER < 0x10100000L
ERR_remove_thread_state(nullptr);
#elif OPENSSL_VERSION_NUMBER < 0x10000000L
ERR_remove_state(0);
#endif
CRYPTO_set_locking_callback(nullptr);
//sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
CRYPTO_cleanup_all_ex_data();
CONF_modules_unload(1);
CONF_modules_free();
#endif //defined(ENABLE_OPENSSL)
}
bool SSL_Initor::loadCertificate(const string &pem_or_p12, bool server_mode, const string &password, bool is_file, bool is_default) {
auto cers = SSLUtil::loadPublicKey(pem_or_p12, password, is_file);
auto key = SSLUtil::loadPrivateKey(pem_or_p12, password, is_file);
auto ssl_ctx = SSLUtil::makeSSLContext(cers, key, server_mode, true);
if (!ssl_ctx) {
return false;
}
for (auto &cer : cers) {
auto server_name = SSLUtil::getServerName(cer.get());
setContext(server_name, ssl_ctx, server_mode, is_default);
break;
}
return true;
}
int SSL_Initor::findCertificate(SSL *ssl, int *, void *arg) {
#if !defined(ENABLE_OPENSSL) || !defined(SSL_ENABLE_SNI)
return 0;
#else
if (!ssl) {
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
SSL_CTX *ctx = nullptr;
static auto &ref = SSL_Initor::Instance();
const char *vhost = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (vhost && vhost[0] != '\0') {
//根据域名找到证书 [AUTO-TRANSLATED:783a55d8]
//Find the certificate based on the domain name
ctx = ref.getSSLCtx(vhost, (bool) (arg)).get();
if (!ctx) {
//未找到对应的证书 [AUTO-TRANSLATED:d4550e6f]
//No corresponding certificate found
std::lock_guard<std::recursive_mutex> lck(ref._mtx);
WarnL << "Can not find any certificate of host: " << vhost
<< ", select default certificate of: " << ref._default_vhost[(bool) (arg)];
}
}
if (!ctx) {
//客户端未指定域名或者指定的证书不存在,那么选择一个默认的证书 [AUTO-TRANSLATED:35115b5c]
//The client did not specify a domain name or the specified certificate does not exist, so a default certificate is selected
ctx = ref.getSSLCtx("", (bool) (arg)).get();
}
if (!ctx) {
//未有任何有效的证书 [AUTO-TRANSLATED:e1d7f5b7]
//No valid certificate available
WarnL << "Can not find any available certificate of host: " << (vhost ? vhost : "default host")
<< ", tls handshake failed";
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
SSL_set_SSL_CTX(ssl, ctx);
return SSL_TLSEXT_ERR_OK;
#endif
}
bool SSL_Initor::setContext(const string &vhost, const shared_ptr<SSL_CTX> &ctx, bool server_mode, bool is_default) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
if (!ctx) {
return false;
}
setupCtx(ctx.get());
#if defined(ENABLE_OPENSSL)
if (vhost.empty()) {
_ctx_empty[server_mode] = ctx;
#ifdef SSL_ENABLE_SNI
if (server_mode) {
SSL_CTX_set_tlsext_servername_callback(ctx.get(), findCertificate);
SSL_CTX_set_tlsext_servername_arg(ctx.get(), (void *) server_mode);
}
#endif // SSL_ENABLE_SNI
} else {
_ctxs[server_mode][vhost] = ctx;
if (is_default) {
_default_vhost[server_mode] = vhost;
}
if (vhost.find("*.") == 0) {
//通配符证书 [AUTO-TRANSLATED:faeefee7]
//Wildcard certificate
_ctxs_wildcards[server_mode][vhost.substr(1)] = ctx;
}
DebugL << "Add certificate of: " << vhost;
}
return true;
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return false;
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Initor::setupCtx(SSL_CTX *ctx) {
#if defined(ENABLE_OPENSSL)
//加载默认信任证书 [AUTO-TRANSLATED:4d98f092]
//Load default trusted certificate
SSLUtil::loadDefaultCAs(ctx);
SSL_CTX_set_cipher_list(ctx, "ALL:!ADH:!LOW:!EXP:!MD5:!3DES:!DES:!IDEA:!RC4:!SEED-SHA:@STRENGTH");
SSL_CTX_set_verify_depth(ctx, 9);
SSL_CTX_set_mode(ctx, SSL_MODE_AUTO_RETRY);
SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, [](int ok, X509_STORE_CTX *pStore) {
if (!ok) {
int depth = X509_STORE_CTX_get_error_depth(pStore);
int err = X509_STORE_CTX_get_error(pStore);
WarnL << "SSL_CTX_set_verify callback, depth: " << depth << " ,err: " << X509_verify_cert_error_string(err);
}
return s_ignore_invalid_cer ? 1 : ok;
});
#ifndef SSL_OP_NO_COMPRESSION
#define SSL_OP_NO_COMPRESSION 0
#endif
#ifndef SSL_MODE_RELEASE_BUFFERS /* OpenSSL >= 1.0.0 */
#define SSL_MODE_RELEASE_BUFFERS 0
#endif
unsigned long ssloptions = SSL_OP_ALL
| SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
| SSL_OP_NO_COMPRESSION;
#ifdef SSL_OP_NO_RENEGOTIATION /* openssl 1.1.0 */
ssloptions |= SSL_OP_NO_RENEGOTIATION;
#endif
#ifdef SSL_OP_NO_SSLv2
ssloptions |= SSL_OP_NO_SSLv2;
#endif
#ifdef SSL_OP_NO_SSLv3
ssloptions |= SSL_OP_NO_SSLv3;
#endif
#ifdef SSL_OP_NO_TLSv1
ssloptions |= SSL_OP_NO_TLSv1;
#endif
#ifdef SSL_OP_NO_TLSv1_1 /* openssl 1.0.1 */
ssloptions |= SSL_OP_NO_TLSv1_1;
#endif
SSL_CTX_set_options(ctx, ssloptions);
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<SSL> SSL_Initor::makeSSL(bool server_mode) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
#if defined(ENABLE_OPENSSL)
#ifdef SSL_ENABLE_SNI
//openssl 版本支持SNI [AUTO-TRANSLATED:b8029f6c]
//OpenSSL version supports SNI
return SSLUtil::makeSSL(_ctx_empty[server_mode].get());
#else
//openssl 版本不支持SNI选择默认证书 [AUTO-TRANSLATED:cedb5f02]
//OpenSSL version does not support SNI, select default certificate
return SSLUtil::makeSSL(getSSLCtx("",server_mode).get());
#endif//SSL_CTRL_SET_TLSEXT_HOSTNAME
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
bool SSL_Initor::trustCertificate(X509 *cer, bool server_mode) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
return SSLUtil::trustCertificate(_ctx_empty[server_mode].get(), cer);
}
bool SSL_Initor::trustCertificate(const string &pem_p12_cer, bool server_mode, const string &password, bool is_file) {
auto cers = SSLUtil::loadPublicKey(pem_p12_cer, password, is_file);
for (auto &cer : cers) {
trustCertificate(cer.get(), server_mode);
}
return true;
}
std::shared_ptr<SSL_CTX> SSL_Initor::getSSLCtx(const string &vhost, bool server_mode) {
auto ret = getSSLCtx_l(vhost, server_mode);
if (ret) {
return ret;
}
return getSSLCtxWildcards(vhost, server_mode);
}
std::shared_ptr<SSL_CTX> SSL_Initor::getSSLCtxWildcards(const string &vhost, bool server_mode) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
for (auto &pr : _ctxs_wildcards[server_mode]) {
auto pos = strcasestr(vhost.data(), pr.first.data());
if (pos && pos + pr.first.size() == &vhost.back() + 1) {
return pr.second;
}
}
return nullptr;
}
std::shared_ptr<SSL_CTX> SSL_Initor::getSSLCtx_l(const string &vhost_in, bool server_mode) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
auto vhost = vhost_in;
if (vhost.empty()) {
if (!_default_vhost[server_mode].empty()) {
vhost = _default_vhost[server_mode];
} else {
//没默认主机,选择空主机 [AUTO-TRANSLATED:99a7d8d4]
//No default host, select empty host
if (server_mode) {
WarnL << "Server with ssl must have certification and key";
}
return _ctx_empty[server_mode];
}
}
//根据主机名查找证书 [AUTO-TRANSLATED:dcc98736]
//Find certificate by hostname
auto it = _ctxs[server_mode].find(vhost);
if (it == _ctxs[server_mode].end()) {
return nullptr;
}
return it->second;
}
string SSL_Initor::defaultVhost(bool server_mode) {
std::lock_guard<std::recursive_mutex> lck(_mtx);
return _default_vhost[server_mode];
}
////////////////////////////////////////////////////SSL_Box////////////////////////////////////////////////////////////
SSL_Box::~SSL_Box() {}
SSL_Box::SSL_Box(bool server_mode, bool enable, int buff_size) {
#if defined(ENABLE_OPENSSL)
_read_bio = BIO_new(BIO_s_mem());
_server_mode = server_mode;
if (enable) {
_ssl = SSL_Initor::Instance().makeSSL(server_mode);
}
if (_ssl) {
_write_bio = BIO_new(BIO_s_mem());
SSL_set_bio(_ssl.get(), _read_bio, _write_bio);
_server_mode ? SSL_set_accept_state(_ssl.get()) : SSL_set_connect_state(_ssl.get());
} else {
WarnL << "makeSSL failed";
}
_send_handshake = false;
_buff_size = buff_size;
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::shutdown() {
#if defined(ENABLE_OPENSSL)
_buffer_send.clear();
int ret = SSL_shutdown(_ssl.get());
if (ret != 1) {
ErrorL << "SSL_shutdown failed: " << SSLUtil::getLastError();
} else {
flush();
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::onRecv(const Buffer::Ptr &buffer) {
if (!buffer->size()) {
return;
}
if (!_ssl) {
if (_on_dec) {
_on_dec(buffer);
}
return;
}
#if defined(ENABLE_OPENSSL)
uint32_t offset = 0;
while (offset < buffer->size()) {
auto nwrite = BIO_write(_read_bio, buffer->data() + offset, buffer->size() - offset);
if (nwrite > 0) {
//部分或全部写入bio完毕 [AUTO-TRANSLATED:baabfef4]
//Partial or full write to bio completed
offset += nwrite;
flush();
continue;
}
//nwrite <= 0,出现异常 [AUTO-TRANSLATED:986e8f36]
//nwrite <= 0, an error occurred
ErrorL << "Ssl error on BIO_write: " << SSLUtil::getLastError();
shutdown();
break;
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::onSend(Buffer::Ptr buffer) {
if (!buffer->size()) {
return;
}
if (!_ssl) {
if (_on_enc) {
_on_enc(buffer);
}
return;
}
#if defined(ENABLE_OPENSSL)
if (!_server_mode && !_send_handshake) {
_send_handshake = true;
SSL_do_handshake(_ssl.get());
}
_buffer_send.emplace_back(std::move(buffer));
flush();
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::setOnDecData(const function<void(const Buffer::Ptr &)> &cb) {
_on_dec = cb;
}
void SSL_Box::setOnEncData(const function<void(const Buffer::Ptr &)> &cb) {
_on_enc = cb;
}
void SSL_Box::flushWriteBio() {
#if defined(ENABLE_OPENSSL)
int total = 0;
int nread = 0;
auto buffer_bio = _buffer_pool.obtain2();
buffer_bio->setCapacity(_buff_size);
auto buf_size = buffer_bio->getCapacity() - 1;
do {
nread = BIO_read(_write_bio, buffer_bio->data() + total, buf_size - total);
if (nread > 0) {
total += nread;
}
} while (nread > 0 && buf_size - total > 0);
if (!total) {
//未有数据 [AUTO-TRANSLATED:9ae3aaa5]
//No data available
return;
}
//触发此次回调 [AUTO-TRANSLATED:dc10c264]
//Trigger this callback
buffer_bio->data()[total] = '\0';
buffer_bio->setSize(total);
if (_on_enc) {
_on_enc(buffer_bio);
}
if (nread > 0) {
//还有剩余数据,读取剩余数据 [AUTO-TRANSLATED:008f4187]
//Still have remaining data, read the remaining data
flushWriteBio();
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::flushReadBio() {
#if defined(ENABLE_OPENSSL)
int total = 0;
int nread = 0;
auto buffer_bio = _buffer_pool.obtain2();
buffer_bio->setCapacity(_buff_size);
auto buf_size = buffer_bio->getCapacity() - 1;
do {
nread = SSL_read(_ssl.get(), buffer_bio->data() + total, buf_size - total);
if (nread > 0) {
total += nread;
}
} while (nread > 0 && buf_size - total > 0);
if (!total) {
//未有数据 [AUTO-TRANSLATED:9ae3aaa5]
//No data available
return;
}
//触发此次回调 [AUTO-TRANSLATED:dc10c264]
//Trigger this callback
buffer_bio->data()[total] = '\0';
buffer_bio->setSize(total);
if (_on_dec) {
_on_dec(buffer_bio);
}
if (nread > 0) {
//还有剩余数据,读取剩余数据 [AUTO-TRANSLATED:008f4187]
//Still have remaining data, read the remaining data
flushReadBio();
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::flush() {
#if defined(ENABLE_OPENSSL)
if (_is_flush) {
return;
}
onceToken token([&] {
_is_flush = true;
}, [&]() {
_is_flush = false;
});
flushReadBio();
if (!SSL_is_init_finished(_ssl.get()) || _buffer_send.empty()) {
//ssl未握手结束或没有需要发送的数据 [AUTO-TRANSLATED:39f8490c]
//SSL handshake not finished or no data to send
flushWriteBio();
return;
}
//加密数据并发送 [AUTO-TRANSLATED:c09fdbd0]
//Encrypt data and send
while (!_buffer_send.empty()) {
auto &front = _buffer_send.front();
uint32_t offset = 0;
while (offset < front->size()) {
auto nwrite = SSL_write(_ssl.get(), front->data() + offset, front->size() - offset);
if (nwrite > 0) {
//部分或全部写入完毕 [AUTO-TRANSLATED:661163d2]
//Partial or complete write finished
offset += nwrite;
flushWriteBio();
continue;
}
//nwrite <= 0,出现异常 [AUTO-TRANSLATED:986e8f36]
//nwrite <= 0, an exception occurred
break;
}
if (offset != front->size()) {
//这个包未消费完毕,出现了异常,清空数据并断开ssl [AUTO-TRANSLATED:1823c65a]
//This package has not been fully consumed, an exception occurred, clear data and disconnect ssl
ErrorL << "Ssl error on SSL_write: " << SSLUtil::getLastError();
shutdown();
break;
}
//这个包消费完毕,开始消费下一个包 [AUTO-TRANSLATED:6fa31240]
//This package has been fully consumed, start consuming the next package
_buffer_send.pop_front();
}
#endif //defined(ENABLE_OPENSSL)
}
bool SSL_Box::setHost(const char *host) {
if (!_ssl) {
return false;
}
#ifdef SSL_ENABLE_SNI
return 0 != SSL_set_tlsext_host_name(_ssl.get(), host);
#else
return false;
#endif//SSL_ENABLE_SNI
}
} /* namespace toolkit */

View File

@ -0,0 +1,290 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef CRYPTO_SSLBOX_H_
#define CRYPTO_SSLBOX_H_
#include <mutex>
#include <string>
#include <functional>
#include "logger.h"
#include "List.h"
#include "util.h"
#include "Network/Buffer.h"
#include "ResourcePool.h"
typedef struct x509_st X509;
typedef struct evp_pkey_st EVP_PKEY;
typedef struct ssl_ctx_st SSL_CTX;
typedef struct ssl_st SSL;
typedef struct bio_st BIO;
namespace toolkit {
class SSL_Initor {
public:
friend class SSL_Box;
static SSL_Initor &Instance();
/**
*
* (cer格式的证书只包括公钥使)
* ()
* @param pem_or_p12 pem或p12文件路径或者文件内容字符串
* @param server_mode
* @param password
* @param is_file pem_or_p12是否为文件路径
* @param is_default
* Load public and private keys from a file or string
* The certificate file must contain both public and private keys (cer format certificates only include public keys, use the following method to load)
* The client can load the certificate by default (unless the server requires the client to provide a certificate)
* @param pem_or_p12 pem or p12 file path or file content string
* @param server_mode Whether it is in server mode
* @param password Private key encryption password
* @param is_file Whether the parameter pem_or_p12 is a file path
* @param is_default Whether it is the default certificate
* [AUTO-TRANSLATED:18cec755]
*/
bool loadCertificate(const std::string &pem_or_p12, bool server_mode = true, const std::string &password = "",
bool is_file = true, bool is_default = true);
/**
*
*
* @param ignore
* Whether to ignore invalid certificates
* Ignore by default, strongly not recommended!
* @param ignore Flag
* [AUTO-TRANSLATED:fd45125a]
*/
void ignoreInvalidCertificate(bool ignore = true);
/**
* ,CA签署的证书使用
*
* @param pem_p12_cer pem文件或p12文件或cer文件路径或内容
* @param server_mode
* @param password pem或p12证书的密码
* @param is_file
* @return
* Trust a certain certificate, generally used for clients to trust self-signed certificates or certificates signed by self-signed CAs
* For example, if my client wants to trust a certificate I issued myself, we can only trust this certificate
* @param pem_p12_cer pem file or p12 file or cer file path or content
* @param server_mode Whether it is in server mode
* @param password pem or p12 certificate password
* @param is_file Whether it is a file path
* @return Whether the loading is successful
* [AUTO-TRANSLATED:9ace5400]
*/
bool trustCertificate(const std::string &pem_p12_cer, bool server_mode = false, const std::string &password = "",
bool is_file = true);
/**
*
* @param cer
* @param server_mode
* @return
* Trust a certain certificate
* @param cer Certificate public key
* @param server_mode Whether it is in server mode
* @return Whether the loading is successful
* [AUTO-TRANSLATED:557120dd]
*/
bool trustCertificate(X509 *cer, bool server_mode = false);
/**
* SSL_CTX对象
* @param vhost
* @param server_mode
* @return SSL_CTX对象
* Get the SSL_CTX object based on the virtual host
* @param vhost Virtual host name
* @param server_mode Whether it is in server mode
* @return SSL_CTX object
* [AUTO-TRANSLATED:4d771109]
*/
std::shared_ptr<SSL_CTX> getSSLCtx(const std::string &vhost, bool server_mode);
private:
SSL_Initor();
~SSL_Initor();
/**
* SSL对象
* Create an SSL object
* [AUTO-TRANSLATED:047a0b4c]
*/
std::shared_ptr<SSL> makeSSL(bool server_mode);
/**
* ssl context
* @param vhost
* @param ctx ssl context
* @param server_mode ssl context
* @param is_default
* Set the ssl context
* @param vhost Virtual host name
* @param ctx ssl context
* @param server_mode ssl context
* @param is_default Whether it is the default certificate
* [AUTO-TRANSLATED:265f3049]
*/
bool setContext(const std::string &vhost, const std::shared_ptr<SSL_CTX> &ctx, bool server_mode, bool is_default = true);
/**
* SSL_CTX的默认配置
* @param ctx
* Set the default configuration for SSL_CTX
* @param ctx Object pointer
* [AUTO-TRANSLATED:1b3438d0]
*/
static void setupCtx(SSL_CTX *ctx);
std::shared_ptr<SSL_CTX> getSSLCtx_l(const std::string &vhost, bool server_mode);
std::shared_ptr<SSL_CTX> getSSLCtxWildcards(const std::string &vhost, bool server_mode);
/**
*
* Get the default virtual host
* [AUTO-TRANSLATED:e2430399]
*/
std::string defaultVhost(bool server_mode);
/**
* vhost name
* Callback function for completing vhost name matching
* [AUTO-TRANSLATED:f9973cfa]
*/
static int findCertificate(SSL *ssl, int *ad, void *arg);
private:
struct less_nocase {
bool operator()(const std::string &x, const std::string &y) const {
return strcasecmp(x.data(), y.data()) < 0;
}
};
private:
std::recursive_mutex _mtx;
std::string _default_vhost[2];
std::shared_ptr<SSL_CTX> _ctx_empty[2];
std::map<std::string, std::shared_ptr<SSL_CTX>, less_nocase> _ctxs[2];
std::map<std::string, std::shared_ptr<SSL_CTX>, less_nocase> _ctxs_wildcards[2];
};
////////////////////////////////////////////////////////////////////////////////////
class SSL_Box {
public:
SSL_Box(bool server_mode = true, bool enable = true, int buff_size = 32 * 1024);
~SSL_Box();
/**
*
* @param buffer
* Decrypts the received ciphertext after calling this function
* @param buffer Received ciphertext data
* [AUTO-TRANSLATED:7e8b1fc6]
*/
void onRecv(const Buffer::Ptr &buffer);
/**
*
* @param buffer
* Calls this function to encrypt the plaintext that needs to be encrypted
* @param buffer Plaintext data that needs to be encrypted
* [AUTO-TRANSLATED:9d386695]
*/
void onSend(Buffer::Ptr buffer);
/**
*
* @param cb
* Sets the callback to get the plaintext after decryption
* @param cb Callback object
* [AUTO-TRANSLATED:897359bc]
*/
void setOnDecData(const std::function<void(const Buffer::Ptr &)> &cb);
/**
*
* @param cb
* Sets the callback to get the ciphertext after encryption
* @param cb Callback object
* [AUTO-TRANSLATED:bb31b34b]
*/
void setOnEncData(const std::function<void(const Buffer::Ptr &)> &cb);
/**
* ssl
* Terminates SSL
* [AUTO-TRANSLATED:2ab06469]
*/
void shutdown();
/**
*
* Clears data
* [AUTO-TRANSLATED:62d4f400]
*/
void flush();
/**
*
* @param host
* @return
* Sets the virtual host name
* @param host Virtual host name
* @return Whether the operation was successful
* [AUTO-TRANSLATED:eebc1e2f]
*/
bool setHost(const char *host);
private:
void flushWriteBio();
void flushReadBio();
private:
bool _server_mode;
bool _send_handshake;
bool _is_flush = false;
int _buff_size;
BIO *_read_bio;
BIO *_write_bio;
std::shared_ptr<SSL> _ssl;
List <Buffer::Ptr> _buffer_send;
ResourcePool <BufferRaw> _buffer_pool;
std::function<void(const Buffer::Ptr &)> _on_dec;
std::function<void(const Buffer::Ptr &)> _on_enc;
};
} /* namespace toolkit */
#endif /* CRYPTO_SSLBOX_H_ */

View File

@ -0,0 +1,399 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SSLUtil.h"
#include "onceToken.h"
#include "logger.h"
#if defined(ENABLE_OPENSSL)
#include <openssl/bio.h>
#include <openssl/ossl_typ.h>
#include <openssl/pkcs12.h>
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/conf.h>
#endif //defined(ENABLE_OPENSSL)
using namespace std;
namespace toolkit {
std::string SSLUtil::getLastError() {
#if defined(ENABLE_OPENSSL)
unsigned long errCode = ERR_get_error();
if (errCode != 0) {
char buffer[256];
ERR_error_string_n(errCode, buffer, sizeof(buffer));
return buffer;
} else
#endif //defined(ENABLE_OPENSSL)
{
return "No error";
}
}
#if defined(ENABLE_OPENSSL)
static int getCerType(BIO *bio, const char *passwd, X509 **x509, int type) {
//尝试pem格式 [AUTO-TRANSLATED:8debedc8]
//Try pem format
if (type == 1 || type == 0) {
if (type == 0) {
BIO_reset(bio);
}
// 尝试PEM格式 [AUTO-TRANSLATED:311e0a11]
//Try PEM format
*x509 = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
if (*x509) {
return 1;
}
}
if (type == 2 || type == 0) {
if (type == 0) {
BIO_reset(bio);
}
//尝试DER格式 [AUTO-TRANSLATED:97ea1386]
//Try DER format
*x509 = d2i_X509_bio(bio, nullptr);
if (*x509) {
return 2;
}
}
if (type == 3 || type == 0) {
if (type == 0) {
BIO_reset(bio);
}
//尝试p12格式 [AUTO-TRANSLATED:32331d1d]
//Try p12 format
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
if (p12) {
EVP_PKEY *pkey = nullptr;
PKCS12_parse(p12, passwd, &pkey, x509, nullptr);
PKCS12_free(p12);
if (pkey) {
EVP_PKEY_free(pkey);
}
if (*x509) {
return 3;
}
}
}
return 0;
}
#endif //defined(ENABLE_OPENSSL)
vector<shared_ptr<X509> > SSLUtil::loadPublicKey(const string &file_path_or_data, const string &passwd, bool isFile) {
vector<shared_ptr<X509> > ret;
#if defined(ENABLE_OPENSSL)
BIO *bio = isFile ? BIO_new_file((char *) file_path_or_data.data(), "r") :
BIO_new_mem_buf((char *) file_path_or_data.data(), file_path_or_data.size());
if (!bio) {
WarnL << (isFile ? "BIO_new_file" : "BIO_new_mem_buf") << " failed: " << getLastError();
return ret;
}
onceToken token0(nullptr, [&]() {
BIO_free(bio);
});
int cer_type = 0;
X509 *x509 = nullptr;
do {
cer_type = getCerType(bio, passwd.data(), &x509, cer_type);
if (cer_type) {
ret.push_back(shared_ptr<X509>(x509, [](X509 *ptr) { X509_free(ptr); }));
}
} while (cer_type != 0);
return ret;
#else
return ret;
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<EVP_PKEY> SSLUtil::loadPrivateKey(const string &file_path_or_data, const string &passwd, bool isFile) {
#if defined(ENABLE_OPENSSL)
BIO *bio = isFile ?
BIO_new_file((char *) file_path_or_data.data(), "r") :
BIO_new_mem_buf((char *) file_path_or_data.data(), file_path_or_data.size());
if (!bio) {
WarnL << (isFile ? "BIO_new_file" : "BIO_new_mem_buf") << " failed: " << getLastError();
return nullptr;
}
pem_password_cb *cb = [](char *buf, int size, int rwflag, void *userdata) -> int {
const string *passwd = (const string *) userdata;
size = size < (int) passwd->size() ? size : (int) passwd->size();
memcpy(buf, passwd->data(), size);
return size;
};
onceToken token0(nullptr, [&]() {
BIO_free(bio);
});
//尝试pem格式 [AUTO-TRANSLATED:8debedc8]
//Try pem format
EVP_PKEY *evp_key = PEM_read_bio_PrivateKey(bio, nullptr, cb, (void *) &passwd);
if (!evp_key) {
//尝试p12格式 [AUTO-TRANSLATED:32331d1d]
//Try p12 format
BIO_reset(bio);
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
if (!p12) {
return nullptr;
}
X509 *x509 = nullptr;
PKCS12_parse(p12, passwd.data(), &evp_key, &x509, nullptr);
PKCS12_free(p12);
if (x509) {
X509_free(x509);
}
if (!evp_key) {
return nullptr;
}
}
return shared_ptr<EVP_PKEY>(evp_key, [](EVP_PKEY *ptr) {
EVP_PKEY_free(ptr);
});
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<SSL_CTX> SSLUtil::makeSSLContext(const vector<shared_ptr<X509> > &cers, const shared_ptr<EVP_PKEY> &key, bool serverMode, bool checkKey) {
#if defined(ENABLE_OPENSSL)
SSL_CTX *ctx = SSL_CTX_new(serverMode ? SSLv23_server_method() : SSLv23_client_method());
if (!ctx) {
WarnL << "SSL_CTX_new " << (serverMode ? "SSLv23_server_method" : "SSLv23_client_method") << " failed: " << getLastError();
return nullptr;
}
int i = 0;
for (auto &cer : cers) {
//加载公钥 [AUTO-TRANSLATED:d3cadbdf]
//Load public key
if (i++ == 0) {
//SSL_CTX_use_certificate内部会调用X509_up_ref,所以这里不用X509_dup [AUTO-TRANSLATED:610aca57]
//SSL_CTX_use_certificate internally calls X509_up_ref, so no need to use X509_dup here
SSL_CTX_use_certificate(ctx, cer.get());
} else {
//需要先拷贝X509对象否则指针会失效 [AUTO-TRANSLATED:c6cb5ebf]
//Need to copy X509 object first, otherwise the pointer will be invalid
SSL_CTX_add_extra_chain_cert(ctx, X509_dup(cer.get()));
}
}
if (key) {
//提供了私钥 [AUTO-TRANSLATED:1b23bc8c]
//Provided private key
if (SSL_CTX_use_PrivateKey(ctx, key.get()) != 1) {
WarnL << "SSL_CTX_use_PrivateKey failed: " << getLastError();
SSL_CTX_free(ctx);
return nullptr;
}
}
if (key || checkKey) {
//加载私钥成功 [AUTO-TRANSLATED:80e96abb]
//Private key loaded successfully
if (SSL_CTX_check_private_key(ctx) != 1) {
WarnL << "SSL_CTX_check_private_key failed: " << getLastError();
SSL_CTX_free(ctx);
return nullptr;
}
}
//公钥私钥匹配或者没有公私钥 [AUTO-TRANSLATED:b12ac3e6]
//Public and private key match or no public and private key
return shared_ptr<SSL_CTX>(ctx, [](SSL_CTX *ptr) { SSL_CTX_free(ptr); });
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<SSL> SSLUtil::makeSSL(SSL_CTX *ctx) {
#if defined(ENABLE_OPENSSL)
auto *ssl = SSL_new(ctx);
if (!ssl) {
return nullptr;
}
return shared_ptr<SSL>(ssl, [](SSL *ptr) {
SSL_free(ptr);
});
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
bool SSLUtil::loadDefaultCAs(SSL_CTX *ctx) {
#if defined(ENABLE_OPENSSL)
if (!ctx) {
return false;
}
if (SSL_CTX_set_default_verify_paths(ctx) != 1) {
WarnL << "SSL_CTX_set_default_verify_paths failed: " << getLastError();
return false;
}
return true;
#else
return false;
#endif //defined(ENABLE_OPENSSL)
}
bool SSLUtil::trustCertificate(SSL_CTX *ctx, X509 *cer) {
#if defined(ENABLE_OPENSSL)
X509_STORE *store = SSL_CTX_get_cert_store(ctx);
if (store && cer) {
if (X509_STORE_add_cert(store, cer) != 1) {
WarnL << "X509_STORE_add_cert failed: " << getLastError();
return false;
}
return true;
}
#endif //defined(ENABLE_OPENSSL)
return false;
}
bool SSLUtil::verifyX509(X509 *cer, ...) {
#if defined(ENABLE_OPENSSL)
va_list args;
va_start(args, cer);
X509_STORE *store = X509_STORE_new();
do {
X509 *ca;
if ((ca = va_arg(args, X509*)) == nullptr) {
break;
}
X509_STORE_add_cert(store, ca);
} while (true);
va_end(args);
X509_STORE_CTX *store_ctx = X509_STORE_CTX_new();
X509_STORE_CTX_init(store_ctx, store, cer, nullptr);
auto ret = X509_verify_cert(store_ctx);
if (ret != 1) {
int depth = X509_STORE_CTX_get_error_depth(store_ctx);
int err = X509_STORE_CTX_get_error(store_ctx);
WarnL << "X509_verify_cert failed, depth: " << depth << ", err: " << X509_verify_cert_error_string(err);
}
X509_STORE_CTX_free(store_ctx);
X509_STORE_free(store);
return ret == 1;
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return false;
#endif //defined(ENABLE_OPENSSL)
}
#ifdef ENABLE_OPENSSL
#ifndef X509_F_X509_PUBKEY_GET0
EVP_PKEY *X509_get0_pubkey(X509 *x){
EVP_PKEY *ret = X509_get_pubkey(x);
if(ret){
EVP_PKEY_free(ret);
}
return ret;
}
#endif //X509_F_X509_PUBKEY_GET0
#ifndef EVP_F_EVP_PKEY_GET0_RSA
RSA *EVP_PKEY_get0_RSA(EVP_PKEY *pkey){
RSA *ret = EVP_PKEY_get1_RSA(pkey);
if(ret){
RSA_free(ret);
}
return ret;
}
#endif //EVP_F_EVP_PKEY_GET0_RSA
#endif //ENABLE_OPENSSL
string SSLUtil::cryptWithRsaPublicKey(X509 *cer, const string &in_str, bool enc_or_dec) {
#if defined(ENABLE_OPENSSL)
EVP_PKEY *public_key = X509_get0_pubkey(cer);
if (!public_key) {
return "";
}
auto rsa = EVP_PKEY_get1_RSA(public_key);
if (!rsa) {
return "";
}
string out_str(RSA_size(rsa), '\0');
int ret = 0;
if (enc_or_dec) {
ret = RSA_public_encrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
} else {
ret = RSA_public_decrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
}
if (ret > 0) {
out_str.resize(ret);
return out_str;
}
WarnL << (enc_or_dec ? "RSA_public_encrypt" : "RSA_public_decrypt") << " failed: " << getLastError();
return "";
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return "";
#endif //defined(ENABLE_OPENSSL)
}
string SSLUtil::cryptWithRsaPrivateKey(EVP_PKEY *private_key, const string &in_str, bool enc_or_dec) {
#if defined(ENABLE_OPENSSL)
auto rsa = EVP_PKEY_get1_RSA(private_key);
if (!rsa) {
return "";
}
string out_str(RSA_size(rsa), '\0');
int ret = 0;
if (enc_or_dec) {
ret = RSA_private_encrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
} else {
ret = RSA_private_decrypt(in_str.size(), (uint8_t *) in_str.data(), (uint8_t *) out_str.data(), rsa,
RSA_PKCS1_PADDING);
}
if (ret > 0) {
out_str.resize(ret);
return out_str;
}
WarnL << getLastError();
return "";
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return "";
#endif //defined(ENABLE_OPENSSL)
}
string SSLUtil::getServerName(X509 *cer) {
#if defined(ENABLE_OPENSSL) && defined(SSL_CTRL_SET_TLSEXT_HOSTNAME)
if (!cer) {
return "";
}
//获取证书里的域名 [AUTO-TRANSLATED:97830946]
//Get domain name from certificate
X509_NAME *name = X509_get_subject_name(cer);
char ret[256] = {0};
X509_NAME_get_text_by_NID(name, NID_commonName, ret, sizeof(ret));
return ret;
#else
return "";
#endif
}
}//namespace toolkit

View File

@ -0,0 +1,193 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLTOOLKIT_SSLUTIL_H
#define ZLTOOLKIT_SSLUTIL_H
#include <memory>
#include <string>
#include <vector>
typedef struct x509_st X509;
typedef struct evp_pkey_st EVP_PKEY;
typedef struct ssl_ctx_st SSL_CTX;
typedef struct ssl_st SSL;
typedef struct bio_st BIO;
namespace toolkit {
/**
* ssl证书后缀一般分为以下几种
* pem:base64的字符编码串
* cer:pem的私钥配合使用
* p12:
* SSL certificate suffixes are generally divided into the following types
* pem: This is a base64 character encoded string, which may contain a public key, private key, or both
* cer: Only and must be a public key, can be used with pem private key
* p12: Must include both private key and public key
* [AUTO-TRANSLATED:1cae2cfa]
*/
class SSLUtil {
public:
static std::string getLastError();
/**
* pem,p12,cer后缀
* openssl加载p12证书时会校验公钥和私钥是否匹对p12的公钥时可能需要传入证书密码
* @param file_path_or_data
* @param isFile
* @return
* Load public key certificate, support pem, p12, cer suffixes
* When openssl loads p12 certificate, it will verify whether the public key and private key match,
* so when loading p12 public key, you may need to pass in the certificate password
* @param file_path_or_data File path or file content
* @param isFile Whether it is a file
* @return Public key certificate list
* [AUTO-TRANSLATED:d9dbac61]
*/
static std::vector<std::shared_ptr<X509> > loadPublicKey(const std::string &file_path_or_data, const std::string &passwd = "", bool isFile = true);
/**
* pem,p12后缀
* @param file_path_or_data
* @param passwd
* @param isFile
* @return
* Load private key certificate, support pem, p12 suffixes
* @param file_path_or_data File path or file content
* @param passwd Password
* @param isFile Whether it is a file
* @return Private key certificate
* [AUTO-TRANSLATED:73c495c8]
*/
static std::shared_ptr<EVP_PKEY> loadPrivateKey(const std::string &file_path_or_data, const std::string &passwd = "", bool isFile = true);
/**
* SSL_CTX对象
* @param cer
* @param key
* @param serverMode
* @return SSL_CTX对象
* Create SSL_CTX object
* @param cer Public key array
* @param key Private key
* @param serverMode Whether it is server mode or client mode
* @return SSL_CTX object
* [AUTO-TRANSLATED:d0faa6a4]
*/
static std::shared_ptr<SSL_CTX> makeSSLContext(const std::vector<std::shared_ptr<X509> > &cers, const std::shared_ptr<EVP_PKEY> &key, bool serverMode = true, bool checkKey = false);
/**
* ssl对象
* @param ctx SSL_CTX对象
* Create ssl object
* @param ctx SSL_CTX object
* [AUTO-TRANSLATED:2e3eb193]
*/
static std::shared_ptr<SSL> makeSSL(SSL_CTX *ctx);
/**
* specifies that the default locations from which CA certificates are loaded should be used.
* There is one default directory and one default file.
* The default CA certificates directory is called "certs" in the default OpenSSL directory.
* Alternatively the SSL_CERT_DIR environment variable can be defined to override this location.
* The default CA certificates file is called "cert.pem" in the default OpenSSL directory.
* Alternatively the SSL_CERT_FILE environment variable can be defined to override this location.
* /usr/local/ssl/certs//usr/local/ssl/cert.pem的证书
* SSL_CERT_FILE将替换/usr/local/ssl/cert.pem的路径
* specifies that the default locations from which CA certificates are loaded should be used.
* There is one default directory and one default file.
* The default CA certificates directory is called "certs" in the default OpenSSL directory.
* Alternatively the SSL_CERT_DIR environment variable can be defined to override this location.
* The default CA certificates file is called "cert.pem" in the default OpenSSL directory.
* Alternatively the SSL_CERT_FILE environment variable can be defined to override this location.
* Trust all certificates in the /usr/local/ssl/certs/ directory and /usr/local/ssl/cert.pem
* The environment variable SSL_CERT_FILE will replace the path of /usr/local/ssl/cert.pem
* [AUTO-TRANSLATED:f13fc4c5]
*/
static bool loadDefaultCAs(SSL_CTX *ctx);
/**
*
* Trust a public key
* [AUTO-TRANSLATED:08987c7e]
*/
static bool trustCertificate(SSL_CTX *ctx, X509 *cer);
/**
*
* @param cer
* @param ... CA根证书X509类型nullptr结尾
* @return
* Verify the validity of the certificate
* @param cer Certificate to be verified
* @param ... Trusted CA root certificates, X509 type, ending with nullptr
* @return Whether it is valid
* [AUTO-TRANSLATED:1b026a8f]
*/
static bool verifyX509(X509 *cer, ...);
/**
* 使
* @param cer ras的公钥
* @param in_str 245256
* @param enc_or_dec true:,false:
* @return
* Use public key to encrypt and decrypt data
* @param cer Public key, must be ras public key
* @param in_str Original data to be encrypted or decrypted, tested to support up to 245 bytes,
* encrypted data length is fixed at 256 bytes
* @param enc_or_dec true: Encrypt, false: Decrypt
* @return Encrypted or decrypted data
* [AUTO-TRANSLATED:77bc2939]
*/
static std::string cryptWithRsaPublicKey(X509 *cer, const std::string &in_str, bool enc_or_dec);
/**
* 使
* @param private_key ras的私钥
* @param in_str 245256
* @param enc_or_dec true:,false:
* @return
* Use private key to encrypt and decrypt data
* @param private_key Private key, must be ras private key
* @param in_str Original data to be encrypted or decrypted, tested to support up to 245 bytes,
* encrypted data length is fixed at 256 bytes
* @param enc_or_dec true: Encrypt, false: Decrypt
* @return Encrypted or decrypted data
* [AUTO-TRANSLATED:a6e4aeb0]
*/
static std::string cryptWithRsaPrivateKey(EVP_PKEY *private_key, const std::string &in_str, bool enc_or_dec);
/**
*
* @param cer
* @return
* Get certificate domain name
* @param cer Certificate public key
* @return Certificate domain name
* [AUTO-TRANSLATED:b3806b53]
*/
static std::string getServerName(X509 *cer);
};
}//namespace toolkit
#endif //ZLTOOLKIT_SSLUTIL_H

View File

@ -0,0 +1,79 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SPEED_STATISTIC_H_
#define SPEED_STATISTIC_H_
#include "TimeTicker.h"
namespace toolkit {
class BytesSpeed {
public:
BytesSpeed() = default;
~BytesSpeed() = default;
/**
*
* Add statistical bytes
* [AUTO-TRANSLATED:d6697ac9]
*/
BytesSpeed &operator+=(size_t bytes) {
_bytes += bytes;
if (_bytes > 1024 * 1024) {
// 数据大于1MB就计算一次网速 [AUTO-TRANSLATED:897af4d6]
// Data greater than 1MB is calculated once for network speed
computeSpeed();
}
_total_bytes += bytes;
return *this;
}
/**
* bytes/s
* Get speed, unit bytes/s
* [AUTO-TRANSLATED:41e26e29]
*/
size_t getSpeed() {
if (_ticker.elapsedTime() < 1000) {
// 获取频率小于1秒那么返回上次计算结果 [AUTO-TRANSLATED:b687b762]
// Get frequency less than 1 second, return the last calculation result
return _speed;
}
return computeSpeed();
}
size_t getTotalBytes() const {
return _total_bytes;
}
private:
size_t computeSpeed() {
auto elapsed = _ticker.elapsedTime();
if (!elapsed) {
return _speed;
}
_speed = (size_t)(_bytes * 1000 / elapsed);
_ticker.resetTime();
_bytes = 0;
return _speed;
}
private:
size_t _speed = 0;
size_t _bytes = 0;
size_t _total_bytes = 0;
Ticker _ticker;
};
} /* namespace toolkit */
#endif /* SPEED_STATISTIC_H_ */

View File

@ -0,0 +1,287 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SQL_SQLCONNECTION_H_
#define SQL_SQLCONNECTION_H_
#include <cstdio>
#include <cstdarg>
#include <cstring>
#include <string>
#include <vector>
#include <list>
#include <deque>
#include <sstream>
#include <iostream>
#include <stdexcept>
#include "logger.h"
#include "util.h"
#include <mysql.h>
#if defined(_WIN32)
#pragma comment (lib,"libmysql")
#endif
namespace toolkit {
/**
*
* Database exception class
* [AUTO-TRANSLATED:f92df85e]
*/
class SqlException : public std::exception {
public:
SqlException(const std::string &sql, const std::string &err) {
_sql = sql;
_err = err;
}
virtual const char *what() const noexcept {
return _err.data();
}
const std::string &getSql() const {
return _sql;
}
private:
std::string _sql;
std::string _err;
};
/**
* mysql连接
* MySQL connection
* [AUTO-TRANSLATED:a2deb48d]
*/
class SqlConnection {
public:
/**
*
* @param url
* @param port
* @param dbname
* @param username
* @param password
* @param character
* Constructor
* @param url Database address
* @param port Database port number
* @param dbname Database name
* @param username Username
* @param password User password
* @param character Character set
* [AUTO-TRANSLATED:410a33a6]
*/
SqlConnection(const std::string &url, unsigned short port,
const std::string &dbname, const std::string &username,
const std::string &password, const std::string &character = "utf8mb4") {
mysql_init(&_sql);
unsigned int timeout = 3;
mysql_options(&_sql, MYSQL_OPT_CONNECT_TIMEOUT, &timeout);
if (!mysql_real_connect(&_sql, url.data(), username.data(),
password.data(), dbname.data(), port, nullptr, 0)) {
mysql_close(&_sql);
throw SqlException("mysql_real_connect", mysql_error(&_sql));
}
//兼容bool与my_bool [AUTO-TRANSLATED:7d8d4190]
//Compatible with bool and my_bool
uint32_t reconnect = 0x01010101;
mysql_options(&_sql, MYSQL_OPT_RECONNECT, &reconnect);
mysql_set_character_set(&_sql, character.data());
}
~SqlConnection() {
mysql_close(&_sql);
}
/**
* printf样式执行sql,
* @param rowId insert时的插入rowid
* @param fmt printf类型fmt
* @param arg
* @return
* Execute SQL in printf style, no data returned
* @param rowId Insert rowid when inserting
* @param fmt printf type fmt
* @param arg Variable argument list
* @return Affected rows
* [AUTO-TRANSLATED:7c72ab80]
*/
template<typename Fmt, typename ...Args>
int64_t query(int64_t &rowId, Fmt &&fmt, Args &&...arg) {
check();
auto tmp = queryString(std::forward<Fmt>(fmt), std::forward<Args>(arg)...);
if (doQuery(tmp)) {
throw SqlException(tmp, mysql_error(&_sql));
}
rowId = mysql_insert_id(&_sql);
return mysql_affected_rows(&_sql);
}
/**
* printf样式执行sql,list类型的结果()
* @param rowId insert时的插入rowid
* @param ret
* @param fmt printf类型fmt
* @param arg
* @return
* Execute SQL in printf style, and return list type result (excluding column names)
* @param rowId Insert rowid when inserting
* @param ret Returned data list
* @param fmt printf type fmt
* @param arg Variable argument list
* @return Affected rows
* [AUTO-TRANSLATED:57baa44e]
*/
template<typename Fmt, typename ...Args>
int64_t query(int64_t &rowId, std::vector<std::vector<std::string> > &ret, Fmt &&fmt, Args &&...arg) {
return queryList(rowId, ret, std::forward<Fmt>(fmt), std::forward<Args>(arg)...);
}
template<typename Fmt, typename... Args>
int64_t query(int64_t &rowId, std::vector<std::list<std::string>> &ret, Fmt &&fmt, Args &&...arg) {
return queryList(rowId, ret, std::forward<Fmt>(fmt), std::forward<Args>(arg)...);
}
template<typename Fmt, typename ...Args>
int64_t query(int64_t &rowId, std::vector<std::deque<std::string> > &ret, Fmt &&fmt, Args &&...arg) {
return queryList(rowId, ret, std::forward<Fmt>(fmt), std::forward<Args>(arg)...);
}
/**
* printf样式执行sql,Map类型的结果()
* @param rowId insert时的插入rowid
* @param ret
* @param fmt printf类型fmt
* @param arg
* @return
* Execute SQL in printf style, and return Map type result (including column names)
* @param rowId Insert rowid when inserting
* @param ret Returned data list
* @param fmt printf type fmt
* @param arg Variable argument list
* @return Affected rows
* [AUTO-TRANSLATED:a12a695e]
*/
template<typename Map, typename Fmt, typename ...Args>
int64_t query(int64_t &rowId, std::vector<Map> &ret, Fmt &&fmt, Args &&...arg) {
check();
auto tmp = queryString(std::forward<Fmt>(fmt), std::forward<Args>(arg)...);
if (doQuery(tmp)) {
throw SqlException(tmp, mysql_error(&_sql));
}
ret.clear();
MYSQL_RES *res = mysql_store_result(&_sql);
if (!res) {
rowId = mysql_insert_id(&_sql);
return mysql_affected_rows(&_sql);
}
MYSQL_ROW row;
unsigned int column = mysql_num_fields(res);
MYSQL_FIELD *fields = mysql_fetch_fields(res);
while ((row = mysql_fetch_row(res)) != nullptr) {
ret.emplace_back();
auto &back = ret.back();
for (unsigned int i = 0; i < column; i++) {
back[std::string(fields[i].name, fields[i].name_length)] = (row[i] ? row[i] : "");
}
}
mysql_free_result(res);
rowId = mysql_insert_id(&_sql);
return mysql_affected_rows(&_sql);
}
std::string escape(const std::string &str) {
char *out = new char[str.length() * 2 + 1];
mysql_real_escape_string(&_sql, out, str.c_str(), str.size());
std::string ret(out);
delete[] out;
return ret;
}
template<typename ...Args>
static std::string queryString(const char *fmt, Args &&...arg) {
char *ptr_out = nullptr;
if (asprintf(&ptr_out, fmt, arg...) > 0 && ptr_out) {
std::string ret(ptr_out);
free(ptr_out);
return ret;
}
return "";
}
template<typename ...Args>
static std::string queryString(const std::string &fmt, Args &&...args) {
return queryString(fmt.data(), std::forward<Args>(args)...);
}
static const char *queryString(const char *fmt) {
return fmt;
}
static const std::string &queryString(const std::string &fmt) {
return fmt;
}
private:
template<typename List, typename Fmt, typename... Args>
int64_t queryList(int64_t &rowId, std::vector<List> &ret, Fmt &&fmt, Args &&...arg) {
check();
auto tmp = queryString(std::forward<Fmt>(fmt), std::forward<Args>(arg)...);
if (doQuery(tmp)) {
throw SqlException(tmp, mysql_error(&_sql));
}
ret.clear();
MYSQL_RES *res = mysql_store_result(&_sql);
if (!res) {
rowId = mysql_insert_id(&_sql);
return mysql_affected_rows(&_sql);
}
MYSQL_ROW row;
unsigned int column = mysql_num_fields(res);
while ((row = mysql_fetch_row(res)) != nullptr) {
ret.emplace_back();
auto &back = ret.back();
for (unsigned int i = 0; i < column; i++) {
back.emplace_back(row[i] ? row[i] : "");
}
}
mysql_free_result(res);
rowId = mysql_insert_id(&_sql);
return mysql_affected_rows(&_sql);
}
inline void check() {
if (mysql_ping(&_sql) != 0) {
throw SqlException("mysql_ping", "Mysql connection ping failed");
}
}
int doQuery(const std::string &sql) {
return mysql_query(&_sql, sql.data());
}
int doQuery(const char *sql) {
return mysql_query(&_sql, sql);
}
private:
MYSQL _sql;
};
} /* namespace toolkit */
#endif /* SQL_SQLCONNECTION_H_ */

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#if defined(ENABLE_MYSQL)
#include <memory>
#include "util.h"
#include "onceToken.h"
#include "SqlPool.h"
using namespace std;
namespace toolkit {
INSTANCE_IMP(SqlPool)
} /* namespace toolkit */
#endif// defined(ENABLE_MYSQL)

View File

@ -0,0 +1,383 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef SQL_SQLPOOL_H_
#define SQL_SQLPOOL_H_
#include <deque>
#include <mutex>
#include <memory>
#include <sstream>
#include <functional>
#include "logger.h"
#include "Poller/Timer.h"
#include "SqlConnection.h"
#include "Thread/WorkThreadPool.h"
#include "ResourcePool.h"
namespace toolkit {
class SqlPool : public std::enable_shared_from_this<SqlPool> {
public:
using Ptr = std::shared_ptr<SqlPool>;
using PoolType = ResourcePool<SqlConnection>;
using SqlRetType = std::vector<std::vector<std::string> >;
static SqlPool &Instance();
~SqlPool() {
_timer.reset();
flushError();
_threadPool.reset();
_pool.reset();
InfoL;
}
/**
*
* @param size
* Sets the number of loop pool objects
* @param size
* [AUTO-TRANSLATED:a3f3a8ac]
*/
void setSize(int size) {
checkInited();
_pool->setSize(size);
}
/**
*
* @tparam Args
* @param arg
* Initializes the loop pool and sets the database connection parameters
* @tparam Args
* @param arg
* [AUTO-TRANSLATED:fff84748]
*/
template<typename ...Args>
void Init(Args &&...arg) {
_pool.reset(new PoolType(std::forward<Args>(arg)...));
_pool->obtain();
}
/**
* sql
* @param str sql语句
* @param tryCnt
* Asynchronously executes SQL
* @param str SQL statement
* @param tryCnt Number of retries
* [AUTO-TRANSLATED:9d2414e1]
*/
template<typename ...Args>
void asyncQuery(Args &&...args) {
asyncQuery_l(SqlConnection::queryString(std::forward<Args>(args)...));
}
/**
* sql
* @tparam Args
* @param arg
* @return
* Synchronously executes SQL
* @tparam Args Variable parameter type list
* @param arg Variable parameter list
* @return Number of affected rows
* [AUTO-TRANSLATED:ba47cb3c]
*/
template<typename ...Args>
int64_t syncQuery(Args &&...arg) {
checkInited();
typename PoolType::ValuePtr mysql;
try {
//捕获执行异常 [AUTO-TRANSLATED:ae8a1093]
//Capture execution exceptions
mysql = _pool->obtain();
return mysql->query(std::forward<Args>(arg)...);
} catch (std::exception &e) {
mysql.quit();
throw;
}
}
/**
* sql转义
* @param str
* @return
* Escapes SQL
* @param str
* @return
* [AUTO-TRANSLATED:e7a99a20]
*/
std::string escape(const std::string &str) {
checkInited();
return _pool->obtain()->escape(const_cast<std::string &>(str));
}
private:
SqlPool() {
_threadPool = WorkThreadPool::Instance().getExecutor();
_timer = std::make_shared<Timer>(30, [this]() {
flushError();
return true;
}, nullptr);
}
/**
* sql
* @param sql sql语句
* @param tryCnt
* Asynchronously executes SQL
* @param sql SQL statement
* @param tryCnt Number of retries
* [AUTO-TRANSLATED:6f585bf1]
*/
void asyncQuery_l(const std::string &sql, int tryCnt = 3) {
auto lam = [this, sql, tryCnt]() {
int64_t rowID;
auto cnt = tryCnt - 1;
try {
syncQuery(rowID, sql);
} catch (std::exception &ex) {
if (cnt > 0) {
//失败重试 [AUTO-TRANSLATED:ef091479]
//Retry on failure
std::lock_guard<std::mutex> lk(_error_query_mutex);
sqlQuery query(sql, cnt);
_error_query.push_back(query);
} else {
WarnL << "SqlPool::syncQuery failed: " << ex.what();
}
}
};
_threadPool->async(lam);
}
/**
* sql
* Periodically retries failed SQL
* [AUTO-TRANSLATED:33048898]
*/
void flushError() {
decltype(_error_query) query_copy;
{
std::lock_guard<std::mutex> lck(_error_query_mutex);
query_copy.swap(_error_query);
}
for (auto &query : query_copy) {
asyncQuery(query.sql_str, query.tryCnt);
}
}
/**
*
* Checks if the database connection pool is initialized
* [AUTO-TRANSLATED:176fceed]
*/
void checkInited() {
if (!_pool) {
throw SqlException("SqlPool::checkInited", "Mysql connection pool not initialized");
}
}
private:
struct sqlQuery {
sqlQuery(const std::string &sql, int cnt) : sql_str(sql), tryCnt(cnt) {}
std::string sql_str;
int tryCnt = 0;
};
private:
std::deque<sqlQuery> _error_query;
TaskExecutor::Ptr _threadPool;
std::mutex _error_query_mutex;
std::shared_ptr<PoolType> _pool;
Timer::Ptr _timer;
};
/**
* Sql语句生成器''sql语句
* SQL statement generator, generates SQL statements through the '?' placeholder
* [AUTO-TRANSLATED:12f34981]
*/
class SqlStream {
public:
SqlStream(const char *sql) : _sql(sql) {}
~SqlStream() {}
template<typename T>
SqlStream &operator<<(T &&data) {
auto pos = _sql.find('?', _startPos);
if (pos == std::string::npos) {
return *this;
}
_str_tmp.str("");
_str_tmp << std::forward<T>(data);
std::string str = SqlPool::Instance().escape(_str_tmp.str());
_startPos = pos + str.size();
_sql.replace(pos, 1, str);
return *this;
}
const std::string &operator<<(std::ostream &(*f)(std::ostream &)) const {
return _sql;
}
operator std::string() {
return _sql;
}
private:
std::stringstream _str_tmp;
std::string _sql;
std::string::size_type _startPos = 0;
};
/**
* sql查询器
* SQL query executor
* [AUTO-TRANSLATED:50396624]
*/
class SqlWriter {
public:
/**
*
* @param sql ''sql模板
* @param throwAble
* Constructor
* @param sql SQL template with '?' placeholder
* @param throwAble Whether to throw exceptions
* [AUTO-TRANSLATED:97c6d354]
*/
SqlWriter(const char *sql, bool throwAble = true) : _sqlstream(sql), _throwAble(throwAble) {}
~SqlWriter() {}
/**
* ''便sql语句
* @tparam T
* @param data
* @return
* Replaces '?' placeholders with input parameters to generate SQL statements; may throw exceptions
* @tparam T Parameter type
* @param data Parameter
* @return Self-reference
* [AUTO-TRANSLATED:9bdc6917]
*/
template<typename T>
SqlWriter &operator<<(T &&data) {
try {
_sqlstream << std::forward<T>(data);
} catch (std::exception &ex) {
//在转义sql时可能抛异常 [AUTO-TRANSLATED:ce6314cc]
//May throw exceptions when escaping SQL
if (!_throwAble) {
WarnL << "Commit sql failed: " << ex.what();
} else {
throw;
}
}
return *this;
}
/**
* sql
* @param f std::endl
* Asynchronously executes SQL, does not throw exceptions
* @param f std::endl
* [AUTO-TRANSLATED:e203d266]
*/
void operator<<(std::ostream &(*f)(std::ostream &)) {
//异步执行sql不会抛异常 [AUTO-TRANSLATED:4370797e]
//Asynchronously executes SQL, does not throw exceptions
SqlPool::Instance().asyncQuery((std::string) _sqlstream);
}
/**
* sql
* @tparam Row vector<string>/list<string> obj.emplace_back("value")
* map<string,string>/Json::Value obj["key"] = "value"
* @param ret
* @return
* Synchronously executes SQL, may throw exceptions
* @tparam Row Data row type, can be vector<string>/list<string> or other types that support obj.emplace_back("value") operations
* Can also be map<string,string>/Json::Value or other types that support obj["key"] = "value" operations
* @param ret Data storage object
* @return Number of affected rows
* [AUTO-TRANSLATED:d8e40f96]
*/
template<typename Row>
int64_t operator<<(std::vector<Row> &ret) {
try {
_affectedRows = SqlPool::Instance().syncQuery(_rowId, ret, (std::string) _sqlstream);
} catch (std::exception &ex) {
if (!_throwAble) {
WarnL << "SqlPool::syncQuery failed: " << ex.what();
} else {
throw;
}
}
return _affectedRows;
}
/**
* insert数据库时返回插入的rowid
* @return
* Returns the rowid inserted into the database when inserting data
* @return
* [AUTO-TRANSLATED:699edcc4]
*/
int64_t getRowID() const {
return _rowId;
}
/**
*
* @return
* Returns the number of rows affected in the database
* @return
* [AUTO-TRANSLATED:81af02d9]
*/
int64_t getAffectedRows() const {
return _affectedRows;
}
private:
SqlStream _sqlstream;
int64_t _rowId = -1;
int64_t _affectedRows = -1;
bool _throwAble = true;
};
} /* namespace toolkit */
#endif /* SQL_SQLPOOL_H_ */

View File

@ -0,0 +1,172 @@
/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef UTIL_TIMETICKER_H_
#define UTIL_TIMETICKER_H_
#include <cassert>
#include "logger.h"
namespace toolkit {
class Ticker {
public:
/**
*
* @param min_ms
* @param ctx
* @param print_log
* This object can be used for code execution time statistics, and can be used for general timing
* @param min_ms When the code execution time statistics is enabled, if the code execution time exceeds this parameter, a warning log is printed
* @param ctx Log context capture, used to capture the current log code location
* @param print_log Whether to print the code execution time
* [AUTO-TRANSLATED:4436cf19]
*/
Ticker(uint64_t min_ms = 0,
LogContextCapture ctx = LogContextCapture(Logger::Instance(), LWarn, __FILE__, "", __LINE__),
bool print_log = false) : _ctx(std::move(ctx)) {
if (!print_log) {
_ctx.clear();
}
_created = _begin = getCurrentMillisecond();
_min_ms = min_ms;
}
~Ticker() {
uint64_t tm = createdTime();
if (tm > _min_ms) {
_ctx << "take time: " << tm << "ms" << ", thread may be overloaded";
} else {
_ctx.clear();
}
}
/**
* resetTime后至今的时间
* Get the time from the last resetTime to now, in milliseconds
* [AUTO-TRANSLATED:739ad90a]
*/
uint64_t elapsedTime() const {
return getCurrentMillisecond() - _begin;
}
/**
*
* Get the time from creation to now, in milliseconds
* [AUTO-TRANSLATED:83a189e2]
*/
uint64_t createdTime() const {
return getCurrentMillisecond() - _created;
}
/**
*
* Reset the timer
* [AUTO-TRANSLATED:2500c6f1]
*/
void resetTime() {
_begin = getCurrentMillisecond();
}
private:
uint64_t _min_ms;
uint64_t _begin;
uint64_t _created;
LogContextCapture _ctx;
};
class SmoothTicker {
public:
/**
*
* @param reset_ms reset_ms毫秒,
* This object is used to generate smooth timestamps
* @param reset_ms Timestamp reset interval, every reset_ms milliseconds, the generated timestamp will be synchronized with the system timestamp
* [AUTO-TRANSLATED:0ff567e7]
*/
SmoothTicker(uint64_t reset_ms = 10000) {
_reset_ms = reset_ms;
_ticker.resetTime();
}
~SmoothTicker() {}
/**
*
* Return a smooth timestamp, to prevent the timestamp from being unsmooth due to network jitter
* [AUTO-TRANSLATED:26f78ae3]
*/
uint64_t elapsedTime() {
auto now_time = _ticker.elapsedTime();
if (_first_time == 0) {
if (now_time < _last_time) {
auto last_time = _last_time - _time_inc;
double elapse_time = (now_time - last_time);
_time_inc += (elapse_time / ++_pkt_count) / 3;
auto ret_time = last_time + _time_inc;
_last_time = (uint64_t) ret_time;
return (uint64_t) ret_time;
}
_first_time = now_time;
_last_time = now_time;
_pkt_count = 0;
_time_inc = 0;
return now_time;
}
auto elapse_time = (now_time - _first_time);
_time_inc += elapse_time / ++_pkt_count;
auto ret_time = _first_time + _time_inc;
if (elapse_time > _reset_ms) {
_first_time = 0;
}
_last_time = (uint64_t) ret_time;
return (uint64_t) ret_time;
}
/**
* 0
* Reset the timestamp to start from 0
* [AUTO-TRANSLATED:ca42c3d1]
*/
void resetTime() {
_first_time = 0;
_pkt_count = 0;
_ticker.resetTime();
}
private:
double _time_inc = 0;
uint64_t _first_time = 0;
uint64_t _last_time = 0;
uint64_t _pkt_count = 0;
uint64_t _reset_ms;
Ticker _ticker;
};
#if !defined(NDEBUG)
#define TimeTicker() Ticker __ticker(5,WarnL,true)
#define TimeTicker1(tm) Ticker __ticker1(tm,WarnL,true)
#define TimeTicker2(tm, log) Ticker __ticker2(tm,log,true)
#else
#define TimeTicker()
#define TimeTicker1(tm)
#define TimeTicker2(tm,log)
#endif
} /* namespace toolkit */
#endif /* UTIL_TIMETICKER_H_ */

View File

@ -0,0 +1,202 @@
/*
* Copyright (c) 2006 Ryan Martell. (rdm4@martellventures.com)
*
* This file is part of FFmpeg.
*
* FFmpeg is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* FFmpeg is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with FFmpeg; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
/**
* @file
* @brief Base64 encode/decode
* @author Ryan Martell <rdm4@martellventures.com> (with lots of Michael)
*/
//#include "common.h"
#include "stdio.h"
#include "base64.h"
#include <memory>
#include <limits.h>
using namespace std;
/* ---------------- private code */
static const uint8_t map2[] =
{
0x3e, 0xff, 0xff, 0xff, 0x3f, 0x34, 0x35, 0x36,
0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01,
0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1a, 0x1b,
0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23,
0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33
};
int av_base64_decode(uint8_t *out, const char *in, int out_size)
{
int i, v;
uint8_t *dst = out;
v = 0;
for (i = 0; in[i] && in[i] != '='; i++) {
unsigned int index= in[i]-43;
if (index>=FF_ARRAY_ELEMS(map2) || map2[index] == 0xff)
return -1;
v = (v << 6) + map2[index];
if (i & 3) {
if (dst - out < out_size) {
*dst++ = v >> (6 - 2 * (i & 3));
}
}
}
return dst - out;
}
/*****************************************************************************
* b64_encode: Stolen from VLC's http.c.
* Simplified by Michael.
* Fixed edge cases and made it work from data (vs. strings) by Ryan.
*****************************************************************************/
char *av_base64_encode_l(char *out, int *out_size, const uint8_t *in, int in_size) {
static const char b64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
char *ret, *dst;
unsigned i_bits = 0;
int i_shift = 0;
int bytes_remaining = in_size;
if ((size_t)in_size >= UINT_MAX / 4 || *out_size < AV_BASE64_SIZE(in_size)) {
return nullptr;
}
ret = dst = out;
while (bytes_remaining) {
i_bits = (i_bits << 8) + *in++;
bytes_remaining--;
i_shift += 8;
do {
*dst++ = b64[(i_bits << 6 >> i_shift) & 0x3f];
i_shift -= 6;
} while (i_shift > 6 || (bytes_remaining == 0 && i_shift > 0));
}
while ((dst - ret) & 3)
*dst++ = '=';
*dst = '\0';
*out_size = dst - out;
return ret;
}
char *av_base64_encode(char *out, int out_size, const uint8_t *in, int in_size) {
return av_base64_encode_l(out, &out_size, in, in_size);
}
string encodeBase64(const string &txt) {
if (txt.empty()) {
return "";
}
int size = AV_BASE64_SIZE(txt.size()) + 10;
string ret;
ret.resize(size);
if (!av_base64_encode_l((char *) ret.data(), &size, (const uint8_t *) txt.data(), txt.size())) {
return "";
}
ret.resize(size);
return ret;
}
string decodeBase64(const string &txt){
if (txt.empty()) {
return "";
}
string ret;
ret.resize(txt.size() * 3 / 4 + 10);
auto size = av_base64_decode((uint8_t *) ret.data(), txt.data(), ret.size());
if (size <= 0) {
return "";
}
ret.resize(size);
return ret;
}
#ifdef TEST
#undef printf
#define MAX_DATA_SIZE 1024
#define MAX_ENCODED_SIZE 2048
static int test_encode_decode(const uint8_t *data, unsigned int data_size,
const char *encoded_ref)
{
char encoded[MAX_ENCODED_SIZE];
uint8_t data2[MAX_DATA_SIZE];
int data2_size, max_data2_size = MAX_DATA_SIZE;
if (!av_base64_encode(encoded, MAX_ENCODED_SIZE, data, data_size)) {
printf("Failed: cannot encode the input data\n");
return 1;
}
if (encoded_ref && strcmp(encoded, encoded_ref)) {
printf("Failed: encoded string differs from reference\n"
"Encoded:\n%s\nReference:\n%s\n", encoded, encoded_ref);
return 1;
}
if ((data2_size = av_base64_decode(data2, encoded, max_data2_size)) < 0) {
printf("Failed: cannot decode the encoded string\n"
"Encoded:\n%s\n", encoded);
return 1;
}
if (memcmp(data2, data, data_size)) {
printf("Failed: encoded/decoded data differs from original data\n");
return 1;
}
printf("Passed!\n");
return 0;
}
int main(void)
{
int i, error_count = 0;
struct test {
const uint8_t *data;
const char *encoded_ref;
} tests[] = {
{ "", ""},
{ "1", "MQ=="},
{ "22", "MjI="},
{ "333", "MzMz"},
{ "4444", "NDQ0NA=="},
{ "55555", "NTU1NTU="},
{ "666666", "NjY2NjY2"},
{ "abc:def", "YWJjOmRlZg=="},
};
printf("Encoding/decoding tests\n");
for (i = 0; i < FF_ARRAY_ELEMS(tests); i++)
error_count += test_encode_decode(tests[i].data, strlen(tests[i].data), tests[i].encoded_ref);
return error_count;
}
#endif

Some files were not shown because too many files have changed in this diff Show More