Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions Robust.Shared.Tests/Networking/NetEcnryptionDoSTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
using Lidgren.Network;
using NUnit.Framework;
using Robust.Shared.Network;

namespace Robust.Shared.Tests.Networking;

public sealed class NetEncryptionDoSTest
{
private const ulong Magic = 0x13377777_77777777;

[Test]
[Description("A control test that ensures connecting in a test works.")]
public void ConnectionWorks()
{
var (client, server) = MakeConnectionPair();

var message = client.CreateMessage();

message.WriteVariableUInt64(Magic);

client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);

var packet = Receive(server);

Assert.That(packet, Is.Not.Null);

Assert.That(packet.ReadVariableUInt64(), Is.EqualTo(Magic));
}

[Test]
[Description("A control test that just ensures encryption works as other tests expect.")]
public void EncryptionWorks()
{
var (clientEnc, serverEnc) = MakeEncryptionPair();
var (client, server) = MakeConnectionPair();

var message = client.CreateMessage();

message.WriteVariableUInt64(Magic);

clientEnc.Encrypt(message);

client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);

var packet = Receive(server);

Assert.That(packet, Is.Not.Null);

Assert.That(serverEnc.TryDecrypt(packet), Is.True);
}

[Test]
[Description("Attempt to decrypt a packet that is using the wrong encryption keys, ensuring it doesn't throw.")]
public void WrongKeyFailureDoesNotThrow()
{
var (clientEnc, serverEnc) = MakeEncryptionPair(disjointKey: true);
var (client, server) = MakeConnectionPair();

var message = client.CreateMessage();

message.WriteVariableUInt64(Magic);

clientEnc.Encrypt(message);

client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);

var packet = server.WaitMessage(1000);

Assert.That(packet, Is.Not.Null);

Assert.That(serverEnc.TryDecrypt(packet), Is.False);
}

private static byte[][] _badMessages =
[
[1, 1, 1, 1, 1],
[1, 2],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,]
];

[Test]
[Description("Attempt to decrypt a packet that is bogus, ensuring it doesn't throw.")]
[TestCaseSource(nameof(_badMessages))]
public void BadMessageDoesNotThrow(byte[] badMessage)
{
var (_, serverEnc) = MakeEncryptionPair(disjointKey: true);
var (client, server) = MakeConnectionPair();

var message = client.CreateMessage();

message.Write(badMessage);

// Don't encrypt at all.

client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);

var packet = server.WaitMessage(1000);

Assert.That(packet, Is.Not.Null);

Assert.That(serverEnc.TryDecrypt(packet), Is.False);
}


// TODO: Generalize all this for other low level network tests.

private (NetClient client, NetServer server) MakeConnectionPair()
{
const string id = "test";
var client = new NetClient(new NetPeerConfiguration(id));

var server = new NetServer(new NetPeerConfiguration(id));

client.Start();
// Lidgren has no facilities for mocking this nicely.
// So we just use an actual socket.
server.Start();

client.Connect("localhost", server.Port);

var ready = false;

while (!ready)
{
switch (server.WaitMessage(1000))
{
case { MessageType: NetIncomingMessageType.StatusChanged } msg:
{
// hello there.
var status = (NetConnectionStatus)msg.ReadByte();

if (status == NetConnectionStatus.Connected)
ready = true;

break;
}
}
}

return (client, server);
}

private NetIncomingMessage Receive(NetPeer peer)
{
NetIncomingMessage? found = null;

while (found == null)
{
switch (peer.WaitMessage(1000))
{
case { MessageType: NetIncomingMessageType.Data } msg:
{
found = msg;
break;
}
}
}

return found;
}

private (NetEncryption client, NetEncryption server) MakeEncryptionPair(bool disjointKey = false)
{
var serverKey = new byte[32];

System.Random.Shared.NextBytes(serverKey.AsSpan());

var clientKey = (byte[])serverKey.Clone();

if (disjointKey)
System.Random.Shared.NextBytes(clientKey.AsSpan());

return (new NetEncryption(clientKey, false), new NetEncryption(serverKey, true));
}
}
3 changes: 2 additions & 1 deletion Robust.Shared.Tests/Robust.Shared.Tests.csproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\MSBuild\Robust.Engine.props"/>

