diff --git a/include/ydb-cpp-sdk/client/types/ydb.h b/include/ydb-cpp-sdk/client/types/ydb.h index c62652c609..ae689699df 100644 --- a/include/ydb-cpp-sdk/client/types/ydb.h +++ b/include/ydb-cpp-sdk/client/types/ydb.h @@ -54,6 +54,9 @@ class TBalancingPolicy { //! location is a name of datacenter (VLA, MAN), if location is nullopt local datacenter is used static TBalancingPolicy UsePreferableLocation(const std::optional& location = {}); + //! Use detected local dc + static TBalancingPolicy UseDetectedLocalDC(); + //! Use all available cluster nodes regardless datacenter locality static TBalancingPolicy UseAllNodes(); diff --git a/src/client/impl/internal/CMakeLists.txt b/src/client/impl/internal/CMakeLists.txt index 5370c34f57..56bfbc0045 100644 --- a/src/client/impl/internal/CMakeLists.txt +++ b/src/client/impl/internal/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(common) add_subdirectory(db_driver_state) add_subdirectory(grpc_connections) +add_subdirectory(local_dc_detector) add_subdirectory(logger) add_subdirectory(make_request) add_subdirectory(plain_status) diff --git a/src/client/impl/internal/common/balancing_policies.cpp b/src/client/impl/internal/common/balancing_policies.cpp index 22ec50a622..7d2301231e 100644 --- a/src/client/impl/internal/common/balancing_policies.cpp +++ b/src/client/impl/internal/common/balancing_policies.cpp @@ -16,6 +16,12 @@ TBalancingPolicy::TImpl TBalancingPolicy::TImpl::UsePreferableLocation(const std return impl; } +TBalancingPolicy::TImpl TBalancingPolicy::TImpl::UseDetectedLocalDC() { + TBalancingPolicy::TImpl impl; + impl.PolicyType = EPolicyType::UseDetectedLocalDC; + return impl; +} + TBalancingPolicy::TImpl TBalancingPolicy::TImpl::UsePreferablePileState(EPileState pileState) { TBalancingPolicy::TImpl impl; impl.PolicyType = EPolicyType::UsePreferablePileState; diff --git a/src/client/impl/internal/common/balancing_policies.h b/src/client/impl/internal/common/balancing_policies.h index f1180f37ed..49a3ae505b 100644 --- a/src/client/impl/internal/common/balancing_policies.h +++ b/src/client/impl/internal/common/balancing_policies.h @@ -14,6 +14,7 @@ class TBalancingPolicy::TImpl { enum class EPolicyType { UseAllNodes, UsePreferableLocation, + UseDetectedLocalDC, UsePreferablePileState }; @@ -21,6 +22,8 @@ class TBalancingPolicy::TImpl { static TImpl UsePreferableLocation(const std::optional& location); + static TImpl UseDetectedLocalDC(); + static TImpl UsePreferablePileState(EPileState pileState); EPolicyType PolicyType; diff --git a/src/client/impl/internal/db_driver_state/CMakeLists.txt b/src/client/impl/internal/db_driver_state/CMakeLists.txt index 09089e4c10..cc6ec36d7b 100644 --- a/src/client/impl/internal/db_driver_state/CMakeLists.txt +++ b/src/client/impl/internal/db_driver_state/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries(impl-internal-db_driver_state PUBLIC client-impl-ydb_endpoints impl-internal-logger impl-internal-plain_status + impl-internal-local_dc_detector client-types-credentials ) diff --git a/src/client/impl/internal/db_driver_state/endpoint_pool.cpp b/src/client/impl/internal/db_driver_state/endpoint_pool.cpp index 8bdbd19262..a6435a580b 100644 --- a/src/client/impl/internal/db_driver_state/endpoint_pool.cpp +++ b/src/client/impl/internal/db_driver_state/endpoint_pool.cpp @@ -41,6 +41,10 @@ std::pair, bool> TEndpointPool::Updat TListEndpointsResult result = future.GetValue(); std::vector removed; if (result.DiscoveryStatus.Status == EStatus::SUCCESS) { + if (BalancingPolicy_.PolicyType == TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC) { + LocalDCDetector_.DetectLocalDC(result.Result); + } + std::vector records; // Is used to convert float to integer load factor // same integer values will be selected randomly. @@ -182,6 +186,8 @@ bool TEndpointPool::IsPreferredEndpoint(const Ydb::Discovery::EndpointInfo& endp return true; case TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation: return endpoint.location() == BalancingPolicy_.Location.value_or(selfLocation); + case TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC: + return LocalDCDetector_.IsLocalDC(endpoint.location()); case TBalancingPolicy::TImpl::EPolicyType::UsePreferablePileState: if (auto it = pileStates.find(endpoint.bridge_pile_name()); it != pileStates.end()) { return GetPileState(it->second.state()) == BalancingPolicy_.PileState; diff --git a/src/client/impl/internal/db_driver_state/endpoint_pool.h b/src/client/impl/internal/db_driver_state/endpoint_pool.h index b534593337..a1ec7263b0 100644 --- a/src/client/impl/internal/db_driver_state/endpoint_pool.h +++ b/src/client/impl/internal/db_driver_state/endpoint_pool.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -57,7 +58,9 @@ class TEndpointPool { TEndpointElectorSafe Elector_; NThreading::TPromise DiscoveryPromise_; std::atomic_uint64_t LastUpdateTime_; + const TBalancingPolicy::TImpl BalancingPolicy_; + TLocalDCDetector LocalDCDetector_; NSdkStats::TStatCollector* StatCollector_ = nullptr; diff --git a/src/client/impl/internal/local_dc_detector/CMakeLists.txt b/src/client/impl/internal/local_dc_detector/CMakeLists.txt new file mode 100644 index 0000000000..71a33f6aad --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/CMakeLists.txt @@ -0,0 +1,13 @@ +_ydb_sdk_add_library(impl-internal-local_dc_detector) + +target_link_libraries(impl-internal-local_dc_detector PUBLIC + yutil + api-protos +) + +target_sources(impl-internal-local_dc_detector PRIVATE + local_dc_detector.cpp + pinger.cpp +) + +_ydb_sdk_install_targets(TARGETS impl-internal-local_dc_detector) diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp new file mode 100644 index 0000000000..b9fc78b35e --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp @@ -0,0 +1,75 @@ +#define INCLUDE_YDB_INTERNAL_H +#include "local_dc_detector.h" + +namespace NYdb::inline V3 { + +TLocalDCDetector::TLocalDCDetector(TPinger pingEndpoint) + : PingEndpoint_(std::move(pingEndpoint)) +{} + +void TLocalDCDetector::DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpointsList) { + auto endpointsByLocation = GroupEndpointsByLocation(endpointsList); + SampleEndpoints(endpointsByLocation); + + if (endpointsByLocation.size() > 1) { + Location_ = FindNearestLocation(endpointsByLocation); + } else { + Location_.clear(); + } +} + +bool TLocalDCDetector::IsLocalDC(const std::string& location) const { + return Location_.empty() || Location_ == location; +} + +TLocalDCDetector::TEndpointsByLocation TLocalDCDetector::GroupEndpointsByLocation(const Ydb::Discovery::ListEndpointsResult& endpointsList) const { + TEndpointsByLocation endpointsByLocation; + for (const auto& endpoint : endpointsList.endpoints()) { + endpointsByLocation[endpoint.location()].emplace_back(endpoint); + } + return endpointsByLocation; +} + +void TLocalDCDetector::SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const { + std::mt19937 gen(std::random_device{}()); + for (auto& [location, endpoints] : endpointsByLocation) { + if (endpoints.size() > MAX_ENDPOINTS_PER_LOCATION) { + std::vector sample; + sample.reserve(MAX_ENDPOINTS_PER_LOCATION); + std::sample(endpoints.begin(), endpoints.end(), std::back_inserter(sample), MAX_ENDPOINTS_PER_LOCATION, gen); + endpoints.swap(sample); + } + } +} + +std::uint64_t TLocalDCDetector::MeasureLocationRtt(const std::vector& endpoints) const { + if (endpoints.empty()) { + return std::numeric_limits::max(); + } + + std::vector timings; + timings.reserve(PING_COUNT); + for (size_t i = 0; i < PING_COUNT; ++i) { + const auto& ep = endpoints[i % endpoints.size()].get(); + timings.push_back(PingEndpoint_(ep).MicroSeconds()); + } + std::sort(timings.begin(), timings.end()); + + return std::midpoint(timings[(PING_COUNT - 1) / 2], timings[PING_COUNT / 2]); +} + + +std::string TLocalDCDetector::FindNearestLocation(const TEndpointsByLocation& endpointsByLocation) { + auto minRtt = std::numeric_limits::max(); + std::string nearestLocation; + for (const auto& [location, endpoints] : endpointsByLocation) { + auto rtt = MeasureLocationRtt(endpoints); + if (rtt < minRtt) { + minRtt = rtt; + nearestLocation = location; + } + } + return nearestLocation; +} + +} // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.h b/src/client/impl/internal/local_dc_detector/local_dc_detector.h new file mode 100644 index 0000000000..d255f8c047 --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace NYdb::inline V3 { + +class TLocalDCDetector { +public: + using TPinger = std::function; + explicit TLocalDCDetector(TPinger pingEndpoint = PingEndpoint); + + void DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpoints); + bool IsLocalDC(const std::string& location) const; + +private: + using TEndpoint = Ydb::Discovery::EndpointInfo; + using TEndpointRef = std::reference_wrapper; + using TEndpointsByLocation = std::unordered_map>; + + TEndpointsByLocation GroupEndpointsByLocation(const Ydb::Discovery::ListEndpointsResult& endpointsList) const; + void SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const; + std::uint64_t MeasureLocationRtt(const std::vector& endpoints) const; + std::string FindNearestLocation(const TEndpointsByLocation& endpointsByLocation); + +private: + static constexpr std::size_t MAX_ENDPOINTS_PER_LOCATION = 3; + static constexpr std::size_t PING_COUNT = 2 * MAX_ENDPOINTS_PER_LOCATION; + + TPinger PingEndpoint_; + std::string Location_; +}; + +} // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/pinger.cpp b/src/client/impl/internal/local_dc_detector/pinger.cpp new file mode 100644 index 0000000000..1acb49c9b3 --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/pinger.cpp @@ -0,0 +1,17 @@ +#define INCLUDE_YDB_INTERNAL_H +#include "pinger.h" + +namespace NYdb::inline V3 { + +TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint) { + try { + TNetworkAddress addr(endpoint.address().data(), static_cast(endpoint.port())); + auto start = TInstant::Now(); + TSocket sock(addr, TDuration::Seconds(PING_TIMEOUT_SECONDS)); + return TInstant::Now() - start; + } catch (...) { + return TDuration::Max(); + } +} + +} // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/pinger.h b/src/client/impl/internal/local_dc_detector/pinger.h new file mode 100644 index 0000000000..627f7a368f --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/pinger.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include +#include + +namespace NYdb::inline V3 { + +static constexpr std::uint32_t PING_TIMEOUT_SECONDS = 5; + +TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint); + +} // namespace NYdb diff --git a/src/client/table/impl/table_client.cpp b/src/client/table/impl/table_client.cpp index 26594dfb0b..6d311308a1 100644 --- a/src/client/table/impl/table_client.cpp +++ b/src/client/table/impl/table_client.cpp @@ -229,8 +229,10 @@ void TTableClient::TImpl::StartPeriodicHostScanTask() { const auto balancingPolicy = strongClient->DbDriverState_->GetBalancingPolicyType(); // Try to find any host at foreign locations if prefer local dc - const ui64 foreignHost = (balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation) ? - ScanForeignLocations(strongClient) : 0; + const ui64 foreignHost = + balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation || + balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC ? + ScanForeignLocations(strongClient) : 0; std::unordered_map hostMap; diff --git a/src/client/types/ydb.cpp b/src/client/types/ydb.cpp index f6083ec25e..2f7eef2b3a 100644 --- a/src/client/types/ydb.cpp +++ b/src/client/types/ydb.cpp @@ -22,6 +22,10 @@ TBalancingPolicy TBalancingPolicy::UsePreferableLocation(const std::optional(TImpl::UsePreferableLocation(location))); } +TBalancingPolicy TBalancingPolicy::UseDetectedLocalDC() { + return TBalancingPolicy(std::make_unique(TImpl::UseDetectedLocalDC())); +} + TBalancingPolicy TBalancingPolicy::UseAllNodes() { return TBalancingPolicy(std::make_unique(TImpl::UseAllNodes())); } diff --git a/tests/unit/client/CMakeLists.txt b/tests/unit/client/CMakeLists.txt index 8c3b142ee7..f3452a0c45 100644 --- a/tests/unit/client/CMakeLists.txt +++ b/tests/unit/client/CMakeLists.txt @@ -45,6 +45,19 @@ add_ydb_test(NAME client-impl-ydb_endpoints_ut unit ) +add_ydb_test(NAME client-impl-internal-local_dc_detector_ut + INCLUDE_DIRS + ${YDB_SDK_SOURCE_DIR}/src/client/impl/internal/local_dc_detector + SOURCES + local_dc_detector/local_dc_detector_ut.cpp + LINK_LIBRARIES + yutil + api-protos + impl-internal-local_dc_detector + LABELS + unit +) + add_ydb_test(NAME client-oauth2_ut SOURCES oauth2_token_exchange/credentials_ut.cpp diff --git a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp new file mode 100644 index 0000000000..d68f0ebec7 --- /dev/null +++ b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp @@ -0,0 +1,246 @@ +#include + +#define INCLUDE_YDB_INTERNAL_H +#include +#undef INCLUDE_YDB_INTERNAL_H + +#include +#include + +using namespace NYdb; + +class TMockedEndpoint { +public: + explicit TMockedEndpoint(std::vector measures) + : Measures_(std::move(measures)) + , Idx_(0) + {} + + TDuration Ping() { + std::size_t idx = Idx_++; + + if (idx < Measures_.size()) { + return Measures_.at(idx); + } + return TDuration::Max(); + } + +private: + const std::vector Measures_; + std::size_t Idx_; +}; + +class TMockedPinger { +public: + explicit TMockedPinger(std::unordered_map> measuresByAdress) { + EndpointByAdress_.reserve(measuresByAdress.size()); + + for (auto& [adress, measures] : measuresByAdress) { + EndpointByAdress_.emplace(std::move(adress), std::move(measures)); + } + } + + TDuration operator()(const Ydb::Discovery::EndpointInfo& endpoint) const { + auto it = EndpointByAdress_.find(endpoint.address()); + if (it == EndpointByAdress_.end() || Blacklist_.contains(endpoint.address())) { + return TDuration::Max(); + } + return it->second.Ping(); + } + + void BanEndpoint(const std::string& adress) { + Blacklist_.insert(adress); + } + + void UnbanEndpoint(const std::string& adress) { + Blacklist_.erase(adress); + } + +private: + mutable std::unordered_map EndpointByAdress_; + std::unordered_set Blacklist_; +}; + +std::vector GenerateMeasures(size_t count, int minMs, int maxMs, std::mt19937& gen) { + std::vector measures; + measures.reserve(count); + std::uniform_int_distribution distrib(minMs, maxMs); + for (size_t i = 0; i < count; ++i) { + measures.push_back(TDuration::MicroSeconds(distrib(gen))); + } + return measures; +} + +Y_UNIT_TEST_SUITE(LocalDCDetectionTest) { + Y_UNIT_TEST(Basic) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; + const std::vector endpointsB = {"B1", "B2", "B3"}; + const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + + const std::size_t epoches = 3; + const std::size_t measuresAmount = 10 * epoches; + + for (const auto& ep : endpointsA) { + mockData[ep] = GenerateMeasures(measuresAmount, 20, 30, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsB) { + mockData[ep] = GenerateMeasures(measuresAmount, 30, 45, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("B"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsC) { + mockData[ep] = GenerateMeasures(measuresAmount, 50, 70, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("C"); + endpoint.set_address(ep); + } + + std::function pinger = TMockedPinger(mockData); + TLocalDCDetector detector(pinger); + + for (std::size_t i = 0; i < epoches; ++i) { + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + } + } + + Y_UNIT_TEST(SingleLocation) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3"}; + + for (const auto& ep : endpointsA) { + mockData[ep] = GenerateMeasures(10, 20, 30, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + + std::function pinger = TMockedPinger(mockData); + TLocalDCDetector detector(pinger); + + detector.DetectLocalDC(endpoints); + + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + } + + Y_UNIT_TEST(UnavailableLocalDC) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; + const std::vector endpointsB = {"B1", "B2", "B3", "B4", "B5", "B6", "B7"}; + const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + + const std::size_t epoches = 3; + const std::size_t measuresAmount = 10 * epoches; + + for (const auto& ep : endpointsA) { + mockData[ep] = GenerateMeasures(measuresAmount, 20, 30, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsB) { + mockData[ep] = GenerateMeasures(measuresAmount, 30, 45, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("B"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsC) { + mockData[ep] = GenerateMeasures(measuresAmount, 50, 70, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("C"); + endpoint.set_address(ep); + } + + TMockedPinger mockPinger(mockData); + std::function pinger = + [&mockPinger](const Ydb::Discovery::EndpointInfo& endpoint) { return mockPinger(endpoint); }; + + TLocalDCDetector detector(pinger); + + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + + for (const auto& ep : endpointsA) { + mockPinger.BanEndpoint(ep); + } + + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + + for (const auto& ep : endpointsA) { + mockPinger.UnbanEndpoint(ep); + } + + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + } + + Y_UNIT_TEST(OfflineDCs) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; + const std::vector endpointsB = {"B1", "B2", "B3", "B4", "B5", "B6", "B7"}; + const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + + for (const auto& ep : endpointsA) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsB) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("B"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsC) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("C"); + endpoint.set_address(ep); + } + + TMockedPinger mockPinger(mockData); + std::function pinger = + [&mockPinger](const Ydb::Discovery::EndpointInfo& endpoint) { return mockPinger(endpoint); }; + + TLocalDCDetector detector(pinger); + + for (const auto& ep : endpointsA) { + mockPinger.BanEndpoint(ep); + } + for (const auto& ep : endpointsB) { + mockPinger.BanEndpoint(ep); + } + for (const auto& ep : endpointsC) { + mockPinger.BanEndpoint(ep); + } + + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("C")); + } +}