diff --git a/Robust.Shared.Tests/Robust.Shared.Tests.csproj b/Robust.Shared.Tests/Robust.Shared.Tests.csproj
new file mode 100644
index 00000000000..5b22ba23274
--- /dev/null
+++ b/Robust.Shared.Tests/Robust.Shared.Tests.csproj
@@ -0,0 +1,31 @@
+
+
+
+
+ enable
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Robust.Shared/CVars.cs b/Robust.Shared/CVars.cs
index 3019547cc5b..1e5507c934b 100644
--- a/Robust.Shared/CVars.cs
+++ b/Robust.Shared/CVars.cs
@@ -337,6 +337,31 @@ protected CVars()
public static readonly CVarDef NetLidgrenAppIdentifier =
CVarDef.Create("net.lidgren_app_identifier", "RobustToolbox");
+ ///
+ /// Whether to disconnect clients that exceed the decryption failure threshold.
+ ///
+ public static readonly CVarDef NetDecryptFailKick =
+ CVarDef.Create("net.dos_fail_kick", true, CVar.SERVERONLY);
+
+ ///
+ /// Number of decryption failures from a single IP (or /64 subnet for IPv6) before logging a ban warning and optionally disconnecting.
+ ///
+ public static readonly CVarDef NetDecryptFailBanThreshold =
+ CVarDef.Create("net.dos_fail_ban_threshold", 10, CVar.SERVERONLY);
+
+ ///
+ /// How often (in minutes) to clean up stale decryption failure records.
+ /// Records are only removed if they have not been seen for this many minutes.
+ ///
+ public static readonly CVarDef NetDecryptFailCleanupInterval =
+ CVarDef.Create("net.dos_fail_cleanup_interval", 10, CVar.SERVERONLY);
+
+ ///
+ /// Maximum number of IPs tracked for decryption failures. Prevents memory exhaustion from botnet attacks.
+ ///
+ public static readonly CVarDef NetDecryptFailMaxTracked =
+ CVarDef.Create("net.dos_fail_max_tracked", 10000, CVar.SERVERONLY);
+
///
/// Add random fake network loss to all outgoing UDP network packets, as a ratio of how many packets to drop.
/// 0 = no packet loss, 1 = all packets dropped
diff --git a/Robust.Shared/Network/NetEncryption.cs b/Robust.Shared/Network/NetEncryption.cs
index 1aef24d602f..8040cfc101a 100644
--- a/Robust.Shared/Network/NetEncryption.cs
+++ b/Robust.Shared/Network/NetEncryption.cs
@@ -1,4 +1,4 @@
-using System;
+using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Threading;
@@ -84,7 +84,12 @@ public unsafe void Encrypt(NetOutgoingMessage message)
ArrayPool.Shared.Return(returnPool);
}
- public unsafe void Decrypt(NetIncomingMessage message)
+ ///
+ /// Attempts to decrypt an incoming network message, falliably.
+ ///
+ /// The message to decrypt in-place. This will be mutated with the decrypted results.
+ /// Whether the operation was successful. If this fails, you likely want to drop the connection.
+ public unsafe bool TryDecrypt(NetIncomingMessage message)
{
var nonce = message.ReadUInt64();
var cipherText = message.Data.AsSpan(sizeof(ulong), message.LengthBytes - sizeof(ulong));
@@ -109,12 +114,13 @@ public unsafe void Decrypt(NetIncomingMessage message)
// key
_key);
- message.Position = 0;
- message.LengthBytes = messageLength;
-
ArrayPool.Shared.Return(buffer);
if (!result)
- throw new SodiumException("Decryption operation failed!");
+ return false;
+
+ message.Position = 0;
+ message.LengthBytes = messageLength;
+ return true;
}
}
diff --git a/Robust.Shared/Network/NetManager.ClientConnect.cs b/Robust.Shared/Network/NetManager.ClientConnect.cs
index 238bc9d0aa6..14cf54f3ddb 100644
--- a/Robust.Shared/Network/NetManager.ClientConnect.cs
+++ b/Robust.Shared/Network/NetManager.ClientConnect.cs
@@ -220,13 +220,20 @@ private async Task CCDoHandshake(
// Expect login success here.
response = await AwaitData(connection, cancel);
- encryption?.Decrypt(response);
+
+ // Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
+ if ((!encryption?.TryDecrypt(response)) ?? false)
+ {
+ const string msg = "Failed to decrypt login success.";
+ connection.Disconnect(msg);
+ throw new Exception(msg);
+ }
}
var msgSuc = new MsgLoginSuccess();
msgSuc.ReadFromBuffer(response, _serializer);
- var channel = new NetChannel(this, connection, msgSuc.UserData with { HWId = [..legacyHwid] }, msgSuc.Type);
+ var channel = new NetChannel(this, connection, msgSuc.UserData with { HWId = [.. legacyHwid] }, msgSuc.Type);
_channels.Add(connection, channel);
peer.AddChannel(channel);
@@ -440,7 +447,7 @@ private Task AwaitData(
if (ipAddress.AddressFamily == AddressFamily.InterNetwork
|| ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
- return new[] {ipAddress};
+ return new[] { ipAddress };
}
throw new ArgumentException("This method will not currently resolve other than IPv4 or IPv6 addresses");
diff --git a/Robust.Shared/Network/NetManager.cs b/Robust.Shared/Network/NetManager.cs
index 7d9474fc95d..701a13c566d 100644
--- a/Robust.Shared/Network/NetManager.cs
+++ b/Robust.Shared/Network/NetManager.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
@@ -113,6 +114,11 @@ public sealed partial class NetManager : IClientNetManager, IServerNetManager, I
[Dependency] private readonly HttpClientHolder _http = default!;
[Dependency] private readonly IHWId _hwId = default!;
+ ///
+ /// Whether we bother to log problematic packets. Set by .
+ ///
+ private bool _logPacketIssues = false;
+
///
/// Holds lookup table for NetMessage.Id -> NetMessage.Type
///
@@ -140,6 +146,13 @@ public sealed partial class NetManager : IClientNetManager, IServerNetManager, I
private ISawmill _loggerPacket = default!;
private ISawmill _authLogger = default!;
+ private readonly ConcurrentDictionary _decryptFailCounts = new();
+ private DateTime _lastDecryptFailCleanup = DateTime.UtcNow;
+
+ private bool _clientSerializerComplete;
+ private bool _clientTransferComplete;
+ private bool _clientResetPending;
+
///
public int Port => _config.GetCVar(CVars.NetPort);
@@ -258,6 +271,7 @@ public void Initialize(bool isServer)
_config.OnValueChanged(CVars.NetLidgrenLogError, LidgrenLogErrorChanged);
_config.OnValueChanged(CVars.NetVerbose, NetVerboseChanged);
+ _config.OnValueChanged(CVars.NetLogging, NetLoggingChanged);
if (isServer)
{
_config.OnValueChanged(CVars.AuthMode, OnAuthModeChanged, invokeImmediately: true);
@@ -280,6 +294,11 @@ public void Initialize(bool isServer)
}
}
+ private void NetLoggingChanged(bool obj)
+ {
+ _logPacketIssues = obj;
+ }
+
private void LidgrenLogWarningChanged(bool newValue)
{
foreach (var netPeer in _netPeers)
@@ -377,7 +396,7 @@ public void StartServer()
if (UpnpCompatible(config) && upnp)
config.EnableUPnP = true;
- var peer = IsServer ? (NetPeer) new NetServer(config) : new NetClient(config);
+ var peer = IsServer ? (NetPeer)new NetServer(config) : new NetClient(config);
peer.Start();
_netPeers.Add(new NetPeerData(peer));
}
@@ -460,8 +479,28 @@ public void Shutdown(string reason)
_initialized = false;
}
+ private static IPAddress NormalizeIp(IPAddress ip)
+ {
+ if (ip.AddressFamily != AddressFamily.InterNetworkV6) return ip;
+ var bytes = ip.GetAddressBytes();
+ for (var i = 8; i < 16; i++) bytes[i] = 0;
+ return new IPAddress(bytes);
+ }
+
+ private void CleanupDecryptFailCounts()
+ {
+ if (!IsServer) return;
+ var now = DateTime.UtcNow;
+ var intervalMinutes = _config.GetCVar(CVars.NetDecryptFailCleanupInterval);
+ if ((now - _lastDecryptFailCleanup).TotalMinutes < intervalMinutes) return;
+ _lastDecryptFailCleanup = now;
+ foreach (var (ip, (_, lastSeen)) in _decryptFailCounts)
+ { if ((now - lastSeen).TotalMinutes >= intervalMinutes) _decryptFailCounts.TryRemove(ip, out _); }
+ }
+
public void ProcessPackets()
{
+ CleanupDecryptFailCounts();
var sentMessages = 0L;
var recvMessages = 0L;
var sentBytes = 0L;
@@ -739,7 +778,7 @@ private void HandleStatusChanged(NetPeerData peer, NetIncomingMessage msg)
var sender = msg.SenderConnection;
DebugTools.Assert(sender != null);
- var newStatus = (NetConnectionStatus) msg.ReadByte();
+ var newStatus = (NetConnectionStatus)msg.ReadByte();
var reason = msg.ReadString();
_logger.Debug("{ConnectionEndpoint}: Status changed to {ConnectionStatus}, reason: {ConnectionStatusReason}",
sender.RemoteEndPoint, newStatus, reason);
@@ -835,7 +874,7 @@ private void HandleDisconnect(NetPeerData peer, NetConnection connection, string
try
{
#endif
- OnDisconnected(channel, reason);
+ OnDisconnected(channel, reason);
#if EXCEPTION_TOLERANCE
}
catch (Exception e)
@@ -871,19 +910,25 @@ private bool DispatchNetMessage(NetIncomingMessage msg)
var peer = msg.SenderConnection.Peer;
if (peer.Status == NetPeerStatus.ShutdownRequested)
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested.");
+ if (_logPacketIssues)
+ _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested.");
+
return true;
}
if (peer.Status == NetPeerStatus.NotRunning)
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running.");
+ if (_logPacketIssues)
+ _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running.");
+
return true;
}
if (!IsConnected)
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected.");
+ if (_logPacketIssues)
+ _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected.");
+
return true;
}
@@ -898,19 +943,57 @@ private bool DispatchNetMessage(NetIncomingMessage msg)
if (msg.LengthBytes < 1)
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet.");
+ if (_logPacketIssues)
+ _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet.");
+
+ msg.SenderConnection.Disconnect("Received empty/weird packet", false);
return true;
}
if (!_channels.TryGetValue(msg.SenderConnection, out var channel))
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion.");
+ if (_logPacketIssues)
+ _logger.Debug($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion.");
+
- msg.SenderConnection.Disconnect("Unexpected packet before handshake completion");
+ msg.SenderConnection.Disconnect("Unexpected packet before handshake completion", false);
return true;
}
- channel.Encryption?.Decrypt(msg);
+ // Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
+ if ((!channel.Encryption?.TryDecrypt(msg)) ?? false)
+ {
+ var remoteEndPoint = msg.SenderConnection.RemoteEndPoint;
+ if (IsServer)
+ {
+ var remoteIp = NormalizeIp(remoteEndPoint.Address);
+ var now = DateTime.UtcNow;
+ var maxTracked = _config.GetCVar(CVars.NetDecryptFailMaxTracked);
+
+ // Drop silently if tracking limit reached
+ if (_decryptFailCounts.Count >= maxTracked && !_decryptFailCounts.ContainsKey(remoteIp))
+ return true;
+ var (failCount, _) = _decryptFailCounts.AddOrUpdate(
+ remoteIp,
+ _ => (1, now),
+ (_, old) => (old.TotalCount + 1, now));
+ if (failCount == 1 && _logPacketIssues)
+ _logger.Debug($"{remoteEndPoint}: Got a packet that fails to decrypt.");
+
+ var banThreshold = _config.GetCVar(CVars.NetDecryptFailBanThreshold);
+ if (failCount >= banThreshold)
+ {
+ _authLogger.Warning($"[DECRYPTBAN] {remoteIp} reached {failCount} decryption failures. Consider banning this IP.");
+ if (_config.GetCVar(CVars.NetDecryptFailKick))
+ msg.SenderConnection.Disconnect("Failed to decrypt packet.", false);
+ return true;
+ }
+ }
+ else if (_logPacketIssues)
+ { _logger.Debug($"{remoteEndPoint}: Got a packet that fails to decrypt."); }
+ msg.SenderConnection.Disconnect("Failed to decrypt packet.", false);
+ return true;
+ }
var id = msg.ReadByte();
@@ -918,9 +1001,10 @@ private bool DispatchNetMessage(NetIncomingMessage msg)
if (entry == null)
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}.");
+ if (_logPacketIssues)
+ _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}.");
- channel.Disconnect("Got NetMessage with invalid ID");
+ channel.Disconnect("Got NetMessage with invalid ID", false);
return true;
}
@@ -928,15 +1012,16 @@ private bool DispatchNetMessage(NetIncomingMessage msg)
if (!channel.IsHandshakeComplete && !entry.IsHandshake)
{
- _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion.");
+ if (_logPacketIssues)
+ _logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion.");
- channel.Disconnect("Got unacceptable net message before handshake completion");
+ channel.Disconnect("Got unacceptable net message before handshake completion", false);
return true;
}
var type = entry.Type;
- var instance = (NetMessage) Activator.CreateInstance(type)!;
+ var instance = (NetMessage)Activator.CreateInstance(type)!;
instance.MsgChannel = channel;
#if DEBUG
@@ -956,12 +1041,16 @@ private bool DispatchNetMessage(NetIncomingMessage msg)
}
catch (InvalidCastException ice)
{
- _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}");
+ if (_logPacketIssues)
+ _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}");
+ channel.Disconnect("Failed to deserialize packet.", false);
return true;
}
catch (Exception e) // yes, we want to catch ALL exeptions for security
{
- _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}");
+ if (_logPacketIssues)
+ _logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}");
+ channel.Disconnect("Failed to deserialize packet.", false);
return true;
}
@@ -1021,7 +1110,7 @@ public void RegisterNetMessage(ProcessMessage? rxCallback = null,
if (rxCallback != null && (accept & thisSide) != 0)
{
- data.Callback = msg => rxCallback((T) msg);
+ data.Callback = msg => rxCallback((T)msg);
if (id != -1)
CacheNetMsgIndex(id, name);
@@ -1043,7 +1132,7 @@ private NetOutgoingMessage BuildMessage(NetMessage message, NetPeer peer)
throw new NetManagerException(
$"[NET] No string in table with name {message.MsgName}. Was it registered?");
- packet.Write((byte) msgId);
+ packet.Write((byte)msgId);
message.WriteToBuffer(packet, _serializer);
return packet;
}
diff --git a/Robust.UnitTesting/Shared/Networking/NetEcnryptionDoSTest.cs b/Robust.UnitTesting/Shared/Networking/NetEcnryptionDoSTest.cs
new file mode 100644
index 00000000000..abcb834edaa
--- /dev/null
+++ b/Robust.UnitTesting/Shared/Networking/NetEcnryptionDoSTest.cs
@@ -0,0 +1,175 @@
+using Lidgren.Network;
+using NUnit.Framework;
+using Robust.Shared.Network;
+
+namespace Robust.Shared.Tests.Networking;
+
+public sealed class NetEncryptionDoSTest
+{
+ private const ulong Magic = 0x13377777_77777777;
+
+ [Test]
+ [Description("A control test that ensures connecting in a test works.")]
+ public void ConnectionWorks()
+ {
+ var (client, server) = MakeConnectionPair();
+
+ var message = client.CreateMessage();
+
+ message.WriteVariableUInt64(Magic);
+
+ client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
+
+ var packet = Receive(server);
+
+ Assert.That(packet, Is.Not.Null);
+
+ Assert.That(packet.ReadVariableUInt64(), Is.EqualTo(Magic));
+ }
+
+ [Test]
+ [Description("A control test that just ensures encryption works as other tests expect.")]
+ public void EncryptionWorks()
+ {
+ var (clientEnc, serverEnc) = MakeEncryptionPair();
+ var (client, server) = MakeConnectionPair();
+
+ var message = client.CreateMessage();
+
+ message.WriteVariableUInt64(Magic);
+
+ clientEnc.Encrypt(message);
+
+ client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
+
+ var packet = Receive(server);
+
+ Assert.That(packet, Is.Not.Null);
+
+ Assert.That(serverEnc.TryDecrypt(packet), Is.True);
+ }
+
+ [Test]
+ [Description("Attempt to decrypt a packet that is using the wrong encryption keys, ensuring it doesn't throw.")]
+ public void WrongKeyFailureDoesNotThrow()
+ {
+ var (clientEnc, serverEnc) = MakeEncryptionPair(disjointKey: true);
+ var (client, server) = MakeConnectionPair();
+
+ var message = client.CreateMessage();
+
+ message.WriteVariableUInt64(Magic);
+
+ clientEnc.Encrypt(message);
+
+ client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
+
+ var packet = server.WaitMessage(1000);
+
+ Assert.That(packet, Is.Not.Null);
+
+ Assert.That(serverEnc.TryDecrypt(packet), Is.False);
+ }
+
+ private static byte[][] _badMessages =
+ [
+ [1, 1, 1, 1, 1],
+ [1, 2],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,]
+ ];
+
+ [Test]
+ [Description("Attempt to decrypt a packet that is bogus, ensuring it doesn't throw.")]
+ [TestCaseSource(nameof(_badMessages))]
+ public void BadMessageDoesNotThrow(byte[] badMessage)
+ {
+ var (_, serverEnc) = MakeEncryptionPair(disjointKey: true);
+ var (client, server) = MakeConnectionPair();
+
+ var message = client.CreateMessage();
+
+ message.Write(badMessage);
+
+ // Don't encrypt at all.
+
+ client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
+
+ var packet = server.WaitMessage(1000);
+
+ Assert.That(packet, Is.Not.Null);
+
+ Assert.That(serverEnc.TryDecrypt(packet), Is.False);
+ }
+
+
+ // TODO: Generalize all this for other low level network tests.
+
+ private (NetClient client, NetServer server) MakeConnectionPair()
+ {
+ const string id = "test";
+ var client = new NetClient(new NetPeerConfiguration(id));
+
+ var server = new NetServer(new NetPeerConfiguration(id));
+
+ client.Start();
+ // Lidgren has no facilities for mocking this nicely.
+ // So we just use an actual socket.
+ server.Start();
+
+ client.Connect("localhost", server.Port);
+
+ var ready = false;
+
+ while (!ready)
+ {
+ switch (server.WaitMessage(1000))
+ {
+ case { MessageType: NetIncomingMessageType.StatusChanged } msg:
+ {
+ // hello there.
+ var status = (NetConnectionStatus)msg.ReadByte();
+
+ if (status == NetConnectionStatus.Connected)
+ ready = true;
+
+ break;
+ }
+ }
+ }
+
+ return (client, server);
+ }
+
+ private NetIncomingMessage Receive(NetPeer peer)
+ {
+ NetIncomingMessage? found = null;
+
+ while (found == null)
+ {
+ switch (peer.WaitMessage(1000))
+ {
+ case { MessageType: NetIncomingMessageType.Data } msg:
+ {
+ found = msg;
+ break;
+ }
+ }
+ }
+
+ return found;
+ }
+
+ private (NetEncryption client, NetEncryption server) MakeEncryptionPair(bool disjointKey = false)
+ {
+ var serverKey = new byte[32];
+
+ System.Random.Shared.NextBytes(serverKey.AsSpan());
+
+ var clientKey = (byte[])serverKey.Clone();
+
+ if (disjointKey)
+ System.Random.Shared.NextBytes(clientKey.AsSpan());
+
+ return (new NetEncryption(clientKey, false), new NetEncryption(serverKey, true));
+ }
+}