From 2fd567b8b0aa94ff49dc3bf1490e3eed39394b42 Mon Sep 17 00:00:00 2001 From: xiongziliang <771730766@qq.com> Date: Sat, 8 Aug 2020 12:17:06 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81ws-flv=E7=9B=B4=E6=92=AD=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=AE=A2=E6=88=B7=E7=AB=AF=E4=B8=BB=E5=8A=A8=E5=85=B3?= =?UTF-8?q?=E9=97=AD=E8=AF=B7=E6=B1=82:#430=202=E3=80=81=E5=85=BC=E5=AE=B9?= =?UTF-8?q?CONTINUATION=E7=B1=BB=E5=9E=8B=E7=9A=84websocket=E5=8C=85=203?= =?UTF-8?q?=E3=80=81=E4=BF=AE=E5=A4=8Dwebsocket=E5=AE=A2=E6=88=B7=E7=AB=AF?= =?UTF-8?q?=E5=9C=A8=E5=A4=84=E7=90=86Content-Length=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E7=9B=B8=E5=85=B3bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Http/HttpSession.cpp | 40 +++++++++++++++------- src/Http/HttpSession.h | 11 +++++- src/Http/WebSocketClient.h | 63 ++++++++++++++++++++++------------ src/Http/WebSocketSession.h | 66 ++++++++++++++++++++++-------------- src/Http/WebSocketSplitter.h | 30 ++++++++++++++-- 5 files changed, 149 insertions(+), 61 deletions(-) diff --git a/src/Http/HttpSession.cpp b/src/Http/HttpSession.cpp index f161342f..7e26502a 100644 --- a/src/Http/HttpSession.cpp +++ b/src/Http/HttpSession.cpp @@ -132,36 +132,37 @@ void HttpSession::onManager() { bool HttpSession::checkWebSocket(){ auto Sec_WebSocket_Key = _parser["Sec-WebSocket-Key"]; - if(Sec_WebSocket_Key.empty()){ + if (Sec_WebSocket_Key.empty()) { return false; } - auto Sec_WebSocket_Accept = encodeBase64(SHA1::encode_bin(Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); + auto Sec_WebSocket_Accept = encodeBase64( + SHA1::encode_bin(Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); KeyValue headerOut; headerOut["Upgrade"] = "websocket"; headerOut["Connection"] = "Upgrade"; headerOut["Sec-WebSocket-Accept"] = Sec_WebSocket_Accept; - if(!_parser["Sec-WebSocket-Protocol"].empty()){ + if (!_parser["Sec-WebSocket-Protocol"].empty()) { headerOut["Sec-WebSocket-Protocol"] = _parser["Sec-WebSocket-Protocol"]; } - auto res_cb = [this,headerOut](){ + auto res_cb = [this, headerOut]() { _flv_over_websocket = true; - sendResponse("101 Switching Protocols",false,nullptr,headerOut,nullptr, true); + sendResponse("101 Switching Protocols", false, nullptr, headerOut, nullptr, true); }; //判断是否为websocket-flv - if(checkLiveFlvStream(res_cb)){ + if (checkLiveFlvStream(res_cb)) { //这里是websocket-flv直播请求 return true; } //如果checkLiveFlvStream返回false,则代表不是websocket-flv,而是普通的websocket连接 - if(!onWebSocketConnect(_parser)){ - sendResponse("501 Not Implemented",true, nullptr, headerOut); + if (!onWebSocketConnect(_parser)) { + sendResponse("501 Not Implemented", true, nullptr, headerOut); return true; } - sendResponse("101 Switching Protocols",false, nullptr,headerOut); + sendResponse("101 Switching Protocols", false, nullptr, headerOut, nullptr, true); return true; } @@ -389,7 +390,7 @@ void HttpSession::sendResponse(const char *pcStatus, const char *pcContentType, const HttpSession::KeyValue &header, const HttpBody::Ptr &body, - bool is_http_flv ){ + bool no_content_length ){ GET_CONFIG(string,charSet,Http::kCharSet); GET_CONFIG(uint32_t,keepAliveSec,Http::kKeepAliveSecond); @@ -400,7 +401,7 @@ void HttpSession::sendResponse(const char *pcStatus, size = body->remainSize(); } - if(is_http_flv){ + if(no_content_length){ //http-flv直播是Keep-Alive类型 bClose = false; }else if(size >= INT64_MAX){ @@ -425,7 +426,7 @@ void HttpSession::sendResponse(const char *pcStatus, headerOut.emplace(kAccessControlAllowCredentials, "true"); } - if(!is_http_flv && size >= 0 && size < INT64_MAX){ + if(!no_content_length && size >= 0 && size < INT64_MAX){ //文件长度为固定值,且不是http-flv强制设置Content-Length headerOut[kContentLength] = to_string(size); } @@ -645,6 +646,21 @@ void HttpSession::onWebSocketEncodeData(const Buffer::Ptr &buffer){ send(buffer); } +void HttpSession::onWebSocketDecodeComplete(const WebSocketHeader &header_in){ + WebSocketHeader& header = const_cast(header_in); + header._mask_flag = false; + + switch (header._opcode) { + case WebSocketHeader::CLOSE: { + encode(header, nullptr); + shutdown(SockException(Err_shutdown, "recv close request from client")); + break; + } + + default : break; + } +} + void HttpSession::onDetach() { shutdown(SockException(Err_shutdown,"rtmp ring buffer detached")); } diff --git a/src/Http/HttpSession.h b/src/Http/HttpSession.h index 09bcfca5..b842457b 100644 --- a/src/Http/HttpSession.h +++ b/src/Http/HttpSession.h @@ -47,6 +47,7 @@ public: void onError(const SockException &err) override; void onManager() override; static string urlDecode(const string &str); + protected: //FlvMuxer override void onWrite(const Buffer::Ptr &data, bool flush) override ; @@ -90,6 +91,13 @@ protected: * @param buffer websocket协议数据 */ void onWebSocketEncodeData(const Buffer::Ptr &buffer) override; + + /** + * 接收到完整的一个webSocket数据包后回调 + * @param header 数据包包头 + */ + void onWebSocketDecodeComplete(const WebSocketHeader &header_in) override; + private: void Handle_Req_GET(int64_t &content_len); void Handle_Req_GET_l(int64_t &content_len, bool sendBody); @@ -103,10 +111,11 @@ private: void sendNotFound(bool bClose); void sendResponse(const char *pcStatus, bool bClose, const char *pcContentType = nullptr, const HttpSession::KeyValue &header = HttpSession::KeyValue(), - const HttpBody::Ptr &body = nullptr,bool is_http_flv = false); + const HttpBody::Ptr &body = nullptr, bool no_content_length = false); //设置socket标志 void setSocketFlags(); + private: string _origin; Parser _parser; diff --git a/src/Http/WebSocketClient.h b/src/Http/WebSocketClient.h index 93f140be..d92f1e33 100644 --- a/src/Http/WebSocketClient.h +++ b/src/Http/WebSocketClient.h @@ -38,11 +38,10 @@ public: template ClientTypeImp(ArgsType &&...args): ClientType(std::forward(args)...){} ~ClientTypeImp() override {}; + protected: /** * 发送前拦截并打包为websocket协议 - * @param buf - * @return */ int send(const Buffer::Ptr &buf) override{ if(_beforeSendCB){ @@ -50,6 +49,7 @@ protected: } return ClientType::send(buf); } + /** * 设置发送数据截取回调函数 * @param cb 截取回调函数 @@ -57,6 +57,7 @@ protected: void setOnBeforeSendCB(const onBeforeSendCB &cb){ _beforeSendCB = cb; } + private: onBeforeSendCB _beforeSendCB; }; @@ -108,6 +109,7 @@ public: header._mask_flag = true; WebSocketSplitter::encode(header, nullptr); } + protected: //HttpClientImp override @@ -124,6 +126,8 @@ protected: if(Sec_WebSocket_Accept == const_cast(headers)["Sec-WebSocket-Accept"]){ //success onWebSocketException(SockException()); + //防止ws服务器返回Content-Length + const_cast(headers).erase("Content-Length"); //后续全是websocket负载数据 return -1; } @@ -180,7 +184,6 @@ protected: /** * tcp连接结果 - * @param ex */ void onConnect(const SockException &ex) override{ if(ex){ @@ -194,7 +197,6 @@ protected: /** * tcp连接断开 - * @param ex */ void onErr(const SockException &ex) override{ //tcp断开或者shutdown导致的断开 @@ -208,7 +210,7 @@ protected: * @param header 数据包头 */ void onWebSocketDecodeHeader(const WebSocketHeader &header) override{ - _payload.clear(); + _payload_section.clear(); } /** @@ -219,10 +221,9 @@ protected: * @param recved 已接收数据长度(包含本次数据长度),等于header._payload_len时则接受完毕 */ void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, uint64_t len, uint64_t recved) override{ - _payload.append((char *)ptr,len); + _payload_section.append((char *)ptr,len); } - /** * 接收到完整的一个webSocket数据包后回调 * @param header 数据包包头 @@ -238,28 +239,46 @@ protected: //服务器主动关闭 WebSocketSplitter::encode(header,nullptr); shutdown(SockException(Err_eof,"websocket server close the connection")); - } break; + } + case WebSocketHeader::PING:{ //心跳包 header._opcode = WebSocketHeader::PONG; - WebSocketSplitter::encode(header,std::make_shared(std::move(_payload))); - } + WebSocketSplitter::encode(header,std::make_shared(std::move(_payload_section))); break; - case WebSocketHeader::CONTINUATION:{ + } - } - break; + case WebSocketHeader::CONTINUATION: case WebSocketHeader::TEXT: case WebSocketHeader::BINARY:{ - //接收完毕websocket数据包,触发onRecv事件 - _delegate.onRecv(std::make_shared(std::move(_payload))); + if (!header._fin) { + //还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出 + _payload_cache.append(std::move(_payload_section)); + if (_payload_cache.size() < MAX_WS_PACKET) { + //还有内存容量缓存分片数据 + break; + } + //分片缓存太大,需要清空 + } + + //最后一个包 + if (_payload_cache.empty()) { + //这个包是唯一个分片 + _delegate.onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_section))); + break; + } + + //这个包由多个分片组成 + _payload_cache.append(std::move(_payload_section)); + _delegate.onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_cache))); + _payload_cache.clear(); + break; } - break; - default: - break; + + default: break; } - _payload.clear(); + _payload_section.clear(); header._mask_flag = flag; } @@ -271,6 +290,7 @@ protected: void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{ HttpClientImp::send(buffer); } + private: void onWebSocketException(const SockException &ex){ if(!ex){ @@ -319,10 +339,10 @@ private: string _Sec_WebSocket_Key; function _onRecv; ClientTypeImp &_delegate; - string _payload; + string _payload_section; + string _payload_cache; }; - /** * Tcp客户端转WebSocket客户端模板, * 通过该模板,开发者再不修改TcpClient派生类任何代码的情况下快速实现WebSocket协议的包装 @@ -365,6 +385,7 @@ public: void startWebSocket(const string &ws_url,float fTimeOutSec = 3){ _wsClient->startWsClient(ws_url,fTimeOutSec); } + private: typename HttpWsClient::Ptr _wsClient; }; diff --git a/src/Http/WebSocketSession.h b/src/Http/WebSocketSession.h index 440a99e7..65840250 100644 --- a/src/Http/WebSocketSession.h +++ b/src/Http/WebSocketSession.h @@ -78,7 +78,6 @@ public: } }; - /** * 通过该模板类可以透明化WebSocket协议, * 用户只要实现WebSock协议下的具体业务协议,譬如基于WebSocket协议的Rtmp协议等 @@ -107,8 +106,9 @@ public: void attachServer(const TcpServer &server) override{ HttpSessionType::attachServer(server); - _weakServer = const_cast(server).shared_from_this(); + _weak_server = const_cast(server).shared_from_this(); } + protected: /** * websocket客户端连接上事件 @@ -122,7 +122,7 @@ protected: //此url不允许创建websocket连接 return false; } - auto strongServer = _weakServer.lock(); + auto strongServer = _weak_server.lock(); if(strongServer){ _session->attachServer(*strongServer); } @@ -145,24 +145,20 @@ protected: //允许websocket客户端 return true; } + /** * 开始收到一个webSocket数据包 - * @param packet */ void onWebSocketDecodeHeader(const WebSocketHeader &packet) override{ //新包,原来的包残余数据清空掉 - _remian_data.clear(); + _payload_section.clear(); } /** * 收到websocket数据包负载 - * @param packet - * @param ptr - * @param len - * @param recved */ void onWebSocketDecodePayload(const WebSocketHeader &packet,const uint8_t *ptr,uint64_t len,uint64_t recved) override { - _remian_data.append((char *)ptr,len); + _payload_section.append((char *)ptr,len); } /** @@ -178,39 +174,59 @@ protected: case WebSocketHeader::CLOSE:{ HttpSessionType::encode(header,nullptr); HttpSessionType::shutdown(SockException(Err_shutdown, "recv close request from client")); - } break; + } + case WebSocketHeader::PING:{ header._opcode = WebSocketHeader::PONG; - HttpSessionType::encode(header,std::make_shared(_remian_data)); - } + HttpSessionType::encode(header,std::make_shared(_payload_section)); break; - case WebSocketHeader::CONTINUATION:{ - } - break; + + case WebSocketHeader::CONTINUATION: case WebSocketHeader::TEXT: case WebSocketHeader::BINARY:{ - _session->onRecv(std::make_shared(_remian_data)); + if (!header._fin) { + //还有后续分片数据, 我们先缓存数据,所有分片收集完成才一次性输出 + _payload_cache.append(std::move(_payload_section)); + if (_payload_cache.size() < MAX_WS_PACKET) { + //还有内存容量缓存分片数据 + break; + } + //分片缓存太大,需要清空 + } + + //最后一个包 + if (_payload_cache.empty()) { + //这个包是唯一个分片 + _session->onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_section))); + break; + } + + //这个包由多个分片组成 + _payload_cache.append(std::move(_payload_section)); + _session->onRecv(std::make_shared(header._opcode, header._fin, std::move(_payload_cache))); + _payload_cache.clear(); + break; } - break; - default: - break; + + default: break; } - _remian_data.clear(); + _payload_section.clear(); header._mask_flag = flag; } /** - * 发送数据进行websocket协议打包后回调 - * @param buffer + * 发送数据进行websocket协议打包后回调 */ void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{ HttpSessionType::send(buffer); } + private: - string _remian_data; - weak_ptr _weakServer; + string _payload_cache; + string _payload_section; + weak_ptr _weak_server; TcpSession::Ptr _session; Creator _creator; }; diff --git a/src/Http/WebSocketSplitter.h b/src/Http/WebSocketSplitter.h index 9b2bbea3..e241406c 100644 --- a/src/Http/WebSocketSplitter.h +++ b/src/Http/WebSocketSplitter.h @@ -16,10 +16,12 @@ #include #include #include "Network/Buffer.h" - using namespace std; using namespace toolkit; +//websocket组合包最大不得超过4MB(防止内存爆炸) +#define MAX_WS_PACKET (4 * 1024 * 1024) + namespace mediakit { class WebSocketHeader { @@ -44,6 +46,7 @@ public: CONTROL_RSVF = 0xF } Type; public: + WebSocketHeader() : _mask(4){ //获取_mask内部buffer的内存地址,该内存是malloc开辟的,地址为随机 uint64_t ptr = (uint64_t)(&_mask[0]); @@ -51,6 +54,7 @@ public: _mask.assign((uint8_t*)(&ptr), (uint8_t*)(&ptr) + 4); } virtual ~WebSocketHeader(){} + public: bool _fin; uint8_t _reserved; @@ -60,6 +64,26 @@ public: vector _mask; }; +//websocket协议收到的字符串类型缓存,用户协议层获取该数据传输的方式 +class WebSocketBuffer : public BufferString { +public: + typedef std::shared_ptr Ptr; + + template + WebSocketBuffer(WebSocketHeader::Type headType, bool fin, ARGS &&...args) + : _head_type(headType), _fin(fin), BufferString(std::forward(args)...) {} + + ~WebSocketBuffer() override {} + + WebSocketHeader::Type headType() const { return _head_type; } + + bool isFinished() const { return _fin; }; + +private: + WebSocketHeader::Type _head_type; + bool _fin; +}; + class WebSocketSplitter : public WebSocketHeader{ public: WebSocketSplitter(){} @@ -80,6 +104,7 @@ public: * @param buffer 负载数据 */ void encode(const WebSocketHeader &header,const Buffer::Ptr &buffer); + protected: /** * 收到一个webSocket数据包包头,后续将继续触发onWebSocketDecodePayload回调 @@ -96,7 +121,6 @@ protected: */ virtual void onWebSocketDecodePayload(const WebSocketHeader &header, const uint8_t *ptr, uint64_t len, uint64_t recved) {}; - /** * 接收到完整的一个webSocket数据包后回调 * @param header 数据包包头 @@ -109,8 +133,10 @@ protected: * @param len 数据指针长度 */ virtual void onWebSocketEncodeData(const Buffer::Ptr &buffer){}; + private: void onPayloadData(uint8_t *data, uint64_t len); + private: string _remain_data; int _mask_offset = 0;