diff --git a/Robust.Shared.Tests/Robust.Shared.Tests.csproj b/Robust.Shared.Tests/Robust.Shared.Tests.csproj new file mode 100644 index 00000000000..5b22ba23274 --- /dev/null +++ b/Robust.Shared.Tests/Robust.Shared.Tests.csproj @@ -0,0 +1,31 @@ + + + + + enable + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Robust.Shared/CVars.cs b/Robust.Shared/CVars.cs index 3019547cc5b..1e5507c934b 100644 --- a/Robust.Shared/CVars.cs +++ b/Robust.Shared/CVars.cs @@ -337,6 +337,31 @@ protected CVars() public static readonly CVarDef NetLidgrenAppIdentifier = CVarDef.Create("net.lidgren_app_identifier", "RobustToolbox"); + /// + /// Whether to disconnect clients that exceed the decryption failure threshold. + /// + public static readonly CVarDef NetDecryptFailKick = + CVarDef.Create("net.dos_fail_kick", true, CVar.SERVERONLY); + + /// + /// Number of decryption failures from a single IP (or /64 subnet for IPv6) before logging a ban warning and optionally disconnecting. + /// + public static readonly CVarDef NetDecryptFailBanThreshold = + CVarDef.Create("net.dos_fail_ban_threshold", 10, CVar.SERVERONLY); + + /// + /// How often (in minutes) to clean up stale decryption failure records. + /// Records are only removed if they have not been seen for this many minutes. + /// + public static readonly CVarDef NetDecryptFailCleanupInterval = + CVarDef.Create("net.dos_fail_cleanup_interval", 10, CVar.SERVERONLY); + + /// + /// Maximum number of IPs tracked for decryption failures. Prevents memory exhaustion from botnet attacks. + /// + public static readonly CVarDef NetDecryptFailMaxTracked = + CVarDef.Create("net.dos_fail_max_tracked", 10000, CVar.SERVERONLY); + /// /// Add random fake network loss to all outgoing UDP network packets, as a ratio of how many packets to drop. /// 0 = no packet loss, 1 = all packets dropped diff --git a/Robust.Shared/Network/NetEncryption.cs b/Robust.Shared/Network/NetEncryption.cs index 1aef24d602f..8040cfc101a 100644 --- a/Robust.Shared/Network/NetEncryption.cs +++ b/Robust.Shared/Network/NetEncryption.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Buffers; using System.Buffers.Binary; using System.Threading; @@ -84,7 +84,12 @@ public unsafe void Encrypt(NetOutgoingMessage message) ArrayPool.Shared.Return(returnPool); } - public unsafe void Decrypt(NetIncomingMessage message) + /// + /// Attempts to decrypt an incoming network message, falliably. + /// + /// The message to decrypt in-place. This will be mutated with the decrypted results. + /// Whether the operation was successful. If this fails, you likely want to drop the connection. + public unsafe bool TryDecrypt(NetIncomingMessage message) { var nonce = message.ReadUInt64(); var cipherText = message.Data.AsSpan(sizeof(ulong), message.LengthBytes - sizeof(ulong)); @@ -109,12 +114,13 @@ public unsafe void Decrypt(NetIncomingMessage message) // key _key); - message.Position = 0; - message.LengthBytes = messageLength; - ArrayPool.Shared.Return(buffer); if (!result) - throw new SodiumException("Decryption operation failed!"); + return false; + + message.Position = 0; + message.LengthBytes = messageLength; + return true; } } diff --git a/Robust.Shared/Network/NetManager.ClientConnect.cs b/Robust.Shared/Network/NetManager.ClientConnect.cs index 238bc9d0aa6..14cf54f3ddb 100644 --- a/Robust.Shared/Network/NetManager.ClientConnect.cs +++ b/Robust.Shared/Network/NetManager.ClientConnect.cs @@ -220,13 +220,20 @@ private async Task CCDoHandshake( // Expect login success here. response = await AwaitData(connection, cancel); - encryption?.Decrypt(response); + + // Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption. + if ((!encryption?.TryDecrypt(response)) ?? false) + { + const string msg = "Failed to decrypt login success."; + connection.Disconnect(msg); + throw new Exception(msg); + } } var msgSuc = new MsgLoginSuccess(); msgSuc.ReadFromBuffer(response, _serializer); - var channel = new NetChannel(this, connection, msgSuc.UserData with { HWId = [..legacyHwid] }, msgSuc.Type); + var channel = new NetChannel(this, connection, msgSuc.UserData with { HWId = [.. legacyHwid] }, msgSuc.Type); _channels.Add(connection, channel); peer.AddChannel(channel); @@ -440,7 +447,7 @@ private Task AwaitData( if (ipAddress.AddressFamily == AddressFamily.InterNetwork || ipAddress.AddressFamily == AddressFamily.InterNetworkV6) { - return new[] {ipAddress}; + return new[] { ipAddress }; } throw new ArgumentException("This method will not currently resolve other than IPv4 or IPv6 addresses"); diff --git a/Robust.Shared/Network/NetManager.cs b/Robust.Shared/Network/NetManager.cs index 7d9474fc95d..701a13c566d 100644 --- a/Robust.Shared/Network/NetManager.cs +++ b/Robust.Shared/Network/NetManager.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -113,6 +114,11 @@ public sealed partial class NetManager : IClientNetManager, IServerNetManager, I [Dependency] private readonly HttpClientHolder _http = default!; [Dependency] private readonly IHWId _hwId = default!; + /// + /// Whether we bother to log problematic packets. Set by . + /// + private bool _logPacketIssues = false; + /// /// Holds lookup table for NetMessage.Id -> NetMessage.Type /// @@ -140,6 +146,13 @@ public sealed partial class NetManager : IClientNetManager, IServerNetManager, I private ISawmill _loggerPacket = default!; private ISawmill _authLogger = default!; + private readonly ConcurrentDictionary _decryptFailCounts = new(); + private DateTime _lastDecryptFailCleanup = DateTime.UtcNow; + + private bool _clientSerializerComplete; + private bool _clientTransferComplete; + private bool _clientResetPending; + /// public int Port => _config.GetCVar(CVars.NetPort); @@ -258,6 +271,7 @@ public void Initialize(bool isServer) _config.OnValueChanged(CVars.NetLidgrenLogError, LidgrenLogErrorChanged); _config.OnValueChanged(CVars.NetVerbose, NetVerboseChanged); + _config.OnValueChanged(CVars.NetLogging, NetLoggingChanged); if (isServer) { _config.OnValueChanged(CVars.AuthMode, OnAuthModeChanged, invokeImmediately: true); @@ -280,6 +294,11 @@ public void Initialize(bool isServer) } } + private void NetLoggingChanged(bool obj) + { + _logPacketIssues = obj; + } + private void LidgrenLogWarningChanged(bool newValue) { foreach (var netPeer in _netPeers) @@ -377,7 +396,7 @@ public void StartServer() if (UpnpCompatible(config) && upnp) config.EnableUPnP = true; - var peer = IsServer ? (NetPeer) new NetServer(config) : new NetClient(config); + var peer = IsServer ? (NetPeer)new NetServer(config) : new NetClient(config); peer.Start(); _netPeers.Add(new NetPeerData(peer)); } @@ -460,8 +479,28 @@ public void Shutdown(string reason) _initialized = false; } + private static IPAddress NormalizeIp(IPAddress ip) + { + if (ip.AddressFamily != AddressFamily.InterNetworkV6) return ip; + var bytes = ip.GetAddressBytes(); + for (var i = 8; i < 16; i++) bytes[i] = 0; + return new IPAddress(bytes); + } + + private void CleanupDecryptFailCounts() + { + if (!IsServer) return; + var now = DateTime.UtcNow; + var intervalMinutes = _config.GetCVar(CVars.NetDecryptFailCleanupInterval); + if ((now - _lastDecryptFailCleanup).TotalMinutes < intervalMinutes) return; + _lastDecryptFailCleanup = now; + foreach (var (ip, (_, lastSeen)) in _decryptFailCounts) + { if ((now - lastSeen).TotalMinutes >= intervalMinutes) _decryptFailCounts.TryRemove(ip, out _); } + } + public void ProcessPackets() { + CleanupDecryptFailCounts(); var sentMessages = 0L; var recvMessages = 0L; var sentBytes = 0L; @@ -739,7 +778,7 @@ private void HandleStatusChanged(NetPeerData peer, NetIncomingMessage msg) var sender = msg.SenderConnection; DebugTools.Assert(sender != null); - var newStatus = (NetConnectionStatus) msg.ReadByte(); + var newStatus = (NetConnectionStatus)msg.ReadByte(); var reason = msg.ReadString(); _logger.Debug("{ConnectionEndpoint}: Status changed to {ConnectionStatus}, reason: {ConnectionStatusReason}", sender.RemoteEndPoint, newStatus, reason); @@ -835,7 +874,7 @@ private void HandleDisconnect(NetPeerData peer, NetConnection connection, string try { #endif - OnDisconnected(channel, reason); + OnDisconnected(channel, reason); #if EXCEPTION_TOLERANCE } catch (Exception e) @@ -871,19 +910,25 @@ private bool DispatchNetMessage(NetIncomingMessage msg) var peer = msg.SenderConnection.Peer; if (peer.Status == NetPeerStatus.ShutdownRequested) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested."); + if (_logPacketIssues) + _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested."); + return true; } if (peer.Status == NetPeerStatus.NotRunning) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running."); + if (_logPacketIssues) + _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running."); + return true; } if (!IsConnected) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected."); + if (_logPacketIssues) + _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected."); + return true; } @@ -898,19 +943,57 @@ private bool DispatchNetMessage(NetIncomingMessage msg) if (msg.LengthBytes < 1) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet."); + if (_logPacketIssues) + _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet."); + + msg.SenderConnection.Disconnect("Received empty/weird packet", false); return true; } if (!_channels.TryGetValue(msg.SenderConnection, out var channel)) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion."); + if (_logPacketIssues) + _logger.Debug($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion."); + - msg.SenderConnection.Disconnect("Unexpected packet before handshake completion"); + msg.SenderConnection.Disconnect("Unexpected packet before handshake completion", false); return true; } - channel.Encryption?.Decrypt(msg); + // Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption. + if ((!channel.Encryption?.TryDecrypt(msg)) ?? false) + { + var remoteEndPoint = msg.SenderConnection.RemoteEndPoint; + if (IsServer) + { + var remoteIp = NormalizeIp(remoteEndPoint.Address); + var now = DateTime.UtcNow; + var maxTracked = _config.GetCVar(CVars.NetDecryptFailMaxTracked); + + // Drop silently if tracking limit reached + if (_decryptFailCounts.Count >= maxTracked && !_decryptFailCounts.ContainsKey(remoteIp)) + return true; + var (failCount, _) = _decryptFailCounts.AddOrUpdate( + remoteIp, + _ => (1, now), + (_, old) => (old.TotalCount + 1, now)); + if (failCount == 1 && _logPacketIssues) + _logger.Debug($"{remoteEndPoint}: Got a packet that fails to decrypt."); + + var banThreshold = _config.GetCVar(CVars.NetDecryptFailBanThreshold); + if (failCount >= banThreshold) + { + _authLogger.Warning($"[DECRYPTBAN] {remoteIp} reached {failCount} decryption failures. Consider banning this IP."); + if (_config.GetCVar(CVars.NetDecryptFailKick)) + msg.SenderConnection.Disconnect("Failed to decrypt packet.", false); + return true; + } + } + else if (_logPacketIssues) + { _logger.Debug($"{remoteEndPoint}: Got a packet that fails to decrypt."); } + msg.SenderConnection.Disconnect("Failed to decrypt packet.", false); + return true; + } var id = msg.ReadByte(); @@ -918,9 +1001,10 @@ private bool DispatchNetMessage(NetIncomingMessage msg) if (entry == null) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}."); + if (_logPacketIssues) + _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}."); - channel.Disconnect("Got NetMessage with invalid ID"); + channel.Disconnect("Got NetMessage with invalid ID", false); return true; } @@ -928,15 +1012,16 @@ private bool DispatchNetMessage(NetIncomingMessage msg) if (!channel.IsHandshakeComplete && !entry.IsHandshake) { - _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion."); + if (_logPacketIssues) + _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion."); - channel.Disconnect("Got unacceptable net message before handshake completion"); + channel.Disconnect("Got unacceptable net message before handshake completion", false); return true; } var type = entry.Type; - var instance = (NetMessage) Activator.CreateInstance(type)!; + var instance = (NetMessage)Activator.CreateInstance(type)!; instance.MsgChannel = channel; #if DEBUG @@ -956,12 +1041,16 @@ private bool DispatchNetMessage(NetIncomingMessage msg) } catch (InvalidCastException ice) { - _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}"); + if (_logPacketIssues) + _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}"); + channel.Disconnect("Failed to deserialize packet.", false); return true; } catch (Exception e) // yes, we want to catch ALL exeptions for security { - _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}"); + if (_logPacketIssues) + _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}"); + channel.Disconnect("Failed to deserialize packet.", false); return true; } @@ -1021,7 +1110,7 @@ public void RegisterNetMessage(ProcessMessage? rxCallback = null, if (rxCallback != null && (accept & thisSide) != 0) { - data.Callback = msg => rxCallback((T) msg); + data.Callback = msg => rxCallback((T)msg); if (id != -1) CacheNetMsgIndex(id, name); @@ -1043,7 +1132,7 @@ private NetOutgoingMessage BuildMessage(NetMessage message, NetPeer peer) throw new NetManagerException( $"[NET] No string in table with name {message.MsgName}. Was it registered?"); - packet.Write((byte) msgId); + packet.Write((byte)msgId); message.WriteToBuffer(packet, _serializer); return packet; } diff --git a/Robust.UnitTesting/Shared/Networking/NetEcnryptionDoSTest.cs b/Robust.UnitTesting/Shared/Networking/NetEcnryptionDoSTest.cs new file mode 100644 index 00000000000..abcb834edaa --- /dev/null +++ b/Robust.UnitTesting/Shared/Networking/NetEcnryptionDoSTest.cs @@ -0,0 +1,175 @@ +using Lidgren.Network; +using NUnit.Framework; +using Robust.Shared.Network; + +namespace Robust.Shared.Tests.Networking; + +public sealed class NetEncryptionDoSTest +{ + private const ulong Magic = 0x13377777_77777777; + + [Test] + [Description("A control test that ensures connecting in a test works.")] + public void ConnectionWorks() + { + var (client, server) = MakeConnectionPair(); + + var message = client.CreateMessage(); + + message.WriteVariableUInt64(Magic); + + client.SendMessage(message, NetDeliveryMethod.ReliableOrdered); + + var packet = Receive(server); + + Assert.That(packet, Is.Not.Null); + + Assert.That(packet.ReadVariableUInt64(), Is.EqualTo(Magic)); + } + + [Test] + [Description("A control test that just ensures encryption works as other tests expect.")] + public void EncryptionWorks() + { + var (clientEnc, serverEnc) = MakeEncryptionPair(); + var (client, server) = MakeConnectionPair(); + + var message = client.CreateMessage(); + + message.WriteVariableUInt64(Magic); + + clientEnc.Encrypt(message); + + client.SendMessage(message, NetDeliveryMethod.ReliableOrdered); + + var packet = Receive(server); + + Assert.That(packet, Is.Not.Null); + + Assert.That(serverEnc.TryDecrypt(packet), Is.True); + } + + [Test] + [Description("Attempt to decrypt a packet that is using the wrong encryption keys, ensuring it doesn't throw.")] + public void WrongKeyFailureDoesNotThrow() + { + var (clientEnc, serverEnc) = MakeEncryptionPair(disjointKey: true); + var (client, server) = MakeConnectionPair(); + + var message = client.CreateMessage(); + + message.WriteVariableUInt64(Magic); + + clientEnc.Encrypt(message); + + client.SendMessage(message, NetDeliveryMethod.ReliableOrdered); + + var packet = server.WaitMessage(1000); + + Assert.That(packet, Is.Not.Null); + + Assert.That(serverEnc.TryDecrypt(packet), Is.False); + } + + private static byte[][] _badMessages = + [ + [1, 1, 1, 1, 1], + [1, 2], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,] + ]; + + [Test] + [Description("Attempt to decrypt a packet that is bogus, ensuring it doesn't throw.")] + [TestCaseSource(nameof(_badMessages))] + public void BadMessageDoesNotThrow(byte[] badMessage) + { + var (_, serverEnc) = MakeEncryptionPair(disjointKey: true); + var (client, server) = MakeConnectionPair(); + + var message = client.CreateMessage(); + + message.Write(badMessage); + + // Don't encrypt at all. + + client.SendMessage(message, NetDeliveryMethod.ReliableOrdered); + + var packet = server.WaitMessage(1000); + + Assert.That(packet, Is.Not.Null); + + Assert.That(serverEnc.TryDecrypt(packet), Is.False); + } + + + // TODO: Generalize all this for other low level network tests. + + private (NetClient client, NetServer server) MakeConnectionPair() + { + const string id = "test"; + var client = new NetClient(new NetPeerConfiguration(id)); + + var server = new NetServer(new NetPeerConfiguration(id)); + + client.Start(); + // Lidgren has no facilities for mocking this nicely. + // So we just use an actual socket. + server.Start(); + + client.Connect("localhost", server.Port); + + var ready = false; + + while (!ready) + { + switch (server.WaitMessage(1000)) + { + case { MessageType: NetIncomingMessageType.StatusChanged } msg: + { + // hello there. + var status = (NetConnectionStatus)msg.ReadByte(); + + if (status == NetConnectionStatus.Connected) + ready = true; + + break; + } + } + } + + return (client, server); + } + + private NetIncomingMessage Receive(NetPeer peer) + { + NetIncomingMessage? found = null; + + while (found == null) + { + switch (peer.WaitMessage(1000)) + { + case { MessageType: NetIncomingMessageType.Data } msg: + { + found = msg; + break; + } + } + } + + return found; + } + + private (NetEncryption client, NetEncryption server) MakeEncryptionPair(bool disjointKey = false) + { + var serverKey = new byte[32]; + + System.Random.Shared.NextBytes(serverKey.AsSpan()); + + var clientKey = (byte[])serverKey.Clone(); + + if (disjointKey) + System.Random.Shared.NextBytes(clientKey.AsSpan()); + + return (new NetEncryption(clientKey, false), new NetEncryption(serverKey, true)); + } +}