初步添加rtsp转webrtc相关功能

This commit is contained in:
ziyue 2021-03-24 16:52:41 +08:00
parent 2f0bdf2724
commit 65e470e060
20 changed files with 4000 additions and 0 deletions

View File

@ -55,6 +55,7 @@ option(ENABLE_TESTS "Enable Tests" true)
option(ENABLE_SERVER "Enable Server" true) option(ENABLE_SERVER "Enable Server" true)
option(ENABLE_MEM_DEBUG "Enable Memory Debug" false) option(ENABLE_MEM_DEBUG "Enable Memory Debug" false)
option(ENABLE_ASAN "Enable Address Sanitize" false) option(ENABLE_ASAN "Enable Address Sanitize" false)
option(ENABLE_WEBRTC "Enable WebRTC" true)
if (ENABLE_MEM_DEBUG) if (ENABLE_MEM_DEBUG)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-wrap,free -Wl,-wrap,malloc -Wl,-wrap,realloc -Wl,-wrap,calloc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-wrap,free -Wl,-wrap,malloc -Wl,-wrap,realloc -Wl,-wrap,calloc")
@ -224,6 +225,11 @@ if(ENABLE_API)
add_subdirectory(api) add_subdirectory(api)
endif() endif()
if(ENABLE_WEBRTC)
add_definitions(-DENABLE_WEBRTC)
add_subdirectory(webrtc)
endif()
if (NOT IOS) if (NOT IOS)
# #
if(ENABLE_TESTS) if(ENABLE_TESTS)

55
cmake/FindSRTP.cmake Normal file
View File

@ -0,0 +1,55 @@
############################################################################
# FindSRTP.txt
# Copyright (C) 2014 Belledonne Communications, Grenoble France
#
############################################################################
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
############################################################################
#
# - Find the SRTP include file and library
#
# SRTP_FOUND - system has SRTP
# SRTP_INCLUDE_DIRS - the SRTP include directory
# SRTP_LIBRARIES - The libraries needed to use SRTP
set(_SRTP_ROOT_PATHS
${CMAKE_INSTALL_PREFIX}
)
find_path(SRTP_INCLUDE_DIRS
NAMES srtp2/srtp.h
HINTS _SRTP_ROOT_PATHS
PATH_SUFFIXES include
)
if(SRTP_INCLUDE_DIRS)
set(HAVE_SRTP_SRTP_H 1)
endif()
find_library(SRTP_LIBRARIES
NAMES srtp2
HINTS ${_SRTP_ROOT_PATHS}
PATH_SUFFIXES bin lib
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(SRTP
DEFAULT_MSG
SRTP_INCLUDE_DIRS SRTP_LIBRARIES HAVE_SRTP_SRTP_H
)
mark_as_advanced(SRTP_INCLUDE_DIRS SRTP_LIBRARIES HAVE_SRTP_SRTP_H)

View File

@ -49,3 +49,4 @@ else()
endif() endif()
target_link_libraries(MediaServer jsoncpp ${LINK_LIB_LIST}) target_link_libraries(MediaServer jsoncpp ${LINK_LIB_LIST})
message(${LINK_LIB_LIST})

View File

@ -35,6 +35,9 @@
#if defined(ENABLE_RTPPROXY) #if defined(ENABLE_RTPPROXY)
#include "Rtp/RtpServer.h" #include "Rtp/RtpServer.h"
#endif #endif
#ifdef ENABLE_WEBRTC
#include "../webrtc/webrtc_transport.h"
#endif
using namespace toolkit; using namespace toolkit;
using namespace mediakit; using namespace mediakit;
@ -1049,6 +1052,27 @@ void installWebApi() {
#endif #endif
}); });
#ifdef ENABLE_WEBRTC
static list<WebRtcTransportImp::Ptr> rtcs;
api_regist("/webrtc",[](API_ARGS_MAP_ASYNC){
CHECK_ARGS("app", "stream");
auto src = dynamic_pointer_cast<RtspMediaSource>(MediaSource::find(RTSP_SCHEMA, DEFAULT_VHOST, allArgs["app"], allArgs["stream"]));
if (!src) {
throw ApiRetException("流不存在", API::NotFound);
}
headerOut["Content-Type"] = "text/plain";
headerOut["Access-Control-Allow-Origin"] = "*";
auto poller = EventPollerPool::Instance().getFirstPoller();
auto rtc = std::make_shared<WebRtcTransportImp>(poller);
poller->async([invoker, rtc, headerOut, src]() {
rtc->attach(src);
auto sdp = rtc->GetLocalSdp();
invoker(200, headerOut, sdp);
rtcs.emplace_back(rtc);
});
});
#endif
////////////以下是注册的Hook API//////////// ////////////以下是注册的Hook API////////////
api_regist("/index/hook/on_publish",[](API_ARGS_MAP){ api_regist("/index/hook/on_publish",[](API_ARGS_MAP){
//开始推流事件 //开始推流事件

16
webrtc/CMakeLists.txt Normal file
View File

@ -0,0 +1,16 @@
list(APPEND LINK_LIB_LIST webrtc)
#srtp
find_package(SRTP QUIET)
if (SRTP_FOUND)
message(STATUS "found library:${SRTP_LIBRARIES}")
include_directories(${SRTP_INCLUDE_DIRS})
list(APPEND LINK_LIB_LIST ${SRTP_LIBRARIES})
else ()
message(FATAL_ERROR "srtp未找到!")
endif ()
include_directories(./)
file(GLOB SRC_LIST ./*.*)
add_library(webrtc ${SRC_LIST})
set(LINK_LIB_LIST ${LINK_LIB_LIST} PARENT_SCOPE)

75
webrtc/dtls_transport.cc Normal file
View File

@ -0,0 +1,75 @@
//
// Created by xueyuegui on 19-12-7.
//
#include "dtls_transport.h"
#include <iostream>
DtlsTransport::DtlsTransport(bool is_server) : is_server_(is_server) {
dtls_transport_.reset(new RTC::DtlsTransport(this));
}
DtlsTransport::~DtlsTransport() {}
void DtlsTransport::Start() {
if (is_server_) {
dtls_transport_->Run(RTC::DtlsTransport::Role::SERVER);
} else {
dtls_transport_->Run(RTC::DtlsTransport::Role::CLIENT);
}
}
void DtlsTransport::Close() {}
void DtlsTransport::OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) {}
void DtlsTransport::OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport,
RTC::CryptoSuite srtp_crypto_suite,
uint8_t *srtpLocalKey, size_t srtpLocalKeyLen,
uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen,
std::string &remoteCert) {
std::string client_key;
std::string server_key;
server_key.assign((char *) srtpLocalKey, srtpLocalKeyLen);
client_key.assign((char *) srtpRemoteKey, srtpRemoteKeyLen);
if (is_server_) {
// If we are server, we swap the keys
client_key.swap(server_key);
}
if (handshake_completed_callback_) {
handshake_completed_callback_(client_key, server_key, srtp_crypto_suite);
}
}
void DtlsTransport::OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) {
if (handshake_failed_callback_) {
handshake_failed_callback_();
}
}
void DtlsTransport::OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) {}
void DtlsTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport,
const uint8_t *data, size_t len) {
if (output_callback_) {
output_callback_((char *) data, len);
}
}
void DtlsTransport::OutputData(char *buf, size_t len) {
if (output_callback_) {
output_callback_(buf, len);
}
}
void DtlsTransport::OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport,
const uint8_t *data, size_t len) {}
bool DtlsTransport::IsDtlsPacket(const char *buf, size_t len) {
return RTC::DtlsTransport::IsDtls((uint8_t *) buf, len);
}
void DtlsTransport::InputData(char *buf, size_t len) {
dtls_transport_->ProcessDtlsData((uint8_t *) buf, len);
}

58
webrtc/dtls_transport.h Normal file
View File

@ -0,0 +1,58 @@
//
// Created by xueyuegui on 19-12-7.
//
#ifndef MYWEBRTC_MYDTLSTRANSPORT_H
#define MYWEBRTC_MYDTLSTRANSPORT_H
#include <functional>
#include <memory>
#include "rtc_dtls_transport.h"
class DtlsTransport : RTC::DtlsTransport::Listener {
public:
typedef std::shared_ptr<DtlsTransport> Ptr;
DtlsTransport(bool bServer);
~DtlsTransport();
void Start();
void Close();
void InputData(char *buf, size_t len);
void OutputData(char *buf, size_t len);
static bool IsDtlsPacket(const char *buf, size_t len);
std::string GetMyFingerprint() {
auto finger_prints = dtls_transport_->GetLocalFingerprints();
for (size_t i = 0; i < finger_prints.size(); i++) {
if (finger_prints[i].algorithm == RTC::DtlsTransport::FingerprintAlgorithm::SHA256) {
return finger_prints[i].value;
}
}
return "";
};
void SetHandshakeCompletedCB(std::function<void(std::string clientKey, std::string serverKey, RTC::CryptoSuite)> cb) {
handshake_completed_callback_ = std::move(cb);
}
void SetHandshakeFailedCB(std::function<void()> cb) { handshake_failed_callback_ = std::move(cb); }
void SetOutPutCB(std::function<void(char *buf, size_t len)> cb) { output_callback_ = std::move(cb); }
/* Pure virtual methods inherited from RTC::DtlsTransport::Listener. */
public:
void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override;
void OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, RTC::CryptoSuite srtpCryptoSuite, uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) override;
void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override;
void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override;
void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data,size_t len) override;
void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override;
private:
bool is_server_ = false;
std::function<void()> handshake_failed_callback_;
std::shared_ptr<RTC::DtlsTransport> dtls_transport_;
std::function<void(char *buf, size_t len)> output_callback_;
std::function<void(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite)> handshake_completed_callback_;
};
#endif// MYWEBRTC_MYDTLSTRANSPORT_H

201
webrtc/ice_server.cc Normal file
View File

@ -0,0 +1,201 @@
#include "ice_server.h"
#include <iostream>
static constexpr size_t StunSerializeBufferSize{65536};
static uint8_t StunSerializeBuffer[StunSerializeBufferSize];
IceServer::IceServer() {}
IceServer::~IceServer() {}
IceServer::IceServer(const std::string &username_fragment, const std::string &password)
: username_fragment_(username_fragment), password_(password) {}
void IceServer::ProcessStunPacket(RTC::StunPacket *packet, sockaddr_in *remote_address) {
// Must be a Binding method.
if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) {
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) {
ELOG_WARN("unknown method %#.3x in STUN Request => 400",
static_cast<unsigned int>(packet->GetMethod()));
ELOG_WARN("unknown method %#.3x in STUN Request => 400",
static_cast<unsigned int>(packet->GetMethod()));
// Reply 400.
RTC::StunPacket *response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
} else {
ELOG_WARN("ignoring STUN Indication or Response with unknown method %#.3x",
static_cast<unsigned int>(packet->GetMethod()));
}
return;
}
// Must use FINGERPRINT (optional for ICE STUN indications).
if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) {
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) {
ELOG_WARN("STUN Binding Request without FINGERPRINT => 400");
// Reply 400.
RTC::StunPacket *response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
} else {
ELOG_WARN("ignoring STUN Binding Response without FINGERPRINT");
}
return;
}
switch (packet->GetClass()) {
case RTC::StunPacket::Class::REQUEST: {
// USERNAME, MESSAGE-INTEGRITY and PRIORITY are required.
if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) ||
packet->GetUsername().empty()) {
ELOG_WARN("mising required attributes in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket *response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
return;
}
// Check authentication.
switch (packet->CheckAuthentication(this->username_fragment_, this->password_)) {
case RTC::StunPacket::Authentication::OK: {
if (!this->old_password_.empty()) {
ELOG_DEBUG("kNew ICE credentials applied");
this->old_username_fragment_.clear();
this->old_password_.clear();
}
break;
}
case RTC::StunPacket::Authentication::UNAUTHORIZED: {
// We may have changed our username_fragment_ and password_, so check
// the old ones.
// clang-format off
if (!this->old_username_fragment_.empty() &&
!this->old_password_.empty() &&
packet->CheckAuthentication(this->old_username_fragment_, this->old_password_) ==
RTC::StunPacket::Authentication::OK) {
ELOG_DEBUG("using old ICE credentials");
break;
}
ELOG_WARN("wrong authentication in STUN Binding Request => 401");
// Reply 401.
RTC::StunPacket *response = packet->CreateErrorResponse(401);
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
return;
}
case RTC::StunPacket::Authentication::BAD_REQUEST: {
ELOG_WARN("cannot check authentication in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket *response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
return;
}
}
#if 0
// NOTE: Should be rejected with 487, but this makes Chrome happy:
// https://bugs.chromium.org/p/webrtc/issues/detail?id=7478
// The remote peer must be ICE controlling.
if (packet->GetIceControlled()) {
MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487");
// Reply 487 (Role Conflict).
RTC::StunPacket *response = packet->CreateErrorResponse(487);
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
return;
}
#endif
ELOG_DEBUG("processing STUN Binding Request [Priority:%d, UseCandidate:%s]",
static_cast<uint32_t>(packet->GetPriority()),
(packet->HasUseCandidate() ? "true" : "false"));
// Create a success response.
RTC::StunPacket *response = packet->CreateSuccessResponse();
// Add XOR-MAPPED-ADDRESS.
// response->SetXorMappedAddress(tuple->GetRemoteAddress());
response->SetXorMappedAddress((struct sockaddr *) remote_address);
// Authenticate the response.
if (this->old_password_.empty()) {
response->Authenticate(this->password_);
} else {
response->Authenticate(this->old_password_);
}
// Send back.
response->Serialize(StunSerializeBuffer);
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address);
}
delete response;
// Handle the tuple.
HandleTuple(remote_address, packet->HasUseCandidate());
break;
}
case RTC::StunPacket::Class::INDICATION: {
ELOG_DEBUG("STUN Binding Indication processed");
break;
}
case RTC::StunPacket::Class::SUCCESS_RESPONSE: {
ELOG_DEBUG("STUN Binding Success Response processed");
break;
}
case RTC::StunPacket::Class::ERROR_RESPONSE: {
ELOG_DEBUG("STUN Binding Error Response processed");
break;
}
}
}
void IceServer::HandleTuple(sockaddr_in *remote_address, bool has_use_candidate) {
remote_address_ = *remote_address;
if (has_use_candidate) {
this->state = IceState::kCompleted;
}
if (ice_server_completed_callback_) {
ice_server_completed_callback_();
ice_server_completed_callback_ = nullptr;
}
}
const std::string &IceServer::GetUsernameFragment() const { return this->username_fragment_; }
const std::string &IceServer::GetPassword() const { return this->password_; }
inline void IceServer::SetUsernameFragment(const std::string &username_fragment) {
this->old_username_fragment_ = this->username_fragment_;
this->username_fragment_ = username_fragment;
}
inline void IceServer::SetPassword(const std::string &password) {
this->old_password_ = this->password_;
this->password_ = password;
}
inline IceServer::IceState IceServer::GetState() const { return this->state; }

