Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 7 additions & 37 deletions src/extension/network/NUClearNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,37 +112,11 @@ namespace extension {
next_event_callback = std::move(f);
}

std::array<uint16_t, 9> NUClearNetwork::udp_key(const sock_t& address) {

// Get our keys for our maps, it will be the ip and then port
std::array<uint16_t, 9> key = {0};

switch (address.sock.sa_family) {
case AF_INET:
// The first chars are 0 (ipv6) and after that is our address and then port
std::memcpy(&key[6], &address.ipv4.sin_addr, sizeof(address.ipv4.sin_addr));
key[8] = address.ipv4.sin_port;
break;

case AF_INET6:
// IPv6 address then port
std::memcpy(key.data(), &address.ipv6.sin6_addr, sizeof(address.ipv6.sin6_addr));
key[8] = address.ipv6.sin6_port;
break;

default: throw std::invalid_argument("Unknown address family");
}

return key;
}


void NUClearNetwork::remove_target(const std::shared_ptr<NetworkTarget>& target) {

// Erase udp
auto key = udp_key(target->target);
if (udp_target.find(key) != udp_target.end()) {
udp_target.erase(udp_target.find(key));
if (udp_target.find(target->target) != udp_target.end()) {
udp_target.erase(udp_target.find(target->target));
}

// Erase name
Expand All @@ -161,7 +135,6 @@ namespace extension {
}
}


void NUClearNetwork::open_data(const sock_t& bind_address) {

// Create the "join any" address for this address family
Expand Down Expand Up @@ -383,7 +356,7 @@ namespace extension {
auto all_target = std::make_shared<NetworkTarget>("", announce_target);
targets.push_front(all_target);
name_target.insert(std::make_pair("", all_target));
udp_target.insert(std::make_pair(udp_key(announce_target), all_target));
udp_target.insert(std::make_pair(announce_target, all_target));

// Work out our MTU for udp packets
packet_data_mtu = network_mtu; // Start with the total mtu
Expand Down Expand Up @@ -570,14 +543,11 @@ namespace extension {
// This is a real packet! get our header information
const PacketHeader& header = *reinterpret_cast<const PacketHeader*>(payload.data());

// Get the map key for this device
auto key = udp_key(address);

// From here on, we are doing things with our target lists that if changed would make us sad
std::shared_ptr<NetworkTarget> remote;
/* Mutex scope */ {
const std::lock_guard<std::mutex> lock(target_mutex);
auto r = udp_target.find(key);
auto r = udp_target.find(address);
remote = r == udp_target.end() ? nullptr : r->second;
}

Expand All @@ -601,10 +571,10 @@ namespace extension {
const std::lock_guard<std::mutex> lock(target_mutex);

// Double check they are new
if (udp_target.count(key) == 0) {
if (udp_target.count(address) == 0) {
new_connection = true;
targets.push_back(ptr);
udp_target.insert(std::make_pair(key, ptr));
udp_target.insert(std::make_pair(address, ptr));
name_target.insert(std::make_pair(name, ptr));

// Say hi back!
Expand Down Expand Up @@ -639,7 +609,7 @@ namespace extension {
const std::lock_guard<std::mutex> lock(target_mutex);

// Double check they are gone after locking before removal
if (udp_target.count(key) > 0) {
if (udp_target.count(address) > 0) {
left = true;
remove_target(remote);
}
Expand Down
4 changes: 2 additions & 2 deletions src/extension/network/NUClearNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ namespace extension {
/// A map of string names to targets with that name
std::multimap<std::string, std::shared_ptr<NetworkTarget>, std::less<>> name_target;

/// A map of ip/port pairs to the network target they belong to
std::map<std::array<uint16_t, 9>, std::shared_ptr<NetworkTarget>> udp_target;
/// A map of socket addresses to the network target they belong to
std::map<sock_t, std::shared_ptr<NetworkTarget>> udp_target;
};

} // namespace network
Expand Down
65 changes: 65 additions & 0 deletions src/util/network/sock_t.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* MIT License
*
* Copyright (c) 2025 NUClear Contributors
*
* This file is part of the NUClear codebase.
* See https://github.com/Fastcode/NUClear for further info.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
* Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "sock_t.hpp"

#include <netdb.h>

#include <cstdlib>

namespace NUClear {
namespace util {
namespace network {

socklen_t sock_t::size() const {
switch (sock.sa_family) {
case AF_INET: return sizeof(sockaddr_in);
case AF_INET6: return sizeof(sockaddr_in6);
default: throw std::system_error(EAFNOSUPPORT, std::system_category(), "Unsupported address family");
}
}

std::pair<std::string, in_port_t> sock_t::address(bool numeric) const {
std::array<char, NI_MAXHOST> host{};
std::array<char, NI_MAXSERV> service{};

auto flags = NI_NUMERICSERV | (numeric ? NI_NUMERICHOST : 0);

if (::getnameinfo(&sock,
size(),
host.data(),
static_cast<socklen_t>(host.size()),
service.data(),
static_cast<socklen_t>(service.size()),
flags)
!= 0) {
throw std::system_error(
network_errno,
std::system_category(),
"Cannot get address for socket address family " + std::to_string(sock.sa_family));
}

return {host.data(), static_cast<in_port_t>(std::stoi(service.data()))};
}

} // namespace network
} // namespace util
} // namespace NUClear
138 changes: 111 additions & 27 deletions src/util/network/sock_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,135 @@
#define NUCLEAR_UTIL_NETWORK_SOCK_T_HPP

#include <array>
#include <cstdint>
#include <cstring>
#include <ostream>
#include <stdexcept>
#include <string>
#include <system_error>
#include <tuple>

#include "../platform.hpp"

namespace NUClear {
namespace util {
namespace network {

/**
* A struct representing a socket address, supporting both IPv4 and IPv6.
* This struct provides a unified interface for handling socket addresses across different address families.
*/
struct sock_t {
/**
* A union of socket address structures, allowing access to the address in different formats.
* This union includes sockaddr_storage, sockaddr, sockaddr_in, and sockaddr_in6.
*/
union {
sockaddr_storage storage;
sockaddr sock;
sockaddr_in ipv4;
sockaddr_in6 ipv6;
sockaddr_storage storage; //< The storage for the socket address
sockaddr sock; //< The socket address
sockaddr_in ipv4; //< The IPv4 address
sockaddr_in6 ipv6; //< The IPv6 address
};

socklen_t size() const {
switch (sock.sa_family) {
case AF_INET: return sizeof(sockaddr_in);
case AF_INET6: return sizeof(sockaddr_in6);
default:
throw std::runtime_error("Cannot get size for socket address family "
+ std::to_string(sock.sa_family));
/**
* Equality comparison operator for sock_t.
*
* @param a The first sock_t object
* @param b The second sock_t object
*
* @return true if the addresses are equal, false otherwise.
*/
friend bool operator==(const sock_t& a, const sock_t& b) {
if ((a.sock.sa_family != AF_INET && a.sock.sa_family != AF_INET6)
|| (b.sock.sa_family != AF_INET && b.sock.sa_family != AF_INET6)) {
throw std::system_error(EAFNOSUPPORT, std::system_category(), "Unsupported address family");
}

if (a.sock.sa_family != b.sock.sa_family) {
return false;
}

if (a.sock.sa_family == AF_INET) {
return a.ipv4.sin_port == b.ipv4.sin_port && a.ipv4.sin_addr.s_addr == b.ipv4.sin_addr.s_addr;
}

return a.ipv6.sin6_port == b.ipv6.sin6_port
&& std::memcmp(&a.ipv6.sin6_addr.s6_addr, &b.ipv6.sin6_addr.s6_addr, sizeof(in6_addr)) == 0;
}

std::pair<std::string, in_port_t> address(bool numeric_host = false) const {
std::array<char, NI_MAXHOST> host{};
std::array<char, NI_MAXSERV> service{};
const int result = ::getnameinfo(reinterpret_cast<const sockaddr*>(&storage),
size(),
host.data(),
static_cast<socklen_t>(host.size()),
service.data(),
static_cast<socklen_t>(service.size()),
NI_NUMERICSERV | (numeric_host ? NI_NUMERICHOST : 0));
if (result != 0) {
throw std::system_error(
network_errno,
std::system_category(),
"Cannot get address for socket address family " + std::to_string(sock.sa_family));
/**
* Inequality comparison operator for sock_t.
*
* @param a The first sock_t object
* @param b The second sock_t object
*
* @return true if the addresses are not equal, false otherwise
*/
friend bool operator!=(const sock_t& a, const sock_t& b) {
return !(a == b);
}

/**
* Less-than comparison operator for sock_t.
*
* @param a The first sock_t object
* @param b The second sock_t object
*
* @return true if a is less than b, false otherwise
*/
friend bool operator<(const sock_t& a, const sock_t& b) {
if ((a.sock.sa_family != AF_INET && a.sock.sa_family != AF_INET6)
|| (b.sock.sa_family != AF_INET && b.sock.sa_family != AF_INET6)) {
throw std::system_error(EAFNOSUPPORT, std::system_category(), "Unsupported address family");
}

if (a.sock.sa_family != b.sock.sa_family) {
return a.sock.sa_family < b.sock.sa_family;
}
return std::make_pair(std::string(host.data()), static_cast<in_port_t>(std::stoi(service.data())));
if (a.sock.sa_family == AF_INET) {
return std::forward_as_tuple(ntohl(a.ipv4.sin_addr.s_addr), ntohs(a.ipv4.sin_port))
< std::forward_as_tuple(ntohl(b.ipv4.sin_addr.s_addr), ntohs(b.ipv4.sin_port));
}

auto cmp = std::memcmp(a.ipv6.sin6_addr.s6_addr, b.ipv6.sin6_addr.s6_addr, sizeof(a.ipv6.sin6_addr));
if (cmp != 0) {
return cmp < 0;
}

return ntohs(a.ipv6.sin6_port) < ntohs(b.ipv6.sin6_port);
}

/**
* Returns the size of the socket address structure.
*
* @return The size of the socket address structure
*
* @throws std::system_error if the address family is unsupported
*/
socklen_t size() const;

/**
* Resolves the socket address to a hostname and port.
*
* @param numeric If true, returns the numeric IP address instead of the hostname
*
* @return A pair containing the hostname (or numeric IP) and the port
*
* @throws std::system_error if the address cannot be resolved
*/
std::pair<std::string, in_port_t> address(bool numeric = false) const;

/**
* Output stream operator for sock_t.
* Outputs the address in the format "{host}:{port}"
*
* @param os The output stream to write to
* @param addr The socket address to output
* @return The output stream
*/
friend std::ostream& operator<<(std::ostream& os, const sock_t& addr) {
auto addr_pair = addr.address(true);
return os << addr_pair.first << ":" << addr_pair.second;
}
};

Expand Down
Loading
Loading