<PropertyGroup>
Expand All @@ -19,6 +19,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Lidgren.Network\Lidgren.Network.csproj" />
<ProjectReference Include="..\NetSerializer\NetSerializer\NetSerializer.csproj" />
<ProjectReference Include="..\Robust.Shared.Maths.Tests\Robust.Shared.Maths.Tests.csproj" />
<ProjectReference Include="..\Robust.Shared.Maths\Robust.Shared.Maths.csproj"/>
Expand Down
25 changes: 25 additions & 0 deletions Robust.Shared/CVars.cs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,31 @@ protected CVars()
public static readonly CVarDef<string> NetLidgrenAppIdentifier =
CVarDef.Create("net.lidgren_app_identifier", "RobustToolbox");

/// <summary>
/// Whether to disconnect clients that exceed the decryption failure threshold.
/// </summary>
public static readonly CVarDef<bool> NetDecryptFailKick =
CVarDef.Create("net.dos_fail_kick", true, CVar.SERVERONLY);

/// <summary>
/// Number of decryption failures from a single IP (or /64 subnet for IPv6) before logging a ban warning and optionally disconnecting.
/// </summary>
public static readonly CVarDef<int> NetDecryptFailBanThreshold =
CVarDef.Create("net.dos_fail_ban_threshold", 10, CVar.SERVERONLY);

/// <summary>
/// How often (in minutes) to clean up stale decryption failure records.
/// Records are only removed if they have not been seen for this many minutes.
/// </summary>
public static readonly CVarDef<int> NetDecryptFailCleanupInterval =
CVarDef.Create("net.dos_fail_cleanup_interval", 10, CVar.SERVERONLY);

/// <summary>
/// Maximum number of IPs tracked for decryption failures. Prevents memory exhaustion from botnet attacks.
/// </summary>
public static readonly CVarDef<int> NetDecryptFailMaxTracked =
CVarDef.Create("net.dos_fail_max_tracked", 10000, CVar.SERVERONLY);

/// <summary>
/// Add random fake network loss to all outgoing UDP network packets, as a ratio of how many packets to drop.
/// 0 = no packet loss, 1 = all packets dropped
Expand Down
18 changes: 12 additions & 6 deletions Robust.Shared/Network/NetEncryption.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Threading;
Expand Down Expand Up @@ -84,7 +84,12 @@ public unsafe void Encrypt(NetOutgoingMessage message)
ArrayPool<byte>.Shared.Return(returnPool);
}

public unsafe void Decrypt(NetIncomingMessage message)
/// <summary>
/// Attempts to decrypt an incoming network message, falliably.
/// </summary>
/// <param name="message">The message to decrypt in-place. This will be mutated with the decrypted results.</param>
/// <returns>Whether the operation was successful. If this fails, you likely want to drop the connection.</returns>
public unsafe bool TryDecrypt(NetIncomingMessage message)
{
var nonce = message.ReadUInt64();
var cipherText = message.Data.AsSpan(sizeof(ulong), message.LengthBytes - sizeof(ulong));
Expand All @@ -109,12 +114,13 @@ public unsafe void Decrypt(NetIncomingMessage message)
// key
_key);

message.Position = 0;
message.LengthBytes = messageLength;

ArrayPool<byte>.Shared.Return(buffer);

if (!result)
throw new SodiumException("Decryption operation failed!");
return false;

message.Position = 0;
message.LengthBytes = messageLength;
return true;
}
}
13 changes: 10 additions & 3 deletions Robust.Shared/Network/NetManager.ClientConnect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,20 @@ private async Task CCDoHandshake(

// Expect login success here.
response = await AwaitData(connection, cancel);
encryption?.Decrypt(response);

// Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
if ((!encryption?.TryDecrypt(response)) ?? false)
{
const string msg = "Failed to decrypt login success.";
connection.Disconnect(msg);
throw new Exception(msg);
}
}

var msgSuc = new MsgLoginSuccess();
msgSuc.ReadFromBuffer(response, _serializer);

var channel = new NetChannel(this, connection, msgSuc.UserData with { HWId = [..legacyHwid] }, msgSuc.Type);
var channel = new NetChannel(this, connection, msgSuc.UserData with { HWId = [.. legacyHwid] }, msgSuc.Type);
_channels.Add(connection, channel);
peer.AddChannel(channel);

Expand Down Expand Up @@ -440,7 +447,7 @@ private Task<NetIncomingMessage> AwaitData(
if (ipAddress.AddressFamily == AddressFamily.InterNetwork
|| ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
return new[] {ipAddress};
return new[] { ipAddress };
}

throw new ArgumentException("This method will not currently resolve other than IPv4 or IPv6 addresses");
Expand Down
Loading
Loading