40
webrtc/ice_server.h Normal file
View File

@ -0,0 +1,40 @@
#pragma once
#include <functional>
#include <memory>
#include "logger.h"
#include "stun_packet.h"
typedef std::function<void(char *buf, size_t len, struct sockaddr_in *remote_address)> UdpSendCallback;
class IceServer {
public:
enum class IceState { kNew = 1, kConnect, kCompleted, kDisconnected };
typedef std::shared_ptr<IceServer> Ptr;
IceServer();
IceServer(const std::string &username_fragment, const std::string &password);
const std::string &GetUsernameFragment() const;
const std::string &GetPassword() const;
void SetUsernameFragment(const std::string &username_fragment);
void SetPassword(const std::string &password);
IceState GetState() const;
void ProcessStunPacket(RTC::StunPacket *packet, struct sockaddr_in *remote_address);
void HandleTuple(struct sockaddr_in *remote_address, bool has_use_candidate);
~IceServer();
void SetSendCB(UdpSendCallback send_cb) { send_callback_ = send_cb; }
void SetIceServerCompletedCB(std::function<void()> cb) { ice_server_completed_callback_ = cb; };
struct sockaddr_in *GetSelectAddr() {
return &remote_address_;
}
private:
UdpSendCallback send_callback_;
std::function<void()> ice_server_completed_callback_;
std::string username_fragment_;
std::string password_;
std::string old_username_fragment_;
std::string old_password_;
IceState state{IceState::kNew};
struct sockaddr_in remote_address_;
};

18
webrtc/logger.h Normal file
View File

@ -0,0 +1,18 @@
#pragma once
#include <stdio.h>
#include <assert.h>
#define ELOG_DEBUG(fmt, ...) printf(fmt "\n", ##__VA_ARGS__)
#define ELOG_WARN(fmt, ...) printf(fmt "\n", ##__VA_ARGS__)
#define MS_TRACE()
#define MS_ERROR(fmt, ...) printf("error:" fmt "\n", ##__VA_ARGS__)
#define MS_THROW_ERROR(fmt, ...) do{ printf("throw:" fmt "\n", ##__VA_ARGS__); throw std::runtime_error("error"); } while(false);
#define MS_DUMP(fmt, ...) printf("dump:" fmt "\n", ##__VA_ARGS__)
#define MS_DEBUG_2TAGS(tag1, tag2,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__)
#define MS_WARN_2TAGS(tag1, tag2,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__)
#define MS_DEBUG_TAG(tag,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__)
#define MS_ASSERT(con, log) assert(con)
#define MS_ABORT(fmt, ...) do{ printf("abort:" fmt "\n", ##__VA_ARGS__); abort(); } while(false);
#define MS_WARN_TAG(tag,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__)
#define MS_DEBUG_DEV(fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__)

1323
webrtc/rtc_dtls_transport.cc Normal file

File diff suppressed because it is too large Load Diff

187
webrtc/rtc_dtls_transport.h Normal file
View File

@ -0,0 +1,187 @@
#ifndef MS_RTC_DTLS_TRANSPORT_HPP
#define MS_RTC_DTLS_TRANSPORT_HPP
#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/x509.h>
#include <map>
#include <string>
#include <vector>
namespace RTC {
enum class CryptoSuite {
NONE = 0,
AES_CM_128_HMAC_SHA1_80 = 1,
AES_CM_128_HMAC_SHA1_32,
AEAD_AES_256_GCM,
AEAD_AES_128_GCM
};
class DtlsTransport {
public:
enum class DtlsState { NEW = 1, CONNECTING, CONNECTED, FAILED, CLOSED };
public:
enum class Role { NONE = 0, AUTO = 1, CLIENT, SERVER };
public:
enum class FingerprintAlgorithm { NONE = 0, SHA1 = 1, SHA224, SHA256, SHA384, SHA512 };
public:
struct Fingerprint {
FingerprintAlgorithm algorithm{FingerprintAlgorithm::NONE};
std::string value;
};
private:
struct SrtpCryptoSuiteMapEntry {
RTC::CryptoSuite cryptoSuite;
const char* name;
};
public:
class Listener {
public:
// DTLS is in the process of negotiating a secure connection. Incoming
// media can flow through.
// NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0;
// DTLS has completed negotiation of a secure connection (including DTLS-SRTP
// and remote fingerprint verification). Outgoing media can now flow through.
// NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnected(const RTC::DtlsTransport* dtlsTransport,
RTC::CryptoSuite srtpCryptoSuite, uint8_t* srtpLocalKey,
size_t srtpLocalKeyLen, uint8_t* srtpRemoteKey,
size_t srtpRemoteKeyLen, std::string& remoteCert) = 0;
// The DTLS connection has been closed as the result of an error (such as a
// DTLS alert or a failure to validate the remote fingerprint).
virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0;
// The DTLS connection has been closed due to receipt of a close_notify alert.
virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0;
// Need to send DTLS data to the peer.
virtual void OnDtlsTransportSendData(const RTC::DtlsTransport* dtlsTransport,
const uint8_t* data, size_t len) = 0;
// DTLS application data received.
virtual void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport* dtlsTransport,
const uint8_t* data, size_t len) = 0;
};
public:
static void ClassInit();
static void ClassDestroy();
static Role StringToRole(const std::string& role) {
auto it = DtlsTransport::string2Role.find(role);
if (it != DtlsTransport::string2Role.end())
return it->second;
else
return DtlsTransport::Role::NONE;
}
static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) {
auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint);
if (it != DtlsTransport::string2FingerprintAlgorithm.end())
return it->second;
else
return DtlsTransport::FingerprintAlgorithm::NONE;
}
static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) {
auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint);
return it->second;
}
static bool IsDtls(const uint8_t* data, size_t len) {
// clang-format off
return (
// Minimum DTLS record length is 13 bytes.
(len >= 13) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] > 19 && data[0] < 64)
);
// clang-format on
}
private:
static void GenerateCertificateAndPrivateKey();
static void ReadCertificateAndPrivateKeyFromFiles();
static void CreateSslCtx();
static void GenerateFingerprints();
private:
static X509* certificate;
static EVP_PKEY* privateKey;
static SSL_CTX* sslCtx;
static uint8_t sslReadBuffer[];
static std::map<std::string, Role> string2Role;
static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm;
static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String;
static std::vector<Fingerprint> localFingerprints;
static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites;
public:
explicit DtlsTransport(Listener* listener);
~DtlsTransport();
public:
void Dump() const;
void Run(Role localRole);
std::vector<Fingerprint>& GetLocalFingerprints() const {
return DtlsTransport::localFingerprints;
}
bool SetRemoteFingerprint(Fingerprint fingerprint);
void ProcessDtlsData(const uint8_t* data, size_t len);
DtlsState GetState() const { return this->state; }
Role GetLocalRole() const { return this->localRole; }
void SendApplicationData(const uint8_t* data, size_t len);
private:
bool IsRunning() const {
switch (this->state) {
case DtlsState::NEW:
return false;
case DtlsState::CONNECTING:
case DtlsState::CONNECTED:
return true;
case DtlsState::FAILED:
case DtlsState::CLOSED:
return false;
}
// Make GCC 4.9 happy.
return false;
}
void Reset();
bool CheckStatus(int returnCode);
void SendPendingOutgoingDtlsData();
bool SetTimeout();
bool ProcessHandshake();
bool CheckRemoteFingerprint();
void ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite);
RTC::CryptoSuite GetNegotiatedSrtpCryptoSuite();
/* Callbacks fired by OpenSSL events. */
public:
void OnSslInfo(int where, int ret);
/* Pure virtual methods inherited from Timer::Listener. */
public:
void OnTimer();
private:
// Passed by argument.
Listener* listener{nullptr};
// Allocated by this.
SSL* ssl{nullptr};
BIO* sslBioFromNetwork{nullptr}; // The BIO from which ssl reads.
BIO* sslBioToNetwork{nullptr}; // The BIO in which ssl writes.
// Others.
DtlsState state{DtlsState::NEW};
Role localRole{Role::NONE};
Fingerprint remoteFingerprint;
bool handshakeDone{false};
bool handshakeDoneNow{false};
std::string remoteCert;
};
} // namespace RTC
#endif

269
webrtc/srtp_session.cc Normal file
View File

@ -0,0 +1,269 @@
#define MS_CLASS "RTC::SrtpSession"
// #define MS_LOG_DEV_LEVEL 3
#include "srtp_session.h"
#include <cstring> // std::memset(), std::memcpy()
#include <iostream>
#include "logger.h"
namespace RTC {
/* Static. */
static constexpr size_t EncryptBufferSize{65536};
static uint8_t EncryptBuffer[EncryptBufferSize];
/* Class methods. */
std::vector<const char *> DepLibSRTP::errors = {
// From 0 (srtp_err_status_ok) to 24 (srtp_err_status_pfkey_err).
"success (srtp_err_status_ok)",
"unspecified failure (srtp_err_status_fail)",
"unsupported parameter (srtp_err_status_bad_param)",
"couldn't allocate memory (srtp_err_status_alloc_fail)",
"couldn't deallocate memory (srtp_err_status_dealloc_fail)",
"couldn't initialize (srtp_err_status_init_fail)",
"cant process as much data as requested (srtp_err_status_terminus)",
"authentication failure (srtp_err_status_auth_fail)",
"cipher failure (srtp_err_status_cipher_fail)",
"replay check failed (bad index) (srtp_err_status_replay_fail)",
"replay check failed (index too old) (srtp_err_status_replay_old)",
"algorithm failed test routine (srtp_err_status_algo_fail)",
"unsupported operation (srtp_err_status_no_such_op)",
"no appropriate context found (srtp_err_status_no_ctx)",
"unable to perform desired validation (srtp_err_status_cant_check)",
"cant use key any more (srtp_err_status_key_expired)",
"error in use of socket (srtp_err_status_socket_err)",
"error in use POSIX signals (srtp_err_status_signal_err)",
"nonce check failed (srtp_err_status_nonce_bad)",
"couldnt read data (srtp_err_status_read_fail)",
"couldnt write data (srtp_err_status_write_fail)",
"error parsing data (srtp_err_status_parse_err)",
"error encoding data (srtp_err_status_encode_err)",
"error while using semaphores (srtp_err_status_semaphore_err)",
"error while using pfkey (srtp_err_status_pfkey_err)"};
// clang-format on
/* Static methods. */
void DepLibSRTP::ClassInit() {
MS_TRACE();
MS_DEBUG_TAG(info, "libsrtp version: \"%s\"", srtp_get_version_string());
srtp_err_status_t err = srtp_init();
if (DepLibSRTP::IsError(err))
MS_THROW_ERROR("srtp_init() failed: %s", DepLibSRTP::GetErrorString(err));
}
void DepLibSRTP::ClassDestroy() {
MS_TRACE();
srtp_shutdown();
}
void SrtpSession::ClassInit() {
// Set libsrtp event handler.
srtp_err_status_t err =
srtp_install_event_handler(static_cast<srtp_event_handler_func_t *>(OnSrtpEvent));
if (DepLibSRTP::IsError(err)) {
MS_THROW_ERROR("srtp_install_event_handler() failed: %s", DepLibSRTP::GetErrorString(err));
std::cout << "srtp_install_event_handler() failed :" << DepLibSRTP::GetErrorString(err);
}
}
void SrtpSession::OnSrtpEvent(srtp_event_data_t *data) {
MS_TRACE();
switch (data->event) {
case event_ssrc_collision:
MS_WARN_TAG(srtp, "SSRC collision occurred");
break;
case event_key_soft_limit:
MS_WARN_TAG(srtp, "stream reached the soft key usage limit and will expire soon");
break;
case event_key_hard_limit:
MS_WARN_TAG(srtp, "stream reached the hard key usage limit and has expired");
break;
case event_packet_index_limit:
MS_WARN_TAG(srtp, "stream reached the hard packet limit (2^48 packets)");
break;
}
}
/* Instance methods. */
SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen) {
MS_TRACE();
srtp_policy_t policy;// NOLINT(cppcoreguidelines-pro-type-member-init)
// Set all policy fields to 0.
std::memset(&policy, 0, sizeof(srtp_policy_t));
switch (cryptoSuite) {
case CryptoSuite::AES_CM_128_HMAC_SHA1_80: {
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp);
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp);
break;
}
case CryptoSuite::AES_CM_128_HMAC_SHA1_32: {
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp);
// NOTE: Must be 80 for RTCP.
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp);
break;
}
case CryptoSuite::AEAD_AES_256_GCM: {
srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp);
srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp);
break;
}
case CryptoSuite::AEAD_AES_128_GCM: {
srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp);
srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp);
break;
}
default: {
MS_ABORT("unknown SRTP crypto suite");
}
}
MS_ASSERT((int) keyLen == policy.rtp.cipher_key_len,
"given keyLen does not match policy.rtp.cipher_keyLen");
switch (type) {
case Type::INBOUND:
policy.ssrc.type = ssrc_any_inbound;
break;
case Type::OUTBOUND:
policy.ssrc.type = ssrc_any_outbound;
break;
}
policy.ssrc.value = 0;
policy.key = key;
// Required for sending RTP retransmission without RTX.
policy.allow_repeat_tx = 1;
policy.window_size = 1024;
policy.next = nullptr;
// Set the SRTP session.
srtp_err_status_t err = srtp_create(&this->session, &policy);
if (DepLibSRTP::IsError(err)) {
is_init = false;
MS_THROW_ERROR("srtp_create() failed: %s", DepLibSRTP::GetErrorString(err));
} else {
is_init = true;
}
}
SrtpSession::~SrtpSession() {
MS_TRACE();
if (this->session != nullptr) {
srtp_err_status_t err = srtp_dealloc(this->session);
if (DepLibSRTP::IsError(err))
MS_ABORT("srtp_dealloc() failed: %s", DepLibSRTP::GetErrorString(err));
}
}
bool SrtpSession::EncryptRtp(const uint8_t **data, size_t *len) {
MS_TRACE();
if (!is_init) {
return false;
}
// Ensure that the resulting SRTP packet fits into the encrypt buffer.
if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) {
MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len);
return false;
}
std::memcpy(EncryptBuffer, *data, *len);
srtp_err_status_t err =
srtp_protect(this->session, static_cast<void *>(EncryptBuffer), reinterpret_cast<int *>(len));
if (DepLibSRTP::IsError(err)) {
MS_WARN_TAG(srtp, "srtp_protect() failed: %s", DepLibSRTP::GetErrorString(err));
return false;
}
// Update the given data pointer.
*data = (const uint8_t *) EncryptBuffer;
return true;
}
bool SrtpSession::DecryptSrtp(uint8_t *data, size_t *len) {
MS_TRACE();
srtp_err_status_t err =
srtp_unprotect(this->session, static_cast<void *>(data), reinterpret_cast<int *>(len));
if (DepLibSRTP::IsError(err)) {
MS_DEBUG_TAG(srtp, "srtp_unprotect() failed: %s", DepLibSRTP::GetErrorString(err));
return false;
}
return true;
}
bool SrtpSession::EncryptRtcp(const uint8_t **data, size_t *len) {
MS_TRACE();
// Ensure that the resulting SRTCP packet fits into the encrypt buffer.
if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) {
MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len);
return false;
}
std::memcpy(EncryptBuffer, *data, *len);
srtp_err_status_t err = srtp_protect_rtcp(this->session, static_cast<void *>(EncryptBuffer),
reinterpret_cast<int *>(len));
if (DepLibSRTP::IsError(err)) {
MS_WARN_TAG(srtp, "srtp_protect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err));
return false;
}
// Update the given data pointer.
*data = (const uint8_t *) EncryptBuffer;
return true;
}
bool SrtpSession::DecryptSrtcp(uint8_t *data, size_t *len) {
MS_TRACE();
srtp_err_status_t err =
srtp_unprotect_rtcp(this->session, static_cast<void *>(data), reinterpret_cast<int *>(len));
if (DepLibSRTP::IsError(err)) {
MS_DEBUG_TAG(srtp, "srtp_unprotect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err));
return false;
}
return true;
}
}// namespace RTC

54
webrtc/srtp_session.h Normal file
View File

@ -0,0 +1,54 @@
#ifndef MS_RTC_SRTP_SESSION_HPP
#define MS_RTC_SRTP_SESSION_HPP
#include "rtc_dtls_transport.h"
#include "utils.h"
#include <srtp2/srtp.h>
#include <vector>
namespace RTC {
class DepLibSRTP {
public:
static void ClassInit();
static void ClassDestroy();
static bool IsError(srtp_err_status_t code) { return (code != srtp_err_status_ok); }
static const char *GetErrorString(srtp_err_status_t code) {
// This throws out_of_range if the given index is not in the vector.
return DepLibSRTP::errors.at(code);
}
private:
static std::vector<const char *> errors;
};
class SrtpSession {
public:
public:
enum class Type { INBOUND = 1, OUTBOUND };
public:
static void ClassInit();
private:
static void OnSrtpEvent(srtp_event_data_t *data);
public:
SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen);
~SrtpSession();
public:
bool EncryptRtp(const uint8_t **data, size_t *len);
bool DecryptSrtp(uint8_t *data, size_t *len);
bool EncryptRtcp(const uint8_t **data, size_t *len);
bool DecryptSrtcp(uint8_t *data, size_t *len);
void RemoveStream(uint32_t ssrc) { srtp_remove_stream(this->session, uint32_t{htonl(ssrc)}); }
private:
bool is_init = false;
// Allocated by this.
srtp_t session{nullptr};
};
}// namespace RTC
#endif

710
webrtc/stun_packet.cc Normal file
View File

@ -0,0 +1,710 @@
#define MS_CLASS "RTC::StunPacket"
// #define MS_LOG_DEV
#include "stun_packet.h"
#include <cstdio> // std::snprintf()
#include <cstring> // std::memcmp(), std::memcpy()
#include "utils.h"
namespace RTC {
/* Class variables. */
const uint8_t StunPacket::kMagicCookie[] = {0x21, 0x12, 0xA4, 0x42};
/* Class methods. */
StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) {
if (!StunPacket::IsStun(data, len)) return nullptr;
/*
The message type field is decomposed further into the following
structure:
0 1
2 3 4 5 6 7 8 9 0 1 2 3 4 5
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|M |M |M|M|M|C|M|M|M|C|M|M|M|M|
|11|10|9|8|7|1|6|5|4|0|3|2|1|0|
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 3: Format of STUN Message Type Field
Here the bits in the message type field are shown as most significant
(M11) through least significant (M0). M11 through M0 represent a 12-
bit encoding of the method. C1 and C0 represent a 2-bit encoding of
the class.
*/
// Get type field.
uint16_t msgType = Utils::Byte::Get2Bytes(data, 0);
// Get length field.
uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2);
// length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes.
if ((static_cast<size_t>(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) {
ELOG_DEBUG(
"length field + 20 does not match total size (or it is not multiple of 4 bytes), "
"packet discarded");
return nullptr;
}
// Get STUN method.
uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2);
// Get STUN class.
uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4);
// Create a new StunPacket (data + 8 points to the received TransactionID field).
auto packet = new StunPacket(static_cast<Class>(msgClass), static_cast<Method>(msgMethod),
data + 8, data, len);
/*
STUN Attributes
After the STUN header are zero or more attributes. Each attribute
MUST be TLV encoded, with a 16-bit type, 16-bit length, and value.
Each STUN attribute MUST end on a 32-bit boundary. As mentioned
above, all fields in an attribute are transmitted most significant
bit first.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Value (variable) ....
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
// Start looking for attributes after STUN header (Byte #20).
size_t pos{20};
// Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes.
bool hasMessageIntegrity{false};
bool hasFingerprint{false};
size_t fingerprintAttrPos; // Will point to the beginning of the attribute.
uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute.
// Ensure there are at least 4 remaining bytes (attribute with 0 length).
while (pos + 4 <= len) {
// Get the attribute type.
auto attrType = static_cast<Attribute>(Utils::Byte::Get2Bytes(data, pos));
// Get the attribute length.
uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2);
// Ensure the attribute length is not greater than the remaining size.
if ((pos + 4 + attrLength) > len) {
ELOG_DEBUG("the attribute length exceeds the remaining size, packet discarded");
delete packet;
return nullptr;
}
// FINGERPRINT must be the last attribute.
if (hasFingerprint) {
ELOG_DEBUG("attribute after FINGERPRINT is not allowed, packet discarded");
delete packet;
return nullptr;
}
// After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed.
if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) {
ELOG_DEBUG(
"attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, "
"packet discarded");
delete packet;
return nullptr;
}
const uint8_t* attrValuePos = data + pos + 4;
switch (attrType) {
case Attribute::USERNAME: {
packet->SetUsername(reinterpret_cast<const char*>(attrValuePos),
static_cast<size_t>(attrLength));
break;
}
case Attribute::PRIORITY: {
// Ensure attribute length is 4 bytes.
if (attrLength != 4) {
ELOG_DEBUG("attribute PRIORITY must be 4 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0));
break;
}
case Attribute::ICE_CONTROLLING: {
// Ensure attribute length is 8 bytes.
if (attrLength != 8) {
ELOG_DEBUG("attribute ICE-CONTROLLING must be 8 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0));
break;
}
case Attribute::ICE_CONTROLLED: {
// Ensure attribute length is 8 bytes.
if (attrLength != 8) {
ELOG_DEBUG("attribute ICE-CONTROLLED must be 8 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0));
break;
}
case Attribute::USE_CANDIDATE: {
// Ensure attribute length is 0 bytes.
if (attrLength != 0) {
ELOG_DEBUG("attribute USE-CANDIDATE must be 0 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetUseCandidate();
break;
}
case Attribute::MESSAGE_INTEGRITY: {
// Ensure attribute length is 20 bytes.
if (attrLength != 20) {
ELOG_DEBUG("attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded");
delete packet;
return nullptr;
}
hasMessageIntegrity = true;
packet->SetMessageIntegrity(attrValuePos);
break;
}
case Attribute::FINGERPRINT: {
// Ensure attribute length is 4 bytes.
if (attrLength != 4) {
ELOG_DEBUG("attribute FINGERPRINT must be 4 bytes length, packet discarded");
delete packet;
return nullptr;
}
hasFingerprint = true;
fingerprintAttrPos = pos;
fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0);
packet->SetFingerprint();
break;
}
case Attribute::ERROR_CODE: {
// Ensure attribute length >= 4bytes.
if (attrLength < 4) {
ELOG_DEBUG("attribute ERROR-CODE must be >= 4bytes length, packet discarded");
delete packet;
return nullptr;
}
uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2);
uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3);
auto errorCode = static_cast<uint16_t>(errorClass * 100 + errorNumber);
packet->SetErrorCode(errorCode);
break;
}
default:;
}
// Set next attribute position.
pos = static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(pos + 4 + attrLength)));
}
// Ensure current position matches the total length.
if (pos != len) {
ELOG_DEBUG("computed packet size does not match total size, packet discarded");
delete packet;
return nullptr;
}
// If it has FINGERPRINT attribute then verify it.
if (hasFingerprint) {
// Compute the CRC32 of the received packet up to (but excluding) the
// FINGERPRINT attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = Utils::Crypto::GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e;
// Compare with the FINGERPRINT value in the packet.
if (fingerprint != computedFingerprint) {
ELOG_DEBUG(
"computed FINGERPRINT value does not match the value in the packet, "
"packet discarded");
delete packet;
return nullptr;
}
}
return packet;
}
/* Instance methods. */
StunPacket::StunPacket(Class klass, Method method, const uint8_t* transactionId,
const uint8_t* data, size_t size)
: klass(klass),
method(method),
transactionId(transactionId),
data(const_cast<uint8_t*>(data)),
size(size) {
// MS_TRACE();
}
StunPacket::~StunPacket() {
// MS_TRACE();
}
void StunPacket::Dump() const {
// MS_TRACE();
// MS_DUMP("<StunPacket>");
std::string klass;
switch (this->klass) {
case Class::REQUEST:
klass = "Request";
break;
case Class::INDICATION:
klass = "Indication";
break;
case Class::SUCCESS_RESPONSE:
klass = "SuccessResponse";
break;
case Class::ERROR_RESPONSE:
klass = "ErrorResponse";
break;
}
if (this->method == Method::BINDING) {
// MS_DUMP(" Binding %s", klass.c_str());
} else {
// This prints the unknown method number. Example: TURN Allocate => 0x003.
// MS_DUMP(" %s with unknown method %#.3x", klass.c_str(),
// static_cast<uint16_t>(this->method));
}
// MS_DUMP(" size: %zu bytes", this->size);
static char transactionId[25];
for (int i{0}; i < 12; ++i) {
// NOTE: n must be 3 because snprintf adds a \0 after printed chars.
std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]);
}
// MS_DUMP(" transactionId: %s", transactionId);
if (this->errorCode != 0u)
// MS_DUMP(" errorCode: %" PRIu16, this->errorCode);
if (!this->username.empty())
// MS_DUMP(" username: %s", this->username.c_str());
if (this->priority != 0u)
// MS_DUMP(" priority: %" PRIu32, this->priority);
if (this->iceControlling != 0u)
// MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling);
if (this->iceControlled != 0u)
// MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled);
if (this->hasUseCandidate)
// MS_DUMP(" useCandidate");
if (this->xorMappedAddress != nullptr) {
int family;
uint16_t port;
std::string ip;
Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port);
// MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port);
}
if (this->messageIntegrity != nullptr) {
static char messageIntegrity[41];
for (int i{0}; i < 20; ++i) {
std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]);
}
// MS_DUMP(" messageIntegrity: %s", messageIntegrity);
}
if (this->hasFingerprint) {
}
// MS_DUMP(" has fingerprint");
// MS_DUMP("</StunPacket>");
}
StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& localUsername,
const std::string& localPassword) {
// MS_TRACE();
switch (this->klass) {
case Class::REQUEST:
case Class::INDICATION: {
// Both USERNAME and MESSAGE-INTEGRITY must be present.
if (this->messageIntegrity == nullptr || this->username.empty())
return Authentication::BAD_REQUEST;
// Check that USERNAME attribute begins with our local username plus ":".
size_t localUsernameLen = localUsername.length();
if (this->username.length() <= localUsernameLen ||
this->username.at(localUsernameLen) != ':' ||
(this->username.compare(0, localUsernameLen, localUsername) != 0)) {
return Authentication::UNAUTHORIZED;
}
break;
}
// This method cannot check authentication in received responses (as we
// are ICE-Lite and don't generate requests).
case Class::SUCCESS_RESPONSE:
case Class::ERROR_RESPONSE: {
// MS_ERROR("cannot check authentication for a STUN response");
return Authentication::BAD_REQUEST;
}
}
// If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation,
// so the header length field must be modified (and later restored).
if (this->hasFingerprint)
// Set the header length field: full size - header length (20) - FINGERPRINT length (8).
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules.
const uint8_t* computedMessageIntegrity = Utils::Crypto::GetHmacShA1(
localPassword, this->data, (this->messageIntegrity - 4) - this->data);
Authentication result;
// Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet.
if (std::memcmp(this->messageIntegrity, computedMessageIntegrity, 20) == 0)
result = Authentication::OK;
else
result = Authentication::UNAUTHORIZED;
// Restore the header length field.
if (this->hasFingerprint)
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20));
return result;
}
StunPacket* StunPacket::CreateSuccessResponse() {
// MS_TRACE();
// MS_ASSERT(
// this->klass == Class::REQUEST,
// "attempt to create a success response for a non Request STUN packet");
return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0);
}
StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) {
// MS_TRACE();
// MS_ASSERT(
// this->klass == Class::REQUEST,
// "attempt to create an error response for a non Request STUN packet");
auto response =
new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0);
response->SetErrorCode(errorCode);
return response;
}
void StunPacket::Authenticate(const std::string& password) {
// Just for Request, Indication and SuccessResponse messages.
if (this->klass == Class::ERROR_RESPONSE) {
// MS_ERROR("cannot set password for ErrorResponse messages");
return;
}
this->password = password;
}
void StunPacket::Serialize(uint8_t* buffer) {
// MS_TRACE();
// Some useful variables.
uint16_t usernamePaddedLen{0};
uint16_t xorMappedAddressPaddedLen{0};
bool addXorMappedAddress =
((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING &&
this->klass == Class::SUCCESS_RESPONSE);
bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE);
bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty());
bool addFingerprint{true}; // Do always.
// Update data pointer.
this->data = buffer;
// First calculate the total required size for the entire packet.
this->size = 20; // Header.
if (!this->username.empty()) {
usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(this->username.length()));
this->size += 4 + usernamePaddedLen;
}
if (this->priority != 0u) this->size += 4 + 4;
if (this->iceControlling != 0u) this->size += 4 + 8;
if (this->iceControlled != 0u) this->size += 4 + 8;
if (this->hasUseCandidate) this->size += 4;
if (addXorMappedAddress) {
switch (this->xorMappedAddress->sa_family) {
case AF_INET: {
xorMappedAddressPaddedLen = 8;
this->size += 4 + 8;
break;
}
case AF_INET6: {
xorMappedAddressPaddedLen = 20;
this->size += 4 + 20;
break;
}
default: {
// MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute");
addXorMappedAddress = false;
}
}
}
if (addErrorCode) this->size += 4 + 4;
if (addMessageIntegrity) this->size += 4 + 20;
if (addFingerprint) this->size += 4 + 4;
// Merge class and method fields into type.
uint16_t typeField = (static_cast<uint16_t>(this->method) & 0x0f80) << 2;
typeField |= (static_cast<uint16_t>(this->method) & 0x0070) << 1;
typeField |= (static_cast<uint16_t>(this->method) & 0x000f);
typeField |= (static_cast<uint16_t>(this->klass) & 0x02) << 7;
typeField |= (static_cast<uint16_t>(this->klass) & 0x01) << 4;
// Set type field.
Utils::Byte::Set2Bytes(buffer, 0, typeField);
// Set length field.
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size) - 20);
// Set magic cookie.
std::memcpy(buffer + 4, StunPacket::kMagicCookie, 4);
// Set TransactionId field.
std::memcpy(buffer + 8, this->transactionId, 12);
// Update the transaction ID pointer.
this->transactionId = buffer + 8;
// Add atributes.
size_t pos{20};
// Add USERNAME.
if (usernamePaddedLen != 0u) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USERNAME));
Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast<uint16_t>(this->username.length()));
std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length());
pos += 4 + usernamePaddedLen;
}
// Add PRIORITY.
if (this->priority != 0u) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::PRIORITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority);
pos += 4 + 4;
}
// Add ICE-CONTROLLING.
if (this->iceControlling != 0u) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLING));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling);
pos += 4 + 8;
}
// Add ICE-CONTROLLED.
if (this->iceControlled != 0u) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLED));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled);
pos += 4 + 8;
}
// Add USE-CANDIDATE.
if (this->hasUseCandidate) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USE_CANDIDATE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 0);
pos += 4;
}
// Add XOR-MAPPED-ADDRESS
if (addXorMappedAddress) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::XOR_MAPPED_ADDRESS));
Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen);
uint8_t* attrValue = buffer + pos + 4;
switch (this->xorMappedAddress->sa_family) {
case AF_INET: {
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x01;
// Set port and XOR it.
std::memcpy(attrValue + 2,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_port, 2);
attrValue[2] ^= StunPacket::kMagicCookie[0];
attrValue[3] ^= StunPacket::kMagicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_addr.s_addr, 4);
attrValue[4] ^= StunPacket::kMagicCookie[0];
attrValue[5] ^= StunPacket::kMagicCookie[1];
attrValue[6] ^= StunPacket::kMagicCookie[2];
attrValue[7] ^= StunPacket::kMagicCookie[3];
pos += 4 + 8;
break;
}
case AF_INET6: {
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x02;
// Set port and XOR it.
std::memcpy(attrValue + 2,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_port, 2);
attrValue[2] ^= StunPacket::kMagicCookie[0];
attrValue[3] ^= StunPacket::kMagicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_addr.s6_addr,
16);
attrValue[4] ^= StunPacket::kMagicCookie[0];
attrValue[5] ^= StunPacket::kMagicCookie[1];
attrValue[6] ^= StunPacket::kMagicCookie[2];
attrValue[7] ^= StunPacket::kMagicCookie[3];
attrValue[8] ^= this->transactionId[0];
attrValue[9] ^= this->transactionId[1];
attrValue[10] ^= this->transactionId[2];
attrValue[11] ^= this->transactionId[3];
attrValue[12] ^= this->transactionId[4];
attrValue[13] ^= this->transactionId[5];
attrValue[14] ^= this->transactionId[6];
attrValue[15] ^= this->transactionId[7];
attrValue[16] ^= this->transactionId[8];
attrValue[17] ^= this->transactionId[9];
attrValue[18] ^= this->transactionId[10];
attrValue[19] ^= this->transactionId[11];
pos += 4 + 20;
break;
}
}
}
// Add ERROR-CODE.
if (addErrorCode) {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ERROR_CODE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
auto codeClass = static_cast<uint8_t>(this->errorCode / 100);
uint8_t codeNumber = static_cast<uint8_t>(this->errorCode) - (codeClass * 100);
Utils::Byte::Set2Bytes(buffer, pos + 4, 0);
Utils::Byte::Set1Byte(buffer, pos + 6, codeClass);
Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber);
pos += 4 + 4;
}
// Add MESSAGE-INTEGRITY.
if (addMessageIntegrity) {
// Ignore FINGERPRINT.
if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules.
const uint8_t* computedMessageIntegrity =
Utils::Crypto::GetHmacShA1(this->password, buffer, pos);
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::MESSAGE_INTEGRITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 20);
std::memcpy(buffer + pos + 4, computedMessageIntegrity, 20);
// Update the pointer.
this->messageIntegrity = buffer + pos + 4;
pos += 4 + 20;
// Restore length field.
if (addFingerprint) Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20));
} else {
// Unset the pointer (if it was set).
this->messageIntegrity = nullptr;
}
// Add FINGERPRINT.
if (addFingerprint) {
// Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT
// attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = Utils::Crypto::GetCRC32(buffer, pos) ^ 0x5354554e;
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint);
pos += 4 + 4;
// Set flag.
this->hasFingerprint = true;
} else {
this->hasFingerprint = false;
}
// MS_ASSERT(pos == this->size, "pos != this->size");
}
} // namespace RTC

179
webrtc/stun_packet.h Normal file
View File

@ -0,0 +1,179 @@
#ifndef MS_RTC_STUN_PACKET_HPP
#define MS_RTC_STUN_PACKET_HPP
#include "logger.h"
#include "utils.h"
#include <string>
namespace RTC {
class StunPacket {
public:
// STUN message class.
enum class Class : uint16_t {
REQUEST = 0,
INDICATION = 1,
SUCCESS_RESPONSE = 2,
ERROR_RESPONSE = 3
};
// STUN message method.
enum class Method : uint16_t { BINDING = 1 };
// Attribute type.
enum class Attribute : uint16_t {
MAPPED_ADDRESS = 0x0001,
USERNAME = 0x0006,
MESSAGE_INTEGRITY = 0x0008,
ERROR_CODE = 0x0009,
UNKNOWN_ATTRIBUTES = 0x000A,
REALM = 0x0014,
NONCE = 0x0015,
XOR_MAPPED_ADDRESS = 0x0020,
PRIORITY = 0x0024,
USE_CANDIDATE = 0x0025,
SOFTWARE = 0x8022,
ALTERNATE_SERVER = 0x8023,
FINGERPRINT = 0x8028,
ICE_CONTROLLED = 0x8029,
ICE_CONTROLLING = 0x802A
};
// Authentication result.
enum class Authentication { OK = 0, UNAUTHORIZED = 1, BAD_REQUEST = 2 };
public:
static bool IsStun(const uint8_t *data, size_t len);
static StunPacket *Parse(const uint8_t *data, size_t len);
private:
static const uint8_t kMagicCookie[];
public:
StunPacket(Class klass, Method method, const uint8_t *transactionId, const uint8_t *data,
size_t size);
~StunPacket();
void Dump() const;
Class GetClass() const;
Method GetMethod() const;
const uint8_t *GetData() const;
size_t GetSize() const;
void SetUsername(const char *username, size_t len);
void SetPriority(uint32_t priority);
void SetIceControlling(uint64_t iceControlling);
void SetIceControlled(uint64_t iceControlled);
void SetUseCandidate();
void SetXorMappedAddress(const struct sockaddr *xorMappedAddress);
void SetErrorCode(uint16_t errorCode);
void SetMessageIntegrity(const uint8_t *messageIntegrity);
void SetFingerprint();
const std::string &GetUsername() const;
uint32_t GetPriority() const;
uint64_t GetIceControlling() const;
uint64_t GetIceControlled() const;
bool HasUseCandidate() const;
uint16_t GetErrorCode() const;
bool HasMessageIntegrity() const;
bool HasFingerprint() const;
Authentication CheckAuthentication(const std::string &localUsername,
const std::string &localPassword);
StunPacket *CreateSuccessResponse();
StunPacket *CreateErrorResponse(uint16_t errorCode);
void Authenticate(const std::string &password);
void Serialize(uint8_t *buffer);
private:
// Passed by argument.
Class klass; // 2 bytes.
Method method; // 2 bytes.
const uint8_t *transactionId{nullptr};// 12 bytes.
uint8_t *data{nullptr}; // Pointer to binary data.
size_t size{0}; // The full message size (including header).
// STUN attributes.
std::string username; // Less than 513 bytes.
uint32_t priority{0}; // 4 bytes unsigned integer.
uint64_t iceControlling{0}; // 8 bytes unsigned integer.
uint64_t iceControlled{0}; // 8 bytes unsigned integer.
bool hasUseCandidate{false}; // 0 bytes.
const uint8_t *messageIntegrity{nullptr}; // 20 bytes.
bool hasFingerprint{false}; // 4 bytes.
const struct sockaddr *xorMappedAddress{nullptr};// 8 or 20 bytes.
uint16_t errorCode{0}; // 4 bytes (no reason phrase).
std::string password;
};
/* Inline class methods. */
inline bool StunPacket::IsStun(const uint8_t *data, size_t len) {
// clang-format off
return (
// STUN headers are 20 bytes.
(len >= 20) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] < 3) &&
// Magic cookie must match.
(data[4] == StunPacket::kMagicCookie[0]) && (data[5] == StunPacket::kMagicCookie[1]) &&
(data[6] == StunPacket::kMagicCookie[2]) && (data[7] == StunPacket::kMagicCookie[3])
);
// clang-format on
}
/* Inline instance methods. */
inline StunPacket::Class StunPacket::GetClass() const { return this->klass; }
inline StunPacket::Method StunPacket::GetMethod() const { return this->method; }
inline const uint8_t *StunPacket::GetData() const { return this->data; }
inline size_t StunPacket::GetSize() const { return this->size; }
inline void StunPacket::SetUsername(const char *username, size_t len) {
this->username.assign(username, len);
}
inline void StunPacket::SetPriority(const uint32_t priority) { this->priority = priority; }
inline void StunPacket::SetIceControlling(const uint64_t iceControlling) {
this->iceControlling = iceControlling;
}
inline void StunPacket::SetIceControlled(const uint64_t iceControlled) {
this->iceControlled = iceControlled;
}
inline void StunPacket::SetUseCandidate() { this->hasUseCandidate = true; }
inline void StunPacket::SetXorMappedAddress(const struct sockaddr *xorMappedAddress) {
this->xorMappedAddress = xorMappedAddress;
}
inline void StunPacket::SetErrorCode(uint16_t errorCode) { this->errorCode = errorCode; }
inline void StunPacket::SetMessageIntegrity(const uint8_t *messageIntegrity) {
this->messageIntegrity = messageIntegrity;
}
inline void StunPacket::SetFingerprint() { this->hasFingerprint = true; }
inline const std::string &StunPacket::GetUsername() const { return this->username; }
inline uint32_t StunPacket::GetPriority() const { return this->priority; }
inline uint64_t StunPacket::GetIceControlling() const { return this->iceControlling; }
inline uint64_t StunPacket::GetIceControlled() const { return this->iceControlled; }
inline bool StunPacket::HasUseCandidate() const { return this->hasUseCandidate; }
inline uint16_t StunPacket::GetErrorCode() const { return this->errorCode; }
inline bool StunPacket::HasMessageIntegrity() const {
return (this->messageIntegrity ? true : false);
}
inline bool StunPacket::HasFingerprint() const { return this->hasFingerprint; }
}// namespace RTC
#endif

139
webrtc/utils.cc Normal file
View File

@ -0,0 +1,139 @@
#define MS_CLASS "Utils::Crypto"
// #define MS_LOG_DEV
#include "utils.h"
#include "openssl/sha.h"
namespace Utils {
/* Static variables. */
uint32_t Crypto::seed;
HMAC_CTX *Crypto::hmacSha1Ctx{nullptr};
uint8_t Crypto::hmacSha1Buffer[20];// SHA-1 result is 20 bytes long.
// clang-format off
const uint32_t Crypto::crc32Table[] =
{
0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3,
0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91,
0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7,
0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5,
0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b,
0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59,
0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f,
0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d,
0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433,
0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01,
0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65,
0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb,
0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9,
0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f,
0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad,
0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683,
0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1,
0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7,
0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5,
0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b,
0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f,
0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d,
0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713,
0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21,
0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777,
0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45,
0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db,
0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9,
0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf,
0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d
};
// clang-format on
/* Static methods. */
void Crypto::ClassInit() {
// MS_TRACE();
// Init the vrypto seed with a random number taken from the address
// of the seed variable itself (which is random).
Crypto::seed = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(std::addressof(Crypto::seed)));
// Create an OpenSSL HMAC_CTX context for HMAC SHA1 calculation.
// Crypto::hmacSha1Ctx = HMAC_CTX_new();
if (Crypto::hmacSha1Ctx == nullptr) {
Crypto::hmacSha1Ctx = HMAC_CTX_new();
}
}
void Crypto::ClassDestroy() {
// MS_TRACE();
if (Crypto::hmacSha1Ctx != nullptr) {
HMAC_CTX_free(Crypto::hmacSha1Ctx);
}
}
const uint8_t *Crypto::GetHmacShA1(const std::string &key, const uint8_t *data, size_t len) {
// MS_TRACE();
size_t ret;
ret = HMAC_Init_ex(Crypto::hmacSha1Ctx, key.c_str(), key.length(), EVP_sha1(), nullptr);
// MS_ASSERT(ret == 1, "OpenSSL HMAC_Init_ex() failed with key '%s'", key.c_str());
ret = HMAC_Update(Crypto::hmacSha1Ctx, data, static_cast<int>(len));
/*
MS_ASSERT(
ret == 1,
"OpenSSL HMAC_Update() failed with key '%s' and data length %zu bytes",
key.c_str(),
len);
*/
uint32_t resultLen;
ret = HMAC_Final(Crypto::hmacSha1Ctx, (uint8_t *) Crypto::hmacSha1Buffer, &resultLen);
/*
MS_ASSERT(
ret == 1, "OpenSSL HMAC_Final() failed with key '%s' and data length %zu bytes", key.c_str(),
len); MS_ASSERT(resultLen == 20, "OpenSSL HMAC_Final() resultLen is %u instead of 20", resultLen);
*/
return Crypto::hmacSha1Buffer;
}
}// namespace Utils
namespace Utils {
static std::string inet_ntoa(struct in_addr in) {
char buf[20];
unsigned char *p = (unsigned char *) &(in);
snprintf(buf, sizeof(buf), "%u.%u.%u.%u", p[0], p[1], p[2], p[3]);
return buf;
}
void IP::GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip, uint16_t &port) {
char ipBuffer[INET6_ADDRSTRLEN + 1];
switch (addr->sa_family) {
case AF_INET: {
ip = Utils::inet_ntoa(reinterpret_cast<const struct sockaddr_in *>(addr)->sin_addr);
port = static_cast<uint16_t>(ntohs(reinterpret_cast<const struct sockaddr_in *>(addr)->sin_port));
break;
}
case AF_INET6: {
port = static_cast<uint16_t>(ntohs(reinterpret_cast<const struct sockaddr_in6 *>(addr)->sin6_port));
break;
}
default: {
// MS_ABORT("unknown network family: %d", static_cast<int>(addr->sa_family));
}
}
family = addr->sa_family;
ip.assign(ipBuffer);
}
}// namespace Utils

318
webrtc/utils.h Normal file
View File

@ -0,0 +1,318 @@
#ifndef MS_UTILS_HPP
#define MS_UTILS_HPP
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#include <Iphlpapi.h>
#pragma comment (lib, "Ws2_32.lib")
#pragma comment(lib,"Iphlpapi.lib")
#else
#include <netdb.h>
#include <arpa/inet.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <net/if.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#endif // defined(_WIN32)
#include <algorithm>// std::transform(), std::find(), std::min(), std::max()
#include <cinttypes>// PRIu64, etc
#include <cmath>
#include <cstddef>// size_t
#include <cstdint>// uint8_t, etc
#include <cstring>// std::memcmp(), std::memcpy()
#include <memory>
#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/x509.h>
#include <string>
namespace Utils {
class IP {
public:
static int GetFamily(const char *ip, size_t ipLen);
static int GetFamily(const std::string &ip);
static void GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip,
uint16_t &port);
static bool CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2);
static struct sockaddr_storage CopyAddress(const struct sockaddr *addr);
static void NormalizeIp(std::string &ip);
};
/* Inline static methods. */
inline int IP::GetFamily(const std::string &ip) { return GetFamily(ip.c_str(), ip.size()); }
inline bool IP::CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2) {
// Compare family.
if (addr1->sa_family != addr2->sa_family ||
(addr1->sa_family != AF_INET && addr1->sa_family != AF_INET6)) {
return false;
}
// Compare port.
if (reinterpret_cast<const struct sockaddr_in *>(addr1)->sin_port !=
reinterpret_cast<const struct sockaddr_in *>(addr2)->sin_port) {
return false;
}
// Compare IP.
switch (addr1->sa_family) {
case AF_INET: {
return (reinterpret_cast<const struct sockaddr_in *>(addr1)->sin_addr.s_addr ==
reinterpret_cast<const struct sockaddr_in *>(addr2)->sin_addr.s_addr);
}
case AF_INET6: {
return (std::memcmp(
std::addressof(reinterpret_cast<const struct sockaddr_in6 *>(addr1)->sin6_addr),
std::addressof(reinterpret_cast<const struct sockaddr_in6 *>(addr2)->sin6_addr),
16) == 0
? true
: false);
}
default: {
return false;
}
}
}
inline struct sockaddr_storage IP::CopyAddress(const struct sockaddr *addr) {
struct sockaddr_storage copiedAddr;
switch (addr->sa_family) {
case AF_INET:
std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in));
break;
case AF_INET6:
std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in6));
break;
}
return copiedAddr;
}
class File {
public:
static void CheckFile(const char *file);
};
class Byte {
public:
/**
* Getters below get value in Host Byte Order.
* Setters below set value in Network Byte Order.
*/
static uint8_t Get1Byte(const uint8_t *data, size_t i);
static uint16_t Get2Bytes(const uint8_t *data, size_t i);
static uint32_t Get3Bytes(const uint8_t *data, size_t i);
static uint32_t Get4Bytes(const uint8_t *data, size_t i);
static uint64_t Get8Bytes(const uint8_t *data, size_t i);
static void Set1Byte(uint8_t *data, size_t i, uint8_t value);
static void Set2Bytes(uint8_t *data, size_t i, uint16_t value);
static void Set3Bytes(uint8_t *data, size_t i, uint32_t value);
static void Set4Bytes(uint8_t *data, size_t i, uint32_t value);
static void Set8Bytes(uint8_t *data, size_t i, uint64_t value);
static uint16_t PadTo4Bytes(uint16_t size);
static uint32_t PadTo4Bytes(uint32_t size);
};
/* Inline static methods. */
inline uint8_t Byte::Get1Byte(const uint8_t *data, size_t i) { return data[i]; }
inline uint16_t Byte::Get2Bytes(const uint8_t *data, size_t i) {
return uint16_t{data[i + 1]} | uint16_t{data[i]} << 8;
}
inline uint32_t Byte::Get3Bytes(const uint8_t *data, size_t i) {
return uint32_t{data[i + 2]} | uint32_t{data[i + 1]} << 8 | uint32_t{data[i]} << 16;
}
inline uint32_t Byte::Get4Bytes(const uint8_t *data, size_t i) {
return uint32_t{data[i + 3]} | uint32_t{data[i + 2]} << 8 | uint32_t{data[i + 1]} << 16 |
uint32_t{data[i]} << 24;
}
inline uint64_t Byte::Get8Bytes(const uint8_t *data, size_t i) {
return uint64_t{Byte::Get4Bytes(data, i)} << 32 | Byte::Get4Bytes(data, i + 4);
}
inline void Byte::Set1Byte(uint8_t *data, size_t i, uint8_t value) { data[i] = value; }
inline void Byte::Set2Bytes(uint8_t *data, size_t i, uint16_t value) {
data[i + 1] = static_cast<uint8_t>(value);
data[i] = static_cast<uint8_t>(value >> 8);
}
inline void Byte::Set3Bytes(uint8_t *data, size_t i, uint32_t value) {
data[i + 2] = static_cast<uint8_t>(value);
data[i + 1] = static_cast<uint8_t>(value >> 8);
data[i] = static_cast<uint8_t>(value >> 16);
}
inline void Byte::Set4Bytes(uint8_t *data, size_t i, uint32_t value) {
data[i + 3] = static_cast<uint8_t>(value);
data[i + 2] = static_cast<uint8_t>(value >> 8);
data[i + 1] = static_cast<uint8_t>(value >> 16);
data[i] = static_cast<uint8_t>(value >> 24);
}
inline void Byte::Set8Bytes(uint8_t *data, size_t i, uint64_t value) {
data[i + 7] = static_cast<uint8_t>(value);
data[i + 6] = static_cast<uint8_t>(value >> 8);
data[i + 5] = static_cast<uint8_t>(value >> 16);
data[i + 4] = static_cast<uint8_t>(value >> 24);
data[i + 3] = static_cast<uint8_t>(value >> 32);
data[i + 2] = static_cast<uint8_t>(value >> 40);
data[i + 1] = static_cast<uint8_t>(value >> 48);
data[i] = static_cast<uint8_t>(value >> 56);
}
inline uint16_t Byte::PadTo4Bytes(uint16_t size) {
// If size is not multiple of 32 bits then pad it.
if (size & 0x03)
return (size & 0xFFFC) + 4;
else
return size;
}
inline uint32_t Byte::PadTo4Bytes(uint32_t size) {
// If size is not multiple of 32 bits then pad it.
if (size & 0x03)
return (size & 0xFFFFFFFC) + 4;
else
return size;
}
class Bits {
public:
static size_t CountSetBits(const uint16_t mask);
};
/* Inline static methods. */
class Crypto {
public:
static void ClassInit();
static void ClassDestroy();
static uint32_t GetRandomUInt(uint32_t min, uint32_t max);
static const std::string GetRandomString(size_t len);
static uint32_t GetCRC32(const uint8_t *data, size_t size);
static const uint8_t *GetHmacShA1(const std::string &key, const uint8_t *data, size_t len);
private:
static uint32_t seed;
static HMAC_CTX *hmacSha1Ctx;
static uint8_t hmacSha1Buffer[];
static const uint32_t crc32Table[256];
};
/* Inline static methods. */
inline uint32_t Crypto::GetRandomUInt(uint32_t min, uint32_t max) {
// NOTE: This is the original, but produces very small values.
// Crypto::seed = (214013 * Crypto::seed) + 2531011;
// return (((Crypto::seed>>16)&0x7FFF) % (max - min + 1)) + min;
// This seems to produce better results.
Crypto::seed = uint32_t{((214013 * Crypto::seed) + 2531011)};
return (((Crypto::seed >> 4) & 0x7FFF7FFF) % (max - min + 1)) + min;
}
inline const std::string Crypto::GetRandomString(size_t len) {
static char buffer[64];
static const char chars[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b',
'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'};
if (len > 64) len = 64;
for (size_t i{0}; i < len; ++i) {
buffer[i] = chars[GetRandomUInt(0, sizeof(chars) - 1)];
}
return std::string(buffer, len);
}
inline uint32_t Crypto::GetCRC32(const uint8_t *data, size_t size) {
uint32_t crc{0xFFFFFFFF};
const uint8_t *p = data;
while (size--) {
crc = Crypto::crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8);
}
return crc ^ ~0U;
}
class String {
public:
static void ToLowerCase(std::string &str);
};
inline void String::ToLowerCase(std::string &str) {
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
}
class Time {
// Seconds from Jan 1, 1900 to Jan 1, 1970.
static constexpr uint32_t UnixNtpOffset{0x83AA7E80};
// NTP fractional unit.
static constexpr uint64_t NtpFractionalUnit{1LL << 32};
public:
struct Ntp {
uint32_t seconds;
uint32_t fractions;
};
static Time::Ntp TimeMs2Ntp(uint64_t ms);
static uint64_t Ntp2TimeMs(Time::Ntp ntp);
static bool IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp);
static uint32_t LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2);
};
inline Time::Ntp Time::TimeMs2Ntp(uint64_t ms) {
Time::Ntp ntp;// NOLINT(cppcoreguidelines-pro-type-member-init)
ntp.seconds = uint32_t(ms / 1000);
ntp.fractions =
static_cast<uint32_t>((static_cast<double>(ms % 1000) / 1000) * NtpFractionalUnit);
return ntp;
}
inline uint64_t Time::Ntp2TimeMs(Time::Ntp ntp) {
// clang-format off
return (
static_cast<uint64_t>(ntp.seconds) * 1000 +
static_cast<uint64_t>(std::round((static_cast<double>(ntp.fractions) * 1000) / NtpFractionalUnit))
);
// clang-format on
}
inline bool Time::IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp) {
// Distinguish between elements that are exactly 0x80000000 apart.
// If t1>t2 and |t1-t2| = 0x80000000: IsNewer(t1,t2)=true,
// IsNewer(t2,t1)=false
// rather than having IsNewer(t1,t2) = IsNewer(t2,t1) = false.
if (static_cast<uint32_t>(timestamp - prevTimestamp) == 0x80000000)
return timestamp > prevTimestamp;
return timestamp != prevTimestamp &&
static_cast<uint32_t>(timestamp - prevTimestamp) < 0x80000000;
}
inline uint32_t Time::LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2) {
return IsNewerTimestamp(timestamp1, timestamp2) ? timestamp1 : timestamp2;
}
}// namespace Utils
#endif

215
webrtc/webrtc_transport.cc Normal file
View File

@ -0,0 +1,215 @@
#include "webrtc_transport.h"
#include <iostream>
#include "Rtcp/Rtcp.h"
WebRtcTransport::WebRtcTransport() {
static onceToken token([](){
Utils::Crypto::ClassInit();
RTC::DtlsTransport::ClassInit();
RTC::DepLibSRTP::ClassInit();
RTC::SrtpSession::ClassInit();
});
ice_server_ = std::make_shared<IceServer>(Utils::Crypto::GetRandomString(4), Utils::Crypto::GetRandomString(24));
ice_server_->SetIceServerCompletedCB([this]() {
this->OnIceServerCompleted();
});
ice_server_->SetSendCB([this](char *buf, size_t len, struct sockaddr_in *remote_address) {
this->WritePacket(buf, len, remote_address);
});
// todo dtls服务器或客户端模式
dtls_transport_ = std::make_shared<DtlsTransport>(true);
dtls_transport_->SetHandshakeCompletedCB([this](std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) {
this->OnDtlsCompleted(client_key, server_key, srtp_crypto_suite);
});
dtls_transport_->SetOutPutCB([this](char *buf, size_t len) { this->WritePacket(buf, len); });
}
WebRtcTransport::~WebRtcTransport() {}
std::string WebRtcTransport::GetLocalSdp() {
char sdp[1024 * 10] = {0};
auto ssrc = getSSRC();
auto ip = getIP();
auto pt = getPayloadType();
auto port = getPort();
sprintf(sdp,
"v=0\r\n"
"o=- 1495799811084970 1495799811084970 IN IP4 %s\r\n"
"s=Streaming Test\r\n"
"t=0 0\r\n"
"a=group:BUNDLE video\r\n"
"a=msid-semantic: WMS janus\r\n"
"m=video %u RTP/SAVPF %u\r\n"
"c=IN IP4 %s\r\n"
"a=mid:video\r\n"
"a=sendonly\r\n"
"a=rtcp-mux\r\n"
"a=ice-ufrag:%s\r\n"
"a=ice-pwd:%s\r\n"
"a=ice-options:trickle\r\n"
"a=fingerprint:sha-256 %s\r\n"
"a=setup:actpass\r\n"
"a=connection:new\r\n"
"a=rtpmap:%u H264/90000\r\n"
"a=ssrc:%u cname:janusvideo\r\n"
"a=ssrc:%u msid:janus janusv0\r\n"
"a=ssrc:%u mslabel:janus\r\n"
"a=ssrc:%u label:janusv0\r\n"
"a=candidate:%s 1 udp %u %s %u typ %s\r\n",
ip.c_str(), port, pt, ip.c_str(),
ice_server_->GetUsernameFragment().c_str(),ice_server_->GetPassword().c_str(),
dtls_transport_->GetMyFingerprint().c_str(), pt, ssrc, ssrc, ssrc, ssrc, "4", ssrc, ip.c_str(), port, "host");
return sdp;
}
void WebRtcTransport::OnIceServerCompleted() {
InfoL;
dtls_transport_->Start();
onIceConnected();
}
void WebRtcTransport::OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) {
InfoL << client_key << " " << server_key << " " << (int)srtp_crypto_suite;
srtp_session_ = std::make_shared<RTC::SrtpSession>(RTC::SrtpSession::Type::OUTBOUND, srtp_crypto_suite, (uint8_t *) client_key.c_str(), client_key.size());
onDtlsCompleted();
}
bool is_dtls(char *buf) {
return ((*buf > 19) && (*buf < 64));
}
bool is_rtp(char *buf) {
RtpHeader *header = (RtpHeader *) buf;
return ((header->pt < 64) || (header->pt >= 96));
}
bool is_rtcp(char *buf) {
RtpHeader *header = (RtpHeader *) buf;
return ((header->pt >= 64) && (header->pt < 96));
}
void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address) {
if (RTC::StunPacket::IsStun((const uint8_t *) buf, len)) {
InfoL << "stun:" << hexdump(buf, len);
RTC::StunPacket *packet = RTC::StunPacket::Parse((const uint8_t *) buf, len);
if (packet == nullptr) {
WarnL << "parse stun error" << std::endl;
return;
}
ice_server_->ProcessStunPacket(packet, remote_address);
return;
}
if (DtlsTransport::IsDtlsPacket(buf, len)) {
InfoL << "dtls:" << hexdump(buf, len);
dtls_transport_->InputData(buf, len);
return;
}
if (is_rtp(buf)) {
RtpHeader *header = (RtpHeader *) buf;
InfoL << "rtp:" << header->dumpString(len);
return;
}
if (is_rtcp(buf)) {
RtcpHeader *header = (RtcpHeader *) buf;
// InfoL << "rtcp:" << header->dumpString();
return;
}
}
void WebRtcTransport::WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address) {
onWrite(buf, len, remote_address ? remote_address : (ice_server_ ? ice_server_->GetSelectAddr() : nullptr));
}
void WebRtcTransport::WritRtpPacket(char *buf, size_t len) {
const uint8_t *p = (uint8_t *) buf;
bool ret = false;
if (srtp_session_) {
ret = srtp_session_->EncryptRtp(&p, &len);
}
if (ret) {
onWrite((char *) p, len, ice_server_->GetSelectAddr());
}
}
///////////////////////////////////////////////////////////////////////////////////
WebRtcTransportImp::WebRtcTransportImp(const EventPoller::Ptr &poller) {
_socket = Socket::createSocket(poller, false);
//随机端口,绑定全部网卡
_socket->bindUdpSock(0);
_socket->setOnRead([this](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len){
OnInputDataPacket(buf->data(), buf->size(), (struct sockaddr_in*)addr);
});
}
void WebRtcTransportImp::attach(const RtspMediaSource::Ptr &src) {
assert(src);
_src = src;
}
void WebRtcTransportImp::onDtlsCompleted() {
_reader = _src->getRing()->attach(_socket->getPoller(), true);
weak_ptr<WebRtcTransportImp> weak_self = shared_from_this();
_reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt){
auto strongSelf = weak_self.lock();
if (!strongSelf) {
return;
}
pkt->for_each([&](const RtpPacket::Ptr &rtp) {
if(rtp->type == TrackVideo) {
//目前只支持视频
strongSelf->WritRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize,
rtp->size() - RtpPacket::kRtpTcpHeaderSize);
}
});
});
}
void WebRtcTransportImp::onIceConnected(){
}
void WebRtcTransportImp::onWrite(const char *buf, size_t len, struct sockaddr_in *dst) {
auto ptr = BufferRaw::create();
ptr->assign(buf, len);
// InfoL << len << " " << SockUtil::inet_ntoa(dst->sin_addr) << " " << ntohs(dst->sin_port);
_socket->send(ptr, (struct sockaddr *)(dst), sizeof(struct sockaddr));
}
uint32_t WebRtcTransportImp::getSSRC() const {
return _src->getSsrc(TrackVideo);
}
int WebRtcTransportImp::getPayloadType() const{
auto sdp = SdpParser(_src->getSdp());
auto track = sdp.getTrack(TrackVideo);
assert(track);
return track ? track->_pt : 0;
}
uint16_t WebRtcTransportImp::getPort() const {
//todo udp端口号应该与外网映射端口相同
return _socket->get_local_port();
}
std::string WebRtcTransportImp::getIP() const {
//todo 替换为外网ip
return SockUtil::get_local_ip();
}
///////////////////////////////////////////////////////////////////
INSTANCE_IMP(WebRtcManager)
WebRtcManager::WebRtcManager() {
}
WebRtcManager::~WebRtcManager() {
}

112
webrtc/webrtc_transport.h Normal file
View File

@ -0,0 +1,112 @@
#pragma once
#include <memory>
#include <string>
#include "dtls_transport.h"
#include "ice_server.h"
#include "srtp_session.h"
#include "stun_packet.h"
class WebRtcTransport {
public:
using Ptr = std::shared_ptr<WebRtcTransport>;
WebRtcTransport();
virtual ~WebRtcTransport();
/// 获取本地sdp
/// \return
std::string GetLocalSdp();
/// 收到udp数据
/// \param buf
/// \param len
/// \param remote_address
void OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address);
/// 发送rtp
/// \param buf
/// \param len
void WritRtpPacket(char *buf, size_t len);
protected:
/// 输出udp数据
/// \param buf
/// \param len
/// \param dst
virtual void onWrite(const char *buf, size_t len, struct sockaddr_in *dst) = 0;
virtual uint32_t getSSRC() const = 0;
virtual uint16_t getPort() const = 0;
virtual std::string getIP() const = 0;
virtual int getPayloadType() const = 0;
virtual void onIceConnected() = 0;
virtual void onDtlsCompleted() = 0;
private:
void OnIceServerCompleted();
void OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite);
void WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address = nullptr);
private:
IceServer::Ptr ice_server_;
DtlsTransport::Ptr dtls_transport_;
std::shared_ptr<RTC::SrtpSession> srtp_session_;
};
#include "Poller/EventPoller.h"
#include "Network/Socket.h"
#include "Rtsp/RtspMediaSource.h"
using namespace toolkit;
using namespace mediakit;
class WebRtcTransportImp : public WebRtcTransport, public std::enable_shared_from_this<WebRtcTransportImp>{
public:
using Ptr = std::shared_ptr<WebRtcTransportImp>;
WebRtcTransportImp(const EventPoller::Ptr &poller);
~WebRtcTransportImp() override = default;
void attach(const RtspMediaSource::Ptr &src);
protected:
void onWrite(const char *buf, size_t len, struct sockaddr_in *dst) override;
int getPayloadType() const ;
uint32_t getSSRC() const override;
uint16_t getPort() const override;
std::string getIP() const override;
void onIceConnected() override;
void onDtlsCompleted() override;
private:
Socket::Ptr _socket;
RtspMediaSource::Ptr _src;
RtspMediaSource::RingType::RingReader::Ptr _reader;
};
class WebRtcManager : public std::enable_shared_from_this<WebRtcManager> {
public:
~WebRtcManager();
static WebRtcManager& Instance();
private:
WebRtcManager();
};