From f4f07f17ff698a197da608c9b6e0c7c2b877e12f Mon Sep 17 00:00:00 2001 From: MhaWay Date: Wed, 25 Mar 2026 05:28:25 +0100 Subject: [PATCH 1/4] Add typed packet serialization system - Add IPacket interface with PacketBuffer/PacketReader/PacketWriter for strongly-typed packet serialization/deserialization - Add PacketDefinitionAttribute, SerializedPacket, PacketTypeInfo - Add FragmentedPacket record for structured fragment handling - Add ByteReader.ReadEnum() and ByteWriter.WriteEnum() helpers - Refactor PacketHandlerAttribute: replace IsFragmentedAttribute with allowFragmented parameter, add TypedPacketHandlerAttribute, PacketHandlerClassAttribute, FragmentedPacketHandlerAttribute - Add PacketHandlerInvoker delegate type replacing FastInvokeHandler for packet handler invocation - Refactor MpConnectionState: encapsulate handler lookup via GetPacketHandler(), add RegisterTypedPacketHandler and RegisterFragmentedPacketHandler with DynamicMethod IL emit - Add typed packet support to AsyncConnectionState: TypedPacket(), TypedPacketOrNull(), TypedPacketAwaitable() - Update ConnectionBase: add Send for typed packets, Send(SerializedPacket), SendFragmented(SerializedPacket), use WriteEnum/WriteRaw in GetDisconnectBytes - Migrate all [IsFragmented] attributes to [PacketHandler(allowFragmented)] --- Languages | 2 +- Source/Common/ByteReader.cs | 2 +- Source/Common/ByteWriter.cs | 2 +- .../Common/Networking/AsyncConnectionState.cs | 100 ++++-- Source/Common/Networking/ConnectionBase.cs | 35 +- Source/Common/Networking/MpConnectionState.cs | 105 +++++- Source/Common/Networking/Packet/IPacket.cs | 318 ++++++++++++++++++ .../Networking/PacketHandlerAttribute.cs | 14 +- .../Common/Networking/PacketReadException.cs | 4 + .../Syncing/Logger/LoggingByteReader.cs | 2 +- .../Syncing/Logger/LoggingByteWriter.cs | 2 +- 11 files changed, 524 insertions(+), 62 deletions(-) create mode 100644 Source/Common/Networking/Packet/IPacket.cs diff --git a/Languages b/Languages index 0c38ba2e0..8e44aab97 160000 --- a/Languages +++ b/Languages @@ -1 +1 @@ -Subproject commit 0c38ba2e075ab40f676d3658959765464985ceda +Subproject commit 8e44aab97e7f839ede8b98a7ae2160dfebc4ec1d diff --git a/Source/Common/ByteReader.cs b/Source/Common/ByteReader.cs index 4e87dbe43..7e8a1bfdf 100644 --- a/Source/Common/ByteReader.cs +++ b/Source/Common/ByteReader.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Text; namespace Multiplayer.Common diff --git a/Source/Common/ByteWriter.cs b/Source/Common/ByteWriter.cs index 6deaabec1..3980cbe6c 100644 --- a/Source/Common/ByteWriter.cs +++ b/Source/Common/ByteWriter.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections; using System.Collections.Generic; using System.IO; diff --git a/Source/Common/Networking/AsyncConnectionState.cs b/Source/Common/Networking/AsyncConnectionState.cs index 3f7becf97..a1ca6a719 100644 --- a/Source/Common/Networking/AsyncConnectionState.cs +++ b/Source/Common/Networking/AsyncConnectionState.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading.Tasks; +using Multiplayer.Common.Networking.Packet; using Multiplayer.Common.Util; namespace Multiplayer.Common; @@ -58,6 +59,23 @@ protected PacketAwaitable Packet(Packets packet) return packetAwaitable!; } + /// + /// Wait for a packet of the given type. The packet must arrive after the call to this method. + /// An exception is thrown if this is called again before the packet arrives. The player is disconnected if a + /// different packet type arrives. + /// + protected Task TypedPacket() where T: struct, IPacket + { + if (packetAwaitable != null) + throw new Exception($"Already waiting for another packet: {packetAwaitable}"); + + ServerLog.Verbose($"{connection} waiting for {PacketTypeInfo.Id}"); + + packetAwaitable = TypedPacketAwaitable(out var task, announcePacketFailure: false); + return task + .ContinueWith(finishedTask => (T)finishedTask.Result!, TaskContinuationOptions.OnlyOnRanToCompletion); + } + /// /// Wait for a packet of the given type. The packet must arrive after the call to this method. /// An exception is thrown if this is called again before the packet arrives. The player is disconnected if a @@ -75,15 +93,31 @@ protected PacketAwaitable Packet(Packets packet) return packetAwaitable; } + + /// + /// Wait for a packet of the given type. The packet must arrive after the call to this method. + /// An exception is thrown if this is called again before the packet arrives. The player is disconnected if a + /// different packet type arrives. + /// + protected Task TypedPacketOrNull() where T: struct, IPacket + { + if (packetAwaitable != null) + throw new Exception($"Already waiting for another packet: {packetAwaitable}"); + + ServerLog.Verbose($"{connection} waiting for {PacketTypeInfo.Id}"); + + packetAwaitable = TypedPacketAwaitable(out var task, announcePacketFailure: true); + return task; + } + public override PacketHandlerInfo? GetPacketHandler(Packets packet) { if (packetAwaitable != null && packetAwaitable.PacketType == packet) - return new PacketHandlerInfo((_, args) => + return new PacketHandlerInfo((_, data) => { var source = packetAwaitable; packetAwaitable = null; - source.SetResult((ByteReader)args[0]); - return null; + source.SetResult(data); }, packetAwaitable.Fragment); return base.GetPacketHandler(packet); @@ -95,35 +129,46 @@ protected async Task EndIfDead() await new Blackhole(); return true; } -} -public class PacketAwaitable : INotifyCompletion -{ - private List continuations = new(); - public Packets PacketType { get; } - public bool AnnouncePacketFailure { get; } - private T? result; + private static PacketAwaitable TypedPacketAwaitable(out Task task, bool announcePacketFailure) where TPacket : struct, IPacket + { + var awaitable = new PacketAwaitable(PacketTypeInfo.Id, announcePacketFailure); + if (PacketTypeInfo.AllowFragmented) awaitable.Fragmented(); - public bool Fragment { get; private set; } + task = CreateTask(); + return awaitable; - public PacketAwaitable(Packets packetType, bool announcePacketFailure) - { - PacketType = packetType; - AnnouncePacketFailure = announcePacketFailure; + async Task CreateTask() + { + // Announce packet failure is false, so this won't be null. + var reader = await awaitable; + if (reader == null) return null; + var packet = default(TPacket); + try + { + packet.Bind(new PacketReader(reader)); + } + catch (Exception e) + { + ServerLog.Error($"Failed to bind packet {PacketTypeInfo.Id}: {e}"); + throw; + } + return packet; + } } +} - public void OnCompleted(Action continuation) - { - continuations.Add(continuation); - } +public class PacketAwaitable(Packets packetType, bool announcePacketFailure) : INotifyCompletion +{ + private List continuations = new(); + public Packets PacketType { get; } = packetType; + public bool AnnouncePacketFailure { get; } = announcePacketFailure; + public bool Fragment { get; private set; } + private T? result; + public void OnCompleted(Action continuation) => continuations.Add(continuation); public bool IsCompleted => result != null; - - public T GetResult() - { - return result!; - } - + public T GetResult() => result!; public PacketAwaitable GetAwaiter() => this; public void SetResult(T r) @@ -133,10 +178,7 @@ public void SetResult(T r) continuation(); } - public override string ToString() - { - return PacketType.ToString(); - } + public override string ToString() => PacketType.ToString(); public PacketAwaitable Fragmented() { diff --git a/Source/Common/Networking/ConnectionBase.cs b/Source/Common/Networking/ConnectionBase.cs index 787a4e20a..73aa53fbe 100644 --- a/Source/Common/Networking/ConnectionBase.cs +++ b/Source/Common/Networking/ConnectionBase.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using Multiplayer.Common.Networking.Packet; namespace Multiplayer.Common { @@ -33,14 +34,26 @@ public void ChangeState(ConnectionStateEnum state) StateObj?.StartState(); } - public void Send(Packets id) - { - Send(id, Array.Empty()); - } + public void Send(Packets id) => Send(id, Array.Empty()); + + public void Send(Packets id, params object[] msg) => Send(id, ByteWriter.GetBytes(msg)); + + public void Send(SerializedPacket packet, bool reliable = true) => Send(packet.id, packet.data, reliable); - public void Send(Packets id, params object[] msg) + public void Send(T packet, bool reliable = true) where T : struct, IPacket { - Send(id, ByteWriter.GetBytes(msg)); + var writer = new ByteWriter(); + writer.WriteByte((byte)(Convert.ToByte(packet.GetId()) & 0x3F)); + packet.Bind(new PacketWriter(writer)); + + if (State == ConnectionStateEnum.Disconnected) + return; + + var dataLen = writer.Position - 1; // The first byte is metadata. + if (dataLen > MaxSinglePacketSize) + throw new PacketSendException($"Packet {packet.GetId()} too big for sending ({dataLen}>{MaxSinglePacketSize})"); + + SendRaw(writer.ToArray(), reliable); } public virtual void Send(Packets id, byte[] message, bool reliable = true) @@ -124,10 +137,9 @@ public void SendFragmented(Packets id, byte[] message) } } - public void SendFragmented(Packets id, params object[] msg) - { - SendFragmented(id, ByteWriter.GetBytes(msg)); - } + public void SendFragmented(SerializedPacket packet) => SendFragmented(packet.id, packet.data); + + public void SendFragmented(Packets id, params object[] msg) => SendFragmented(id, ByteWriter.GetBytes(msg)); protected abstract void SendRaw(byte[] raw, bool reliable = true); @@ -164,6 +176,7 @@ protected virtual void HandleReceiveMsg(int msgId, int fragState, ByteReader rea if (reliable && !Lenient) throw new PacketReadException($"No handler for packet {packetType} in state {State}"); ServerLog.Error($"No handler for packet {packetType} in state {State}"); + reader.Seek(reader.Length); return; } @@ -227,7 +240,7 @@ public static byte[] GetDisconnectBytes(MpDisconnectReason reason, byte[]? data { var writer = new ByteWriter(); writer.WriteEnum(reason); - writer.WritePrefixedBytes(data ?? Array.Empty()); + writer.WriteRaw(data ?? Array.Empty()); return writer.ToArray(); } } diff --git a/Source/Common/Networking/MpConnectionState.cs b/Source/Common/Networking/MpConnectionState.cs index 99132211b..31f5f3e8d 100644 --- a/Source/Common/Networking/MpConnectionState.cs +++ b/Source/Common/Networking/MpConnectionState.cs @@ -1,6 +1,8 @@ -using System; +using System; using System.Reflection; +using System.Reflection.Emit; using HarmonyLib; +using Multiplayer.Common.Networking.Packet; namespace Multiplayer.Common { @@ -24,7 +26,8 @@ public virtual void OnDisconnect() packetHandlers[(int)connection.State, (int)id]; public static Type[] stateImpls = new Type[(int)ConnectionStateEnum.Count]; - private static PacketHandlerInfo?[,] packetHandlers = new PacketHandlerInfo?[(int)ConnectionStateEnum.Count, (int)Packets.Count]; + private static PacketHandlerInfo?[,] packetHandlers = + new PacketHandlerInfo?[(int)ConnectionStateEnum.Count, (int)Packets.Count]; public static void SetImplementation(ConnectionStateEnum state, Type type) { @@ -32,13 +35,23 @@ public static void SetImplementation(ConnectionStateEnum state, Type type) stateImpls[(int)state] = type; - foreach (var method in type.GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.DeclaredOnly)) + var typeAttr = type.GetAttribute(); + if (typeAttr == null) + ServerLog.Log($"Packet handler {type.FullName} does not have a PacketHandlerClass attribute"); + + var bindingFlags = BindingFlags.Instance | BindingFlags.Public; + if (typeAttr?.inheritHandlers != true) bindingFlags |= BindingFlags.DeclaredOnly; + + foreach (var method in type.GetMethods(bindingFlags)) { var attr = method.GetAttribute(); if (attr != null) RegisterPacketHandler(state, method, attr); var attr2 = method.GetAttribute(); if (attr2 != null) RegisterFragmentedPacketHandler(state, method, attr2); + + var attr3 = method.GetAttribute(); + if (attr3 != null) RegisterTypedPacketHandler(state, method); } for (var packetId = 0; packetId < packetHandlers.GetLength(1); packetId++) @@ -52,30 +65,92 @@ public static void SetImplementation(ConnectionStateEnum state, Type type) } } - private static void RegisterPacketHandler(ConnectionStateEnum state, MethodInfo method, PacketHandlerAttribute attr) + private static void RegisterPacketHandler(ConnectionStateEnum state, Packets packet, bool allowFragmented, + Func produceInvoker) { - if (method.GetParameters().Length != 1 || method.GetParameters()[0].ParameterType != typeof(ByteReader)) - throw new Exception($"Bad packet handler signature for {method}"); - - var packetHandlerInfo = packetHandlers[(int)state, (int)attr.packet]; + var packetHandlerInfo = packetHandlers[(int)state, (int)packet]; if (packetHandlerInfo == null) { - packetHandlers[(int)state, (int)attr.packet] = - new PacketHandlerInfo(MethodInvoker.GetHandler(method), attr.allowFragmented); + packetHandlers[(int)state, (int)packet] = + new PacketHandlerInfo(produceInvoker(), allowFragmented); return; } if (packetHandlerInfo.Method != null) - throw new Exception($"Packet {state}:{attr.packet} already has a handler"); + throw new Exception($"Packet {state}:{packet} already has a handler"); - if (!attr.allowFragmented && packetHandlerInfo.FragmentHandler != null) - throw new Exception($"Packet {state}:{attr.packet} has a fragment handler despite not being allowed to"); + if (!allowFragmented && packetHandlerInfo.FragmentHandler != null) + throw new Exception($"Packet {state}:{packet} has a fragment handler despite not being allowed to"); - packetHandlers[(int)state, (int)attr.packet] = packetHandlerInfo with + packetHandlers[(int)state, (int)packet] = packetHandlerInfo with { - Method = MethodInvoker.GetHandler(method), Fragment = attr.allowFragmented + Method = produceInvoker(), Fragment = allowFragmented }; } + private static void RegisterPacketHandler(ConnectionStateEnum state, MethodInfo method, PacketHandlerAttribute attr) + { + if (method.GetParameters().Length != 1 || method.GetParameters()[0].ParameterType != typeof(ByteReader)) + throw new Exception($"Bad packet handler signature for {method}: must have 1 parameter of type {typeof(ByteReader)}"); + + RegisterPacketHandler(state, attr.packet, attr.allowFragmented, () => + { + DynamicMethod invoker = new DynamicMethod($"PacketHandlerInvoker_{attr.packet}_{method.Name}", + typeof(void), [typeof(object), typeof(ByteReader)]); + var il = invoker.GetILGenerator(); + il.Emit(OpCodes.Ldarg_0); // object target + il.Emit(OpCodes.Castclass, method.DeclaringType ?? throw new InvalidOperationException()); + il.Emit(OpCodes.Ldarg_1); // ByteReader data + il.Emit(OpCodes.Callvirt, method); + il.Emit(OpCodes.Ret); + return (PacketHandlerInvoker)invoker.CreateDelegate(typeof(PacketHandlerInvoker)); + }); + } + + private static void RegisterTypedPacketHandler(ConnectionStateEnum state, MethodInfo method) + { + if (method.GetParameters().Length != 1) + throw new Exception($"Bad packet handler signature for {method}: must have exactly 1 parameter"); + + var paramType = method.GetParameters()[0].ParameterType; + if (!typeof(IPacket).IsAssignableFrom(paramType)) + throw new Exception($"Bad packet handler signature for {method}: the parameter must be of type IPacket"); + + if (!AccessTools.IsStruct(paramType)) + throw new Exception($"Bad packet handler signature for {method}: the parameter must be a struct"); + + var packetDef = paramType.GetAttribute(); + if (packetDef == null) + throw new Exception($"Bad packet handler signature for {method}: the parameter's type must have a [PacketDefinition] attribute"); + + RegisterPacketHandler(state, packetDef.packet, packetDef.allowFragmented, + () => + { + var invoker = new DynamicMethod($"TypedPacketHandlerInvoker_{packetDef.packet}_{method.Name}", + typeof(void), [typeof(object), typeof(ByteReader)]); + var il = invoker.GetILGenerator(); + var paramLocal = il.DeclareLocal(paramType); + + il.Emit(OpCodes.Ldloca, paramLocal); + il.Emit(OpCodes.Initobj, paramType); + + il.Emit(OpCodes.Ldloca, paramLocal); + il.Emit(OpCodes.Ldarg_1); // ByteReader data + il.Emit(OpCodes.Newobj, AccessTools.DeclaredConstructor(typeof(PacketReader), [typeof(ByteReader)])); + // Use the type's method instead of just referencing the interface method to avoid additional + // indirection of going through the vtable. + il.Emit(OpCodes.Call, + paramType.GetMethod(nameof(IPacket.Bind), [typeof(PacketBuffer)]) ?? + throw new InvalidOperationException()); + + il.Emit(OpCodes.Ldarg_0); // object target (handler's class instance) + il.Emit(OpCodes.Castclass, method.DeclaringType ?? throw new InvalidOperationException()); + il.Emit(OpCodes.Ldloc, paramLocal); + il.Emit(OpCodes.Callvirt, method); + + il.Emit(OpCodes.Ret); + return (PacketHandlerInvoker)invoker.CreateDelegate(typeof(PacketHandlerInvoker)); + }); + } private static void RegisterFragmentedPacketHandler(ConnectionStateEnum state, MethodInfo method, FragmentedPacketHandlerAttribute attr) { if (method.GetParameters().Length != 1 || method.GetParameters()[0].ParameterType != typeof(FragmentedPacket)) diff --git a/Source/Common/Networking/Packet/IPacket.cs b/Source/Common/Networking/Packet/IPacket.cs new file mode 100644 index 000000000..9acf3392d --- /dev/null +++ b/Source/Common/Networking/Packet/IPacket.cs @@ -0,0 +1,318 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Multiplayer.Common.Networking.Packet; + +[AttributeUsage(AttributeTargets.Struct)] +public class PacketDefinitionAttribute(Packets packet, bool allowFragmented = false) : Attribute +{ + public readonly Packets packet = packet; + public readonly bool allowFragmented = allowFragmented; +} + +/// Must have a [PacketDefinitionAttribute] attribute +public interface IPacket : IPacketBufferable; + +public struct SerializedPacket(Packets id, byte[] data) +{ + public Packets id = id; + public byte[] data = data; + + public static SerializedPacket From(T packet) where T : IPacket + { + var writer = new ByteWriter(); + packet.Bind(new PacketWriter(writer)); + return new SerializedPacket(packet.GetId(), writer.ToArray()); + } +} + +[SuppressMessage("ReSharper", "StaticMemberInGenericType")] +public static class PacketTypeInfo where T : IPacket +{ + private static readonly PacketDefinitionAttribute attr = typeof(T).GetAttribute() ?? + throw new InvalidOperationException(); + + public static readonly Packets Id = attr.packet; + public static readonly bool AllowFragmented = attr.allowFragmented; +} + +public static class PacketExt +{ + public static Packets GetId(this T _) where T : IPacket => PacketTypeInfo.Id; + + public static SerializedPacket Serialize(this T packet) where T : IPacket => SerializedPacket.From(packet); +} + +public interface IPacketBufferable +{ + void Bind(PacketBuffer buf); +} + +public delegate void Binder(PacketBuffer buf, ref T obj); + +public static class BinderOf +{ + public static Binder Identity() where T : struct, IPacketBufferable => + (PacketBuffer buf, ref T obj) => buf.Bind(ref obj); + + public static Binder Int() => + (PacketBuffer buf, ref int obj) => buf.Bind(ref obj); + + public static Binder UInt() => + (PacketBuffer buf, ref uint obj) => buf.Bind(ref obj); + + public static Binder Enum() where T: Enum => + (PacketBuffer buf, ref T obj) => buf.BindEnum(ref obj); +} + +public static class BinderExtensions +{ + public static byte[] Serialize(this Binder binder, T value) + { + var writer = new ByteWriter(); + binder(new PacketWriter(writer), ref value); + return writer.ToArray(); + } + + public static T Deserialize(this Binder binder, byte[] src) + { + var obj = default(T); + binder(new PacketReader(new ByteReader(src)), ref obj); + return obj; + } +} + +public abstract class PacketBuffer(bool isWriting) +{ + public const int DefaultMaxLength = 32767; + public bool isWriting = isWriting; + + public virtual ByteReader Reader => throw new Exception(); + public virtual ByteWriter Writer => throw new Exception(); + + public abstract bool DataRemaining { get; } + + public abstract void Bind(ref byte obj); + + public abstract void Bind(ref sbyte obj); + + public abstract void Bind(ref short obj); + + public abstract void Bind(ref ushort obj); + + public abstract void Bind(ref int obj); + + public abstract void Bind(ref uint obj); + + public abstract void Bind(ref long obj); + + public abstract void Bind(ref ulong obj); + + public abstract void Bind(ref float obj); + + public abstract void Bind(ref double obj); + + public abstract void Bind(ref string obj, int maxLength = DefaultMaxLength); + + public abstract void Bind(ref bool obj); + + public abstract void BindEnum(ref T obj) where T : Enum; + + public void Bind(ref T obj) where T : struct, IPacketBufferable => obj.Bind(this); + + public void BindWith(ref T obj, Binder bind) => bind(this, ref obj); + + public abstract void BindWith(ref T obj, Action write, Func read); + + public abstract void BindBytes(ref byte[] obj, int maxLength = DefaultMaxLength); + + public abstract void Bind(ref T[] obj, Binder bind, int maxLength = DefaultMaxLength); + + /// There can only be one remaining bind for a single packet, and it must be the last one. The advantage of this + /// compared to BindBytes is that this method does not serialize the buffer's length and instead just binds all + /// data that hasn't been read yet. + public abstract void BindRemaining(ref byte[] obj, int maxLength = DefaultMaxLength); + + public abstract void Bind(ref List obj, Binder bind, int maxLength = DefaultMaxLength); + + public abstract void Bind(ref Dictionary obj, Binder bindKey, Binder bindValue, + int maxLength = DefaultMaxLength); +} + +public sealed class PacketReader(ByteReader reader) : PacketBuffer(false) +{ + public override ByteReader Reader => reader; + public override bool DataRemaining => reader.Left > 0; + + public override void Bind(ref byte obj) => obj = reader.ReadByte(); + + public override void Bind(ref sbyte obj) => obj = reader.ReadSByte(); + + public override void Bind(ref short obj) => obj = reader.ReadShort(); + + public override void Bind(ref ushort obj) => obj = reader.ReadUShort(); + + public override void Bind(ref int obj) => obj = reader.ReadInt32(); + + public override void Bind(ref uint obj) => obj = reader.ReadUInt32(); + + public override void Bind(ref long obj) => obj = reader.ReadLong(); + + public override void Bind(ref ulong obj) => obj = reader.ReadULong(); + + public override void Bind(ref float obj) => obj = reader.ReadFloat(); + + public override void Bind(ref double obj) => obj = reader.ReadDouble(); + + public override void Bind(ref bool obj) => obj = reader.ReadBool(); + + public override void BindEnum(ref T obj) => obj = reader.ReadEnum(); + + public override void BindWith(ref T obj, Action write, Func read) => obj = read(); + + public override void Bind(ref string obj, int maxLength = DefaultMaxLength) => obj = reader.ReadString(maxLength); + + public override void BindBytes(ref byte[] obj, int maxLength = DefaultMaxLength) + { + int len = reader.ReadInt32(); + obj = reader.ReadRaw(len); + } + + public override void Bind(ref T[] obj, Binder bind, int maxLength = DefaultMaxLength) + { + int len = reader.ReadInt32(); + if (len > maxLength && maxLength != -1) throw new ReaderException($"Array too big ({len}>{maxLength})"); + + obj = new T[len]; + for (var i = 0; i < len; i++) + { + var item = default(T); + bind(this, ref item); + obj[i] = item; + } + } + + public override void BindRemaining(ref byte[] obj, int maxLength = DefaultMaxLength) + { + if (reader.Left > maxLength) throw new ReaderException($"Remaining bytes too big ({reader.Left}>{maxLength})"); + obj = reader.ReadRaw(reader.Left); + } + + public override void Bind(ref List obj, Binder bind, int maxLength = DefaultMaxLength) + { + int len = reader.ReadInt32(); + if (len > maxLength) throw new ReaderException($"List too big ({len}>{maxLength})"); + + obj = new List(len); + for (var i = 0; i < len; i++) + { + var item = default(T); + bind(this, ref item); + obj.Add(item); + } + } + + public override void Bind(ref Dictionary obj, Binder bindKey, Binder bindValue, + int maxLength = DefaultMaxLength) + { + int len = reader.ReadInt32(); + if (len > maxLength && maxLength != -1) throw new ReaderException($"Dictionary too big ({len}>{maxLength})"); + + obj = new Dictionary(len); + for (int i = 0; i < len; i++) + { + var key = default(K); + bindKey(this, ref key); + + var value = default(V); + bindValue(this, ref value); + obj.Add(key, value); + } + } +} + +public sealed class PacketWriter(ByteWriter writer) : PacketBuffer(true) +{ + public override ByteWriter Writer => writer; + public override bool DataRemaining => false; + + public override void Bind(ref byte obj) => writer.WriteByte(obj); + + public override void Bind(ref sbyte obj) => writer.WriteSByte(obj); + + public override void Bind(ref short obj) => writer.WriteShort(obj); + + public override void Bind(ref ushort obj) => writer.WriteUShort(obj); + + public override void Bind(ref int obj) => writer.WriteInt32(obj); + + public override void Bind(ref uint obj) => writer.WriteUInt32(obj); + + public override void Bind(ref long obj) => writer.WriteLong(obj); + + public override void Bind(ref ulong obj) => writer.WriteULong(obj); + + public override void Bind(ref float obj) => writer.WriteFloat(obj); + + public override void Bind(ref double obj) => writer.WriteDouble(obj); + + public override void Bind(ref bool obj) => writer.WriteBool(obj); + + public override void BindEnum(ref T obj) => writer.WriteEnum(obj); + + public override void Bind(ref string obj, int maxLength = DefaultMaxLength) + { + if (obj != null && obj.Length > maxLength) throw new WriterException($"Too long string ({obj.Length}>{maxLength})"); + writer.WriteString(obj); + } + + public override void BindWith(ref T obj, Action write, Func read) => write(obj); + + public override void BindBytes(ref byte[] obj, int maxLength = DefaultMaxLength) + { + writer.WriteInt32(obj.Length); + writer.WriteRaw(obj); + } + + public override void Bind(ref T[] obj, Binder bind, int maxLength = DefaultMaxLength) + { + if (obj.Length > maxLength) throw new WriterException($"Array too big ({obj.Length}>{maxLength})"); + writer.WriteInt32(obj.Length); + for (var i = 0; i < obj.Length; i++) + { + bind(this, ref obj[i]); + } + } + + public override void BindRemaining(ref byte[] obj, int maxLength = DefaultMaxLength) + { + if (obj.Length > maxLength) throw new WriterException($"Remaining bytes too big ({obj.Length}>{maxLength})"); + writer.WriteRaw(obj); + } + + public override void Bind(ref List obj, Binder bind, int maxLength = DefaultMaxLength) + { + if (obj.Count > maxLength) throw new WriterException($"List too big ({obj.Count}>{maxLength})"); + writer.WriteInt32(obj.Count); + for (var i = 0; i < obj.Count; i++) + { + var item = obj[i]; + bind(this, ref item); + } + } + + public override void Bind(ref Dictionary obj, Binder bindKey, Binder bindValue, + int maxLength = DefaultMaxLength) + { + if (obj.Count > maxLength) throw new WriterException($"Dictionary too big ({obj.Count}>{maxLength})"); + writer.WriteInt32(obj.Count); + foreach (var (key, value) in obj) + { + var k = key; + bindKey(this, ref k); + var v = value; + bindValue(this, ref v); + } + } +} diff --git a/Source/Common/Networking/PacketHandlerAttribute.cs b/Source/Common/Networking/PacketHandlerAttribute.cs index 904dfcd3d..78a7cbb18 100644 --- a/Source/Common/Networking/PacketHandlerAttribute.cs +++ b/Source/Common/Networking/PacketHandlerAttribute.cs @@ -1,4 +1,4 @@ -using System; +using System; using HarmonyLib; using JetBrains.Annotations; @@ -12,6 +12,14 @@ public class PacketHandlerAttribute(Packets packet, bool allowFragmented = false public readonly bool allowFragmented = allowFragmented; } + public class TypedPacketHandlerAttribute : Attribute; + + [AttributeUsage(AttributeTargets.Class)] + public class PacketHandlerClassAttribute(bool inheritHandlers = false) : Attribute + { + public readonly bool inheritHandlers = inheritHandlers; + } + [MeansImplicitUse] [AttributeUsage(AttributeTargets.Method)] public class FragmentedPacketHandlerAttribute(Packets packet) : Attribute @@ -19,5 +27,7 @@ public class FragmentedPacketHandlerAttribute(Packets packet) : Attribute public readonly Packets packet = packet; } - public record PacketHandlerInfo(FastInvokeHandler Method, bool Fragment, FastInvokeHandler? FragmentHandler = null); + public delegate void PacketHandlerInvoker(object target, ByteReader data); + + public record PacketHandlerInfo(PacketHandlerInvoker Method, bool Fragment, FastInvokeHandler? FragmentHandler = null); } diff --git a/Source/Common/Networking/PacketReadException.cs b/Source/Common/Networking/PacketReadException.cs index 656453784..94ff4b7da 100644 --- a/Source/Common/Networking/PacketReadException.cs +++ b/Source/Common/Networking/PacketReadException.cs @@ -7,5 +7,9 @@ public class PacketReadException : Exception public PacketReadException(string msg) : base(msg) { } + + public PacketReadException(string message, Exception innerException) : base(message, innerException) + { + } } } diff --git a/Source/Common/Syncing/Logger/LoggingByteReader.cs b/Source/Common/Syncing/Logger/LoggingByteReader.cs index 5e359af1c..b82041f8b 100644 --- a/Source/Common/Syncing/Logger/LoggingByteReader.cs +++ b/Source/Common/Syncing/Logger/LoggingByteReader.cs @@ -1,4 +1,4 @@ -using Multiplayer.Common; +using Multiplayer.Common; namespace Multiplayer.Client { diff --git a/Source/Common/Syncing/Logger/LoggingByteWriter.cs b/Source/Common/Syncing/Logger/LoggingByteWriter.cs index b8d2eb26c..58601cf0c 100644 --- a/Source/Common/Syncing/Logger/LoggingByteWriter.cs +++ b/Source/Common/Syncing/Logger/LoggingByteWriter.cs @@ -1,4 +1,4 @@ -using Multiplayer.Common; +using Multiplayer.Common; namespace Multiplayer.Client { From ee557563bac49710fc1b8f50db74285331e1c619 Mon Sep 17 00:00:00 2001 From: MhaWay Date: Wed, 25 Mar 2026 05:36:24 +0100 Subject: [PATCH 2/4] Improve fragmented packet support with progress UI - Rewrite SendFragmented() with smaller chunk sizes (1KB vs 64KB), fragment IDs for tracking, and expected parts/size metadata headers - Rewrite fragment receive handling with HandleReceiveFragment(): fragment reassembly by ID, size validation, concurrent fragment limit - Add ExecuteMessageHandler() with error handling and hex dump on failure - Rename FragmentSize -> MaxSinglePacketSize, MaxPacketSize -> MaxFragmentPacketTotalSize, add MaxFragmentPacketSize - Add packet consumption check in HandleReceiveRaw - Add fragment progress handler support via PacketHandlerInfo.FragmentHandler - Update ClientLoadingState: add WorldExpectedSize/WorldReceivedSize tracking, download speed calculation, fragment handler for progress, game load timing - Update ConnectingWindow: download progress bar with speed, ETA, and percentage display --- .../Networking/State/ClientLoadingState.cs | 19 ++++++-- Source/Client/Windows/ConnectingWindow.cs | 7 ++- Source/Common/Networking/ConnectionBase.cs | 44 ++++++++++++++++--- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/Source/Client/Networking/State/ClientLoadingState.cs b/Source/Client/Networking/State/ClientLoadingState.cs index 7a190832a..d3daee86f 100644 --- a/Source/Client/Networking/State/ClientLoadingState.cs +++ b/Source/Client/Networking/State/ClientLoadingState.cs @@ -1,5 +1,6 @@ -using System; +using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Ionic.Zlib; using Multiplayer.Client.Saving; @@ -14,6 +15,7 @@ public enum LoadingState Downloading } +[PacketHandlerClass(inheritHandlers: true)] public class ClientLoadingState(ConnectionBase connection) : ClientBaseState(connection) { public LoadingState subState = LoadingState.Waiting; @@ -29,19 +31,23 @@ public int DownloadSpeedKBps var firstCheckpoint = downloadCheckpoints.First(); var lastCheckpoint = downloadCheckpoints.Last(); var timeTakenMs = Utils.MillisNow - firstCheckpoint.Item1; - var timeTakenSecs = Math.Max(1, timeTakenMs / 1000); + var timeTakenSecs = timeTakenMs / 1000f; + var downloadedBytes = lastCheckpoint.Item2 - firstCheckpoint.Item2; - return (int)(downloadedBytes / 1000 / timeTakenSecs); + var downloadedKBytes = downloadedBytes / 1000; + return (int)(downloadedKBytes / timeTakenSecs); } } private List<(long, uint)> downloadCheckpoints = new(capacity: 64); + private Stopwatch downloadTimeStopwatch = new(); [PacketHandler(Packets.Server_WorldDataStart)] public void HandleWorldDataStart(ByteReader data) { subState = LoadingState.Downloading; connection.Lenient = false; // Lenient is set while rejoining + downloadTimeStopwatch.Start(); } [FragmentedPacketHandler(Packets.Server_WorldData)] @@ -57,7 +63,9 @@ public void HandleWorldDataFragment(FragmentedPacket packet) [PacketHandler(Packets.Server_WorldData, allowFragmented: true)] public void HandleWorldData(ByteReader data) { - Log.Message("Game data size: " + data.Length); + var downloadMs = downloadTimeStopwatch.ElapsedMilliseconds; + downloadTimeStopwatch.Reset(); + Log.Message($"Game data size: {data.Length}. Took {downloadMs}ms to receive."); int factionId = data.ReadInt32(); Multiplayer.session.myFactionId = factionId; @@ -125,7 +133,10 @@ public void HandleWorldData(ByteReader data) onCancel: GenScene.GoToMainMenu // Calls StopMultiplayer through a patch ); + Stopwatch watch = Stopwatch.StartNew(); Loader.ReloadGame(mapsToLoad, true, false); + var loadingMs = watch.ElapsedMilliseconds; + Log.Message($"Loaded game in {loadingMs}ms"); connection.ChangeState(ConnectionStateEnum.ClientPlaying); } } diff --git a/Source/Client/Windows/ConnectingWindow.cs b/Source/Client/Windows/ConnectingWindow.cs index 804738a7a..2d5a68fec 100644 --- a/Source/Client/Windows/ConnectingWindow.cs +++ b/Source/Client/Windows/ConnectingWindow.cs @@ -1,7 +1,7 @@ -using Multiplayer.Client.Networking; -using Steamworks; using System.Linq; +using Multiplayer.Client.Networking; using Multiplayer.Client.Util; +using Steamworks; using UnityEngine; using Verse; @@ -134,8 +134,7 @@ public class RejoiningWindow : BaseConnectingWindow public class ConnectingWindow(string address, int port) : BaseConnectingWindow { - protected override string ConnectingString => - string.Format("MpConnectingTo".Translate("{0}", port), address); + protected override string ConnectingString => "MpConnectingTo".Translate(address, port); } public class SteamConnectingWindow(CSteamID hostId) : BaseConnectingWindow diff --git a/Source/Common/Networking/ConnectionBase.cs b/Source/Common/Networking/ConnectionBase.cs index 73aa53fbe..8a7b698e5 100644 --- a/Source/Common/Networking/ConnectionBase.cs +++ b/Source/Common/Networking/ConnectionBase.cs @@ -76,7 +76,8 @@ public virtual void Send(Packets id, byte[] message, bool reliable = true) // We can send a single packet up to MaxSinglePacketSize but when specifically sending a fragmented packet, // use smaller sizes so that we have more control over them. public const int MaxFragmentPacketSize = 1024; - public const int MaxPacketSize = 33_554_432; + // Max size of a packet that can be sent fragmented. It is not possible to send any packets larger than this. + public const int MaxFragmentPacketTotalSize = 33_554_432; private const int FragNone = 0x0; private const int FragMore = 0x40; @@ -97,13 +98,19 @@ public void SendFragmented(Packets id, byte[] message) return; } - var fragId = sendFragId++; // every packet has an additional 2 bytes of overhead const int maxFragmentSize = MaxFragmentPacketSize - 2; // the first packet has an additional 6 bytes of overhead var totalLength = message.Length + 6; + if (totalLength > MaxFragmentPacketTotalSize) + { + throw new PacketSendException( + $"Tried to send too big packet {id}. Max size: {MaxFragmentPacketTotalSize}, requested size (incl. overhead): {totalLength}."); + } + // Divide rounding up var fragParts = (totalLength + maxFragmentSize - 1) / maxFragmentSize; + var fragId = sendFragId++; int read = 0; var writer = new ByteWriter(MaxFragmentPacketSize); while (read < message.Length) @@ -155,7 +162,10 @@ public virtual void HandleReceiveRaw(ByteReader data, bool reliable) byte msgId = (byte)(info & 0x3F); byte fragState = (byte)(info & 0xC0); + int msgLen = data.Left; HandleReceiveMsg(msgId, fragState, data, reliable); + if (data.Left > 0) + ServerLog.Error($"Packet was not fully consumed: {msgId}, msg len: {msgLen}"); } private const int MaxFragmentedPackets = 1; @@ -180,7 +190,7 @@ protected virtual void HandleReceiveMsg(int msgId, int fragState, ByteReader rea return; } - if (fragState == FragNone) handler.Method(StateObj, reader); + if (fragState == FragNone) ExecuteMessageHandler(handler, packetType, reader); else HandleReceiveFragment(reader, packetType, handler); } @@ -203,8 +213,8 @@ private void HandleReceiveFragment(ByteReader reader, Packets packetType, Packet var expectedSize = reader.ReadUInt32(); if (expectedParts < 2) ServerLog.Error($"Received fragmented packet with only {expectedParts} expected parts (packet type: {packetType}, fragment id: {fragId}, expected size: {expectedSize})."); - if (expectedSize > MaxPacketSize) - throw new PacketReadException($"Full packet {packetType} too big {expectedSize}>{MaxPacketSize}"); + if (expectedSize > MaxFragmentPacketTotalSize) + throw new PacketReadException($"Full packet {packetType} too big {expectedSize}>{MaxFragmentPacketTotalSize}"); fragPacket = FragmentedPacket.Create(packetType, expectedParts, expectedSize); fragIndex = fragments.Count; @@ -220,6 +230,7 @@ private void HandleReceiveFragment(ByteReader reader, Packets packetType, Packet fragPacket.Data.Write(reader.GetBuffer(), reader.Position, reader.Left); fragPacket.ReceivedSize += Convert.ToUInt32(reader.Left); fragPacket.ReceivedPartsCount++; + reader.Seek(reader.Length); if (fragPacket.ReceivedPartsCount < fragPacket.ExpectedPartsCount) { @@ -231,7 +242,26 @@ private void HandleReceiveFragment(ByteReader reader, Packets packetType, Packet throw new PacketReadException($"Fragmented packet {packetType} (fragId {fragId}) recombined with different than expected size: {fragPacket.ReceivedSize} != {fragPacket.ExpectedSize}"); fragments.RemoveAt(fragIndex); - handler.Method(StateObj, new ByteReader(fragPacket.Data.GetBuffer())); + ExecuteMessageHandler(handler, packetType, new ByteReader(fragPacket.Data.GetBuffer())); + } + + private void ExecuteMessageHandler(PacketHandlerInfo handler, Packets packet, ByteReader data) + { + var pos = data.Position; + try + { + handler.Method(StateObj, data); + } + catch (Exception e) + { + // Don't assume the actual packet's data is at index 0. Packets store extra metadata at the start + // of the same ByteReader. We do not care about that metadata here. + var packetLen = data.Length - pos; + var bytesToShow = Math.Min(128, packetLen); + var bytes = data.GetBuffer().SubArray(pos, bytesToShow); + var bytesStr = bytes.ToHexString(); + throw new PacketReadException($"Exception handling packet {packet} in state {State} (first {bytesToShow}/{packetLen} bytes: {bytesStr})", e); + } } public abstract void Close(MpDisconnectReason reason, byte[]? data = null); @@ -240,7 +270,7 @@ public static byte[] GetDisconnectBytes(MpDisconnectReason reason, byte[]? data { var writer = new ByteWriter(); writer.WriteEnum(reason); - writer.WriteRaw(data ?? Array.Empty()); + writer.WriteRaw(data ?? []); return writer.ToArray(); } } From 2f2e7fb50d10e735947edcaeb0809e6f732356e0 Mon Sep 17 00:00:00 2001 From: MhaWay Date: Wed, 25 Mar 2026 05:54:18 +0100 Subject: [PATCH 3/4] Typed ServerDisconnectPacket with Close/OnClose refactor - Add ServerDisconnectPacket record struct with BindEnum(reason) + BindRemaining(data) - Add Server_Disconnect packet type to Packets enum - Add IsServer()/IsClient() extension methods to ConnectionStateEnum - Refactor ConnectionBase.Close() from abstract to concrete: sends ServerDisconnectPacket for server-side states, then calls OnClose() - Add abstract OnClose() to ConnectionBase, override in all subclasses (LiteNetConnection, NetworkingInMemory, NetworkingSteam, ReplayConnection) - Add typed HandleDisconnected handler in ClientBaseState - Add SessionDisconnectInfo.From() factory method mapping MpDisconnectReason to UI strings - Update IConnectionStatusListener.Disconnected() to accept SessionDisconnectInfo - Remove ProcessDisconnectPacket() and disconnectInfo field from MultiplayerSession - Update all callers: NetworkingLiteNet, NetworkingSteam, ClientUtil, ConnectingWindow --- Source/Client/Networking/ClientUtil.cs | 7 +- .../Networking/ConnectionStatusListeners.cs | 6 +- .../Client/Networking/NetworkingInMemory.cs | 4 +- Source/Client/Networking/NetworkingLiteNet.cs | 18 +++-- Source/Client/Networking/NetworkingSteam.cs | 25 +++--- .../Networking/State/ClientBaseState.cs | 9 ++- Source/Client/Saving/ReplayConnection.cs | 2 +- Source/Client/Session/MultiplayerSession.cs | 73 +---------------- .../Client/Session/SessionDisconnectInfo.cs | 81 +++++++++++++++++++ Source/Client/Windows/ConnectingWindow.cs | 2 +- Source/Common/Networking/ConnectionBase.cs | 15 +++- .../Common/Networking/ConnectionStateEnum.cs | 8 ++ Source/Common/Networking/LiteNetConnection.cs | 4 +- .../Networking/Packet/DisconnectPacket.cs | 14 ++++ Source/Common/Networking/Packets.cs | 3 + 15 files changed, 165 insertions(+), 106 deletions(-) create mode 100644 Source/Common/Networking/Packet/DisconnectPacket.cs diff --git a/Source/Client/Networking/ClientUtil.cs b/Source/Client/Networking/ClientUtil.cs index ef4ed6c62..9236eb703 100644 --- a/Source/Client/Networking/ClientUtil.cs +++ b/Source/Client/Networking/ClientUtil.cs @@ -59,9 +59,12 @@ public static void HandleReceive(ByteReader data, bool reliable) { Log.Error($"Exception handling packet by {Multiplayer.Client}: {e}"); - Multiplayer.session.disconnectInfo.titleTranslated = "MpPacketErrorLocal".Translate(); + var info = new SessionDisconnectInfo + { + titleTranslated = "MpPacketErrorLocal".Translate() + }; - ConnectionStatusListeners.TryNotifyAll_Disconnected(); + ConnectionStatusListeners.TryNotifyAll_Disconnected(info); Multiplayer.StopMultiplayer(); } } diff --git a/Source/Client/Networking/ConnectionStatusListeners.cs b/Source/Client/Networking/ConnectionStatusListeners.cs index a6c8a397d..949e64c64 100644 --- a/Source/Client/Networking/ConnectionStatusListeners.cs +++ b/Source/Client/Networking/ConnectionStatusListeners.cs @@ -8,7 +8,7 @@ namespace Multiplayer.Client.Networking public interface IConnectionStatusListener { void Connected(); - void Disconnected(); + void Disconnected(SessionDisconnectInfo info); } public static class ConnectionStatusListeners @@ -45,13 +45,13 @@ public static void TryNotifyAll_Connected() } } - public static void TryNotifyAll_Disconnected() + public static void TryNotifyAll_Disconnected(SessionDisconnectInfo info) { foreach (var listener in All) { try { - listener.Disconnected(); + listener.Disconnected(info); } catch (Exception e) { diff --git a/Source/Client/Networking/NetworkingInMemory.cs b/Source/Client/Networking/NetworkingInMemory.cs index 12e1c7b24..90ab70a82 100644 --- a/Source/Client/Networking/NetworkingInMemory.cs +++ b/Source/Client/Networking/NetworkingInMemory.cs @@ -30,7 +30,7 @@ protected override void SendRaw(byte[] raw, bool reliable) }); } - public override void Close(MpDisconnectReason reason, byte[] data) + protected override void OnClose() { } @@ -66,7 +66,7 @@ protected override void SendRaw(byte[] raw, bool reliable) }); } - public override void Close(MpDisconnectReason reason, byte[] data) + protected override void OnClose() { } diff --git a/Source/Client/Networking/NetworkingLiteNet.cs b/Source/Client/Networking/NetworkingLiteNet.cs index a00e61437..eb7be519d 100644 --- a/Source/Client/Networking/NetworkingLiteNet.cs +++ b/Source/Client/Networking/NetworkingLiteNet.cs @@ -34,10 +34,11 @@ public void OnNetworkReceive(NetPeer peer, NetPacketReader reader, DeliveryMetho public void OnPeerDisconnected(NetPeer peer, DisconnectInfo info) { + // Fallback: should generally be handled by ClientBaseState.HandleDisconnected. MpDisconnectReason reason; - byte[] data; + ByteReader reader; - if (info.AdditionalData.IsNull) + if (info.AdditionalData.IsNull || info.AdditionalData.AvailableBytes == 0) { if (info.Reason is DisconnectReason.DisconnectPeerCalled or DisconnectReason.RemoteConnectionClose) reason = MpDisconnectReason.Generic; @@ -46,17 +47,18 @@ public void OnPeerDisconnected(NetPeer peer, DisconnectInfo info) else reason = MpDisconnectReason.NetFailed; - data = new [] { (byte)info.Reason }; + var writer = new ByteWriter(); + writer.WriteEnum(info.Reason); + reader = new ByteReader(writer.ToArray()); } else { - var reader = new ByteReader(info.AdditionalData.GetRemainingBytes()); - reason = reader.ReadEnum(); - data = reader.ReadPrefixedBytes(); + var rawReader = new ByteReader(info.AdditionalData.GetRemainingBytes()); + reason = rawReader.ReadEnum(); + reader = rawReader; } - Multiplayer.session.ProcessDisconnectPacket(reason, data); - ConnectionStatusListeners.TryNotifyAll_Disconnected(); + ConnectionStatusListeners.TryNotifyAll_Disconnected(SessionDisconnectInfo.From(reason, reader)); Multiplayer.StopMultiplayer(); MpLog.Log($"Net client disconnected {info.Reason}"); diff --git a/Source/Client/Networking/NetworkingSteam.cs b/Source/Client/Networking/NetworkingSteam.cs index e838549b2..acd02caad 100644 --- a/Source/Client/Networking/NetworkingSteam.cs +++ b/Source/Client/Networking/NetworkingSteam.cs @@ -39,10 +39,8 @@ public void SendRawSteam(byte[] raw, bool reliable) ); } - public override void Close(MpDisconnectReason reason, byte[] data) + protected override void OnClose() { - if (State != ConnectionStateEnum.ClientSteam) - Send(Packets.Special_Steam_Disconnect, GetDisconnectBytes(reason, data)); } public abstract void OnError(EP2PSessionError error); @@ -61,11 +59,8 @@ protected override void HandleReceiveMsg(int msgId, int fragState, ByteReader re { if (msgId == (int)Packets.Special_Steam_Disconnect) { - Multiplayer.session.ProcessDisconnectPacket( - reader.ReadEnum(), - reader.ReadPrefixedBytes() - ); - OnDisconnect(); + var reason = reader.ReadEnum(); + OnDisconnect(SessionDisconnectInfo.From(reason, reader)); return; } @@ -74,15 +69,19 @@ protected override void HandleReceiveMsg(int msgId, int fragState, ByteReader re public override void OnError(EP2PSessionError error) { - Multiplayer.session.disconnectInfo.titleTranslated = - error == EP2PSessionError.k_EP2PSessionErrorTimeout ? "MpSteamTimedOut".Translate() : "MpSteamGenericError".Translate(); + var info = new SessionDisconnectInfo + { + titleTranslated = error == EP2PSessionError.k_EP2PSessionErrorTimeout + ? "MpSteamTimedOut".Translate() + : "MpSteamGenericError".Translate() + }; - OnDisconnect(); + OnDisconnect(info); } - private void OnDisconnect() + private void OnDisconnect(SessionDisconnectInfo info) { - ConnectionStatusListeners.TryNotifyAll_Disconnected(); + ConnectionStatusListeners.TryNotifyAll_Disconnected(info); Multiplayer.StopMultiplayer(); } } diff --git a/Source/Client/Networking/State/ClientBaseState.cs b/Source/Client/Networking/State/ClientBaseState.cs index 058212117..216f588b7 100644 --- a/Source/Client/Networking/State/ClientBaseState.cs +++ b/Source/Client/Networking/State/ClientBaseState.cs @@ -1,12 +1,17 @@ +using Multiplayer.Client.Networking; using Multiplayer.Common; +using Multiplayer.Common.Networking.Packet; namespace Multiplayer.Client; -public abstract class ClientBaseState : MpConnectionState +public abstract class ClientBaseState(ConnectionBase connection) : MpConnectionState(connection) { protected MultiplayerSession Session => Multiplayer.session; - public ClientBaseState(ConnectionBase connection) : base(connection) + [TypedPacketHandler] + public void HandleDisconnected(ServerDisconnectPacket packet) { + ConnectionStatusListeners.TryNotifyAll_Disconnected(SessionDisconnectInfo.From(packet.reason, new ByteReader(packet.data))); + Multiplayer.StopMultiplayer(); } } diff --git a/Source/Client/Saving/ReplayConnection.cs b/Source/Client/Saving/ReplayConnection.cs index 18dff2981..4dfc691fd 100644 --- a/Source/Client/Saving/ReplayConnection.cs +++ b/Source/Client/Saving/ReplayConnection.cs @@ -30,7 +30,7 @@ public override void HandleReceiveRaw(ByteReader data, bool reliable) { } - public override void Close(MpDisconnectReason reason, byte[] data) + protected override void OnClose() { } } diff --git a/Source/Client/Session/MultiplayerSession.cs b/Source/Client/Session/MultiplayerSession.cs index c8e1597e0..02d7a5e5f 100644 --- a/Source/Client/Session/MultiplayerSession.cs +++ b/Source/Client/Session/MultiplayerSession.cs @@ -44,8 +44,6 @@ public class MultiplayerSession : IConnectionStatusListener public bool desynced; - public SessionDisconnectInfo disconnectInfo; - public List pendingSteam = new(); public List knownUsers = new(); @@ -114,73 +112,6 @@ public void NotifyChat() SoundDefOf.PageChange.PlayOneShotOnCamera(); } - public void ProcessDisconnectPacket(MpDisconnectReason reason, byte[] data) - { - var reader = new ByteReader(data); - string titleKey = null; - string descKey = null; - - if (reason == MpDisconnectReason.GenericKeyed) titleKey = reader.ReadString(); - - if (reason == MpDisconnectReason.Protocol) - { - titleKey = "MpWrongProtocol"; - - string strVersion = reader.ReadString(); - int proto = reader.ReadInt32(); - - disconnectInfo.wideWindow = true; - disconnectInfo.descTranslated = "MpWrongMultiplayerVersionDesc".Translate(strVersion, proto, MpVersion.Version, MpVersion.Protocol); - - if (proto < MpVersion.Protocol) - disconnectInfo.descTranslated += "\n" + "MpWrongVersionUpdateInfoHost".Translate(); - else - disconnectInfo.descTranslated += "\n" + "MpWrongVersionUpdateInfo".Translate(); - } - - if (reason == MpDisconnectReason.ConnectingFailed) - { - var netReason = reader.ReadEnum(); - - disconnectInfo.titleTranslated = - netReason == DisconnectReason.ConnectionFailed ? - "MpConnectionFailed".Translate() : - "MpConnectionFailedWithInfo".Translate(netReason.ToString().CamelSpace().ToLowerInvariant()); - } - - if (reason == MpDisconnectReason.NetFailed) - { - var netReason = reader.ReadEnum(); - - disconnectInfo.titleTranslated = - "MpDisconnectedWithInfo".Translate(netReason.ToString().CamelSpace().ToLowerInvariant()); - } - - if (reason == MpDisconnectReason.UsernameAlreadyOnline) - { - titleKey = "MpInvalidUsernameAlreadyPlaying"; - descKey = "MpChangeUsernameInfo"; - - var newName = Multiplayer.username.Substring(0, Math.Min(Multiplayer.username.Length, MultiplayerServer.MaxUsernameLength - 3)); - newName += new Random().Next(1000); - - disconnectInfo.specialButtonTranslated = "MpConnectAsUsername".Translate(newName); - disconnectInfo.specialButtonAction = () => Reconnect(newName); - } - - if (reason == MpDisconnectReason.UsernameLength) { titleKey = "MpInvalidUsernameLength"; descKey = "MpChangeUsernameInfo"; } - if (reason == MpDisconnectReason.UsernameChars) { titleKey = "MpInvalidUsernameChars"; descKey = "MpChangeUsernameInfo"; } - if (reason == MpDisconnectReason.ServerClosed) titleKey = "MpServerClosed"; - if (reason == MpDisconnectReason.ServerFull) titleKey = "MpServerFull"; - if (reason == MpDisconnectReason.ServerStarting) titleKey = "MpDisconnectServerStarting"; - if (reason == MpDisconnectReason.Kick) titleKey = "MpKicked"; - if (reason == MpDisconnectReason.ServerPacketRead) descKey = "MpPacketErrorRemote"; - if (reason == MpDisconnectReason.BadGamePassword) descKey = "MpBadGamePassword"; - - disconnectInfo.titleTranslated ??= titleKey?.Translate(); - disconnectInfo.descTranslated ??= descKey?.Translate(); - } - public void Reconnect(string username) { Multiplayer.username = username; @@ -195,11 +126,11 @@ public void Connected() { } - public void Disconnected() + public void Disconnected(SessionDisconnectInfo info) { MpUI.ClearWindowStack(); - Find.WindowStack.Add(new DisconnectedWindow(disconnectInfo) + Find.WindowStack.Add(new DisconnectedWindow(info) { returnToServerBrowser = Multiplayer.Client?.State != ConnectionStateEnum.ClientPlaying }); diff --git a/Source/Client/Session/SessionDisconnectInfo.cs b/Source/Client/Session/SessionDisconnectInfo.cs index 96dfb2419..cbc74ed4f 100644 --- a/Source/Client/Session/SessionDisconnectInfo.cs +++ b/Source/Client/Session/SessionDisconnectInfo.cs @@ -1,4 +1,7 @@ using System; +using LiteNetLib; +using Multiplayer.Common; +using Verse; namespace Multiplayer.Client; @@ -9,4 +12,82 @@ public struct SessionDisconnectInfo public string specialButtonTranslated; public Action specialButtonAction; public bool wideWindow; + + public static SessionDisconnectInfo From(MpDisconnectReason reason, ByteReader reader) + { + var disconnectInfo = new SessionDisconnectInfo(); + string titleKey = null; + string descKey = null; + + if (reason == MpDisconnectReason.GenericKeyed) titleKey = reader.ReadString(); + + if (reason == MpDisconnectReason.Protocol) + { + titleKey = "MpWrongProtocol"; + string strVersion = reader.ReadString(); + int proto = reader.ReadInt32(); + disconnectInfo.wideWindow = true; + disconnectInfo.descTranslated = + "MpWrongMultiplayerVersionDesc".Translate(strVersion, proto, MpVersion.Version, MpVersion.Protocol); + if (proto < MpVersion.Protocol) + disconnectInfo.descTranslated += "\n" + "MpWrongVersionUpdateInfoHost".Translate(); + else + disconnectInfo.descTranslated += "\n" + "MpWrongVersionUpdateInfo".Translate(); + } + + if (reason == MpDisconnectReason.ConnectingFailed) + { + var netReason = reader.ReadEnum(); + disconnectInfo.titleTranslated = + netReason == DisconnectReason.ConnectionFailed + ? "MpConnectionFailed".Translate() + : "MpConnectionFailedWithInfo".Translate(netReason.ToString().CamelSpace().ToLowerInvariant()); + } + + if (reason == MpDisconnectReason.NetFailed) + { + var netReason = reader.ReadEnum(); + disconnectInfo.titleTranslated = + "MpDisconnectedWithInfo".Translate(netReason.ToString().CamelSpace().ToLowerInvariant()); + } + + if (reason == MpDisconnectReason.UsernameAlreadyOnline) + { + titleKey = "MpInvalidUsernameAlreadyPlaying"; + descKey = "MpChangeUsernameInfo"; + var newName = Multiplayer.username.Substring(0, + Math.Min(Multiplayer.username.Length, MultiplayerServer.MaxUsernameLength - 3)); + newName += new Random().Next(1000); + disconnectInfo.specialButtonTranslated = "MpConnectAsUsername".Translate(newName); + var session = Multiplayer.session; + disconnectInfo.specialButtonAction = () => session.Reconnect(newName); + } + + if (reason == MpDisconnectReason.UsernameLength) + { + titleKey = "MpInvalidUsernameLength"; + descKey = "MpChangeUsernameInfo"; + } + + if (reason == MpDisconnectReason.UsernameChars) + { + titleKey = "MpInvalidUsernameChars"; + descKey = "MpChangeUsernameInfo"; + } + + if (reason == MpDisconnectReason.ServerClosed) titleKey = "MpServerClosed"; + if (reason == MpDisconnectReason.ServerFull) titleKey = "MpServerFull"; + if (reason == MpDisconnectReason.ServerStarting) titleKey = "MpDisconnectServerStarting"; + if (reason == MpDisconnectReason.Kick) titleKey = "MpKicked"; + if (reason == MpDisconnectReason.ServerPacketRead) descKey = "MpPacketErrorRemote"; + if (reason == MpDisconnectReason.BadGamePassword) descKey = "MpBadGamePassword"; + + disconnectInfo.titleTranslated ??= titleKey?.Translate(); + disconnectInfo.descTranslated ??= descKey?.Translate(); + + Log.Message($"Processed disconnect packet ({reason}). Title: {disconnectInfo.titleTranslated} ({titleKey}), " + + $"description: {disconnectInfo.descTranslated} ({descKey})"); + + return disconnectInfo; + } } diff --git a/Source/Client/Windows/ConnectingWindow.cs b/Source/Client/Windows/ConnectingWindow.cs index 2d5a68fec..11aef7617 100644 --- a/Source/Client/Windows/ConnectingWindow.cs +++ b/Source/Client/Windows/ConnectingWindow.cs @@ -124,7 +124,7 @@ public override void PostClose() } public void Connected() => result = "MpConnected".Translate(); - public void Disconnected() { } + public void Disconnected(SessionDisconnectInfo info) { } } public class RejoiningWindow : BaseConnectingWindow diff --git a/Source/Common/Networking/ConnectionBase.cs b/Source/Common/Networking/ConnectionBase.cs index 8a7b698e5..541511240 100644 --- a/Source/Common/Networking/ConnectionBase.cs +++ b/Source/Common/Networking/ConnectionBase.cs @@ -264,7 +264,20 @@ private void ExecuteMessageHandler(PacketHandlerInfo handler, Packets packet, By } } - public abstract void Close(MpDisconnectReason reason, byte[]? data = null); + public void Close(MpDisconnectReason reason, byte[]? data = null) + { + // State.IsServer check only used when disconnecting from a self-hosted local server + if (State != ConnectionStateEnum.Disconnected && State.IsServer()) + Send(new ServerDisconnectPacket { reason = reason, data = data ?? [] }); + OnClose(); + } + + protected abstract void OnClose(); + + /// Invoked after a keep alive timer arrives. Only used by the server + public virtual void OnKeepAliveArrived(bool idMatched) + { + } public static byte[] GetDisconnectBytes(MpDisconnectReason reason, byte[]? data = null) { diff --git a/Source/Common/Networking/ConnectionStateEnum.cs b/Source/Common/Networking/ConnectionStateEnum.cs index 564923315..2dceb271a 100644 --- a/Source/Common/Networking/ConnectionStateEnum.cs +++ b/Source/Common/Networking/ConnectionStateEnum.cs @@ -15,3 +15,11 @@ public enum ConnectionStateEnum : byte Count, Disconnected } + +public static class ConnectionStateEnumExt +{ + public static bool IsClient(this ConnectionStateEnum state) => + state is >= ConnectionStateEnum.ClientJoining and <= ConnectionStateEnum.ClientSteam; + public static bool IsServer(this ConnectionStateEnum state) => + state is >= ConnectionStateEnum.ServerJoining and <= ConnectionStateEnum.ServerSteam; +} diff --git a/Source/Common/Networking/LiteNetConnection.cs b/Source/Common/Networking/LiteNetConnection.cs index 0b58dc4b3..07b9068e0 100644 --- a/Source/Common/Networking/LiteNetConnection.cs +++ b/Source/Common/Networking/LiteNetConnection.cs @@ -16,10 +16,10 @@ protected override void SendRaw(byte[] raw, bool reliable) peer.Send(raw, reliable ? DeliveryMethod.ReliableOrdered : DeliveryMethod.Unreliable); } - public override void Close(MpDisconnectReason reason, byte[]? data) + protected override void OnClose() { peer.NetManager.TriggerUpdate(); // todo: is this needed? - peer.NetManager.DisconnectPeer(peer, GetDisconnectBytes(reason, data)); + peer.NetManager.DisconnectPeer(peer); } public override string ToString() diff --git a/Source/Common/Networking/Packet/DisconnectPacket.cs b/Source/Common/Networking/Packet/DisconnectPacket.cs new file mode 100644 index 000000000..8385dbc48 --- /dev/null +++ b/Source/Common/Networking/Packet/DisconnectPacket.cs @@ -0,0 +1,14 @@ +namespace Multiplayer.Common.Networking.Packet; + +[PacketDefinition(Packets.Server_Disconnect)] +public record struct ServerDisconnectPacket : IPacket +{ + public MpDisconnectReason reason; + public byte[] data; + + public void Bind(PacketBuffer buf) + { + buf.BindEnum(ref reason); + buf.BindRemaining(ref data); + } +} diff --git a/Source/Common/Networking/Packets.cs b/Source/Common/Networking/Packets.cs index b7deb0fec..0b3f409e2 100644 --- a/Source/Common/Networking/Packets.cs +++ b/Source/Common/Networking/Packets.cs @@ -62,6 +62,9 @@ public enum Packets : byte Server_Traces, Server_SetFaction, + // All states (Joining, Loading, Playing) + Server_Disconnect, + Count, Special_Steam_Disconnect = 63 // Also the max packet id } From 94982817c1059e7c1e1eb470117f0b7893dc7aa7 Mon Sep 17 00:00:00 2001 From: MhaWay Date: Thu, 26 Mar 2026 22:55:56 +0100 Subject: [PATCH 4/4] Bootstrap: fix net48 runtime compatibility --- Source/Common/ActionQueue.cs | 19 ++- Source/Common/LiteNetManager.cs | 38 ++++-- Source/Common/MultiplayerServer.cs | 2 +- Source/Common/Networking/Packet/IPacket.cs | 6 +- Source/Common/ServerSettings.cs | 2 +- Source/Common/Util/MpReflection.cs | 2 +- Source/Common/Util/TomlSettingsCommon.cs | 149 +++++++++++++++++++++ Source/Server/Server.cs | 33 ++++- Source/Server/Server.csproj | 1 - Source/Server/TomlSettings.cs | 72 ---------- 10 files changed, 226 insertions(+), 98 deletions(-) create mode 100644 Source/Common/Util/TomlSettingsCommon.cs delete mode 100644 Source/Server/TomlSettings.cs diff --git a/Source/Common/ActionQueue.cs b/Source/Common/ActionQueue.cs index b03249ef8..e73e04bf3 100644 --- a/Source/Common/ActionQueue.cs +++ b/Source/Common/ActionQueue.cs @@ -3,10 +3,14 @@ namespace Multiplayer.Common { + // Uses List instead of Queue because Queue is in System.dll on + // .NET Framework but in mscorlib on Unity/Mono. Common compiles against + // Krafs.Rimworld.Ref (Mono) so the emitted reference targets mscorlib, + // which causes a TypeLoadException when the Server runs on .NET Framework. public class ActionQueue { - private Queue queue = new(); - private Queue tempQueue = new(); + private List queue = new(); + private List tempQueue = new(); public void RunQueue(Action errorLogger) { @@ -14,8 +18,7 @@ public void RunQueue(Action errorLogger) { if (queue.Count > 0) { - foreach (Action a in queue) - tempQueue.Enqueue(a); + tempQueue.AddRange(queue); queue.Clear(); } } @@ -23,7 +26,11 @@ public void RunQueue(Action errorLogger) try { while (tempQueue.Count > 0) - tempQueue.Dequeue().Invoke(); + { + var action = tempQueue[0]; + tempQueue.RemoveAt(0); + action.Invoke(); + } } catch (Exception e) { @@ -34,7 +41,7 @@ public void RunQueue(Action errorLogger) public void Enqueue(Action action) { lock (queue) - queue.Enqueue(action); + queue.Add(action); } } } diff --git a/Source/Common/LiteNetManager.cs b/Source/Common/LiteNetManager.cs index f0bfbb79f..81f16889a 100644 --- a/Source/Common/LiteNetManager.cs +++ b/Source/Common/LiteNetManager.cs @@ -27,11 +27,14 @@ public LiteNetManager(MultiplayerServer server) public void Tick() { - foreach (var (_, man) in netManagers) - man.PollEvents(); + foreach (var (_, man) in netManagers.ToArray()) + SafePollEvents(man); - lanManager?.PollEvents(); - arbiter?.PollEvents(); + if (lanManager != null) + SafePollEvents(lanManager); + + if (arbiter != null) + SafePollEvents(arbiter); if (lanManager != null && broadcastTimer % 60 == 0) lanManager.SendBroadcast(Encoding.UTF8.GetBytes("mp-server"), 5100); @@ -46,7 +49,7 @@ public void StartNet() if (server.settings.direct) { var liteNetEndpoints = new Dictionary(); - var split = server.settings.directAddress.Split(MultiplayerServer.EndpointSeparator); + var split = server.settings.directAddress.Split(new[] { MultiplayerServer.EndpointSeparator }); foreach (var str in split) if (Endpoints.TryParse(str, MultiplayerServer.DefaultPort, out var endpoint)) @@ -57,10 +60,10 @@ public void StartNet() liteNetEndpoints.GetOrAddNew(endpoint.Port).ipv6 = endpoint.Address; } - foreach (var (port, endpoint) in liteNetEndpoints) + foreach (var kvp in liteNetEndpoints) { - endpoint.port = port; - netManagers.Add((endpoint, CreateNetManager(endpoint.ipv6 != null ? IPv6Mode.SeparateSocket : IPv6Mode.Disabled))); + kvp.Value.port = kvp.Key; + netManagers.Add((kvp.Value, CreateNetManager(kvp.Value.ipv6 != null ? IPv6Mode.SeparateSocket : IPv6Mode.Disabled))); } foreach (var (endpoint, man) in netManagers) @@ -100,7 +103,7 @@ NetManager CreateNetManager(IPv6Mode ipv6) public void StopNet() { - foreach (var (_, man) in netManagers) + foreach (var (_, man) in netManagers.ToArray()) man.Stop(); netManagers.Clear(); lanManager?.Stop(); @@ -117,6 +120,23 @@ public void OnServerStop() StopNet(); arbiter?.Stop(); } + + private static void SafePollEvents(NetManager manager) + { + try + { + manager.PollEvents(); + } + catch (InvalidOperationException e) when (IsQueueEmptyRace(e)) + { + // LiteNetLib can race while its internal event queue is being drained during disconnect/shutdown. + } + } + + private static bool IsQueueEmptyRace(InvalidOperationException e) + { + return e.Message == "Queue empty." || e.Message == "Coda vuota."; + } } public class LiteNetEndpoint diff --git a/Source/Common/MultiplayerServer.cs b/Source/Common/MultiplayerServer.cs index c05b4ac1e..58ec8c663 100644 --- a/Source/Common/MultiplayerServer.cs +++ b/Source/Common/MultiplayerServer.cs @@ -243,7 +243,7 @@ public void RegisterChatCmd(string cmdName, ChatCmdHandler handler) public void HandleChatCmd(IChatSource source, string cmd) { - var parts = cmd.Split(' '); + var parts = cmd.Split(new[] { ' ' }); var handler = GetChatCmdHandler(parts[0]); if (handler != null) diff --git a/Source/Common/Networking/Packet/IPacket.cs b/Source/Common/Networking/Packet/IPacket.cs index 9acf3392d..21e8bfe9b 100644 --- a/Source/Common/Networking/Packet/IPacket.cs +++ b/Source/Common/Networking/Packet/IPacket.cs @@ -307,11 +307,11 @@ public override void Bind(ref Dictionary obj, Binder bindKey, Bin { if (obj.Count > maxLength) throw new WriterException($"Dictionary too big ({obj.Count}>{maxLength})"); writer.WriteInt32(obj.Count); - foreach (var (key, value) in obj) + foreach (var kvp in obj) { - var k = key; + var k = kvp.Key; bindKey(this, ref k); - var v = value; + var v = kvp.Value; bindValue(this, ref v); } } diff --git a/Source/Common/ServerSettings.cs b/Source/Common/ServerSettings.cs index eef436fbb..8bc6d310f 100644 --- a/Source/Common/ServerSettings.cs +++ b/Source/Common/ServerSettings.cs @@ -4,7 +4,7 @@ namespace Multiplayer.Common { public class ServerSettings { - public string gameName; + public string gameName = ""; public string lanAddress; public string directAddress = $"0.0.0.0:{MultiplayerServer.DefaultPort}"; diff --git a/Source/Common/Util/MpReflection.cs b/Source/Common/Util/MpReflection.cs index d44e25906..e2a394f09 100644 --- a/Source/Common/Util/MpReflection.cs +++ b/Source/Common/Util/MpReflection.cs @@ -79,7 +79,7 @@ public static Type PathType(string memberPath) public static Type? IndexType(string memberPath) { InitPropertyOrField(memberPath); - return indexTypes.GetValueOrDefault(memberPath); + return indexTypes.TryGetValue(memberPath, out var value) ? value : null; } /// diff --git a/Source/Common/Util/TomlSettingsCommon.cs b/Source/Common/Util/TomlSettingsCommon.cs new file mode 100644 index 000000000..2dc23b638 --- /dev/null +++ b/Source/Common/Util/TomlSettingsCommon.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Text; + +namespace Multiplayer.Common.Util +{ + /// + /// Manual TOML reader/writer for ServerSettings. + /// Does NOT use Tomlyn to avoid netstandard 2.1 dependency issues on .NET Framework 4.8. + /// Only supports flat key-value pairs (string, int, float, bool, enum). + /// + public static class TomlSettingsCommon + { + public static ServerSettings Load(string filename) + { + var scribe = new SimpleTomlScribe(); + scribe.ParseFile(filename); + scribe.mode = SimpleTomlMode.Loading; + + ScribeLike.provider = scribe; + + var settings = new ServerSettings(); + settings.ExposeData(); + + return settings; + } + + public static void Save(ServerSettings settings, string filename) + { + var scribe = new SimpleTomlScribe { mode = SimpleTomlMode.Saving }; + ScribeLike.provider = scribe; + + settings.ExposeData(); + + File.WriteAllText(filename, scribe.ToToml()); + } + } + + internal enum SimpleTomlMode + { + Loading, Saving + } + + internal class SimpleTomlScribe : ScribeLike.Provider + { + private readonly Dictionary data = new Dictionary(); + private readonly List> entries = new List>(); + public SimpleTomlMode mode; + + public void ParseFile(string filename) + { + foreach (var line in File.ReadAllLines(filename)) + { + var trimmed = line.Trim(); + if (trimmed.Length == 0 || trimmed.StartsWith("#")) + continue; + + var eqIdx = trimmed.IndexOf('='); + if (eqIdx < 0) + continue; + + var key = trimmed.Substring(0, eqIdx).Trim(); + var val = trimmed.Substring(eqIdx + 1).Trim(); + data[key] = val; + } + } + + public override void Look(ref T value, string label, T defaultValue, bool forceSave) + { + if (mode == SimpleTomlMode.Loading) + { + if (data.TryGetValue(label, out var raw)) + value = ParseValue(raw); + else + value = defaultValue; + } + else + { + entries.Add(new KeyValuePair(label, FormatValue(value))); + } + } + + private static T ParseValue(string raw) + { + var type = typeof(T); + + if (type == typeof(string)) + return (T)(object)Unquote(raw); + + if (type == typeof(bool)) + return (T)(object)(raw.Equals("true", StringComparison.OrdinalIgnoreCase)); + + if (type == typeof(int)) + return (T)(object)int.Parse(raw, CultureInfo.InvariantCulture); + + if (type == typeof(float)) + return (T)(object)float.Parse(raw, CultureInfo.InvariantCulture); + + if (type == typeof(double)) + return (T)(object)double.Parse(raw, CultureInfo.InvariantCulture); + + if (type == typeof(long)) + return (T)(object)long.Parse(raw, CultureInfo.InvariantCulture); + + if (type.IsEnum) + return (T)Enum.Parse(type, Unquote(raw)); + + return (T)Convert.ChangeType(raw, type, CultureInfo.InvariantCulture); + } + + private static string Unquote(string s) + { + if (s.Length >= 2 && s.StartsWith("\"") && s.EndsWith("\"")) + { + s = s.Substring(1, s.Length - 2); + s = s.Replace("\\\"", "\"").Replace("\\\\", "\\"); + } + return s; + } + + private static string FormatValue(T value) + { + if (value == null) return "\"\""; + if (value is bool b) return b ? "true" : "false"; + if (value is int i) return i.ToString(CultureInfo.InvariantCulture); + if (value is float f) return f.ToString(CultureInfo.InvariantCulture); + if (value is double d) return d.ToString(CultureInfo.InvariantCulture); + if (value is long l) return l.ToString(CultureInfo.InvariantCulture); + if (typeof(T).IsEnum) return Quote(value.ToString()); + if (value is string s) return Quote(s); + return Quote(value.ToString()); + } + + private static string Quote(string s) + { + return "\"" + (s ?? "").Replace("\\", "\\\\").Replace("\"", "\\\"") + "\""; + } + + public string ToToml() + { + var sb = new StringBuilder(); + for (int i = 0; i < entries.Count; i++) + sb.AppendLine(entries[i].Key + " = " + entries[i].Value); + return sb.ToString(); + } + } +} diff --git a/Source/Server/Server.cs b/Source/Server/Server.cs index 0d78cd7a0..11dbaca0f 100644 --- a/Source/Server/Server.cs +++ b/Source/Server/Server.cs @@ -1,9 +1,14 @@ using System.IO.Compression; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Threading; using Multiplayer.Common; using Multiplayer.Common.Util; using Server; ServerLog.detailEnabled = true; +Directory.SetCurrentDirectory(AppContext.BaseDirectory); const string settingsFile = "settings.toml"; const string stopCmd = "stop"; @@ -16,9 +21,12 @@ }; if (File.Exists(settingsFile)) - settings = TomlSettings.Load(settingsFile); +{ + settings = TomlSettingsCommon.Load(settingsFile); + if (settings.lan) settings.lanAddress = GetLocalIpAddress() ?? "127.0.0.1"; +} else - TomlSettings.Save(settings, settingsFile); // Save default settings + TomlSettingsCommon.Save(settings, settingsFile); // Save default settings var server = MultiplayerServer.instance = new MultiplayerServer(settings) { @@ -63,7 +71,7 @@ static void LoadSave(MultiplayerServer server, string path) // Parse cmds entry for each map foreach (var entry in zip.GetEntries("maps/*_cmds")) { - var parts = entry.FullName.Split('_'); + var parts = entry.FullName.Split(new[] { '_' }); if (parts.Length == 3) { @@ -75,7 +83,7 @@ static void LoadSave(MultiplayerServer server, string path) // Parse save entry for each map foreach (var entry in zip.GetEntries("maps/*_save")) { - var parts = entry.FullName.Split('_'); + var parts = entry.FullName.Split(new[] { '_' }); if (parts.Length == 3) { @@ -102,6 +110,23 @@ static byte[] Compress(byte[] input) return result.ToArray(); } +static string GetLocalIpAddress() +{ + try + { + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.IP)) + { + socket.Connect("8.8.8.8", 65530); + var endPoint = socket.LocalEndPoint as IPEndPoint; + return endPoint.Address.ToString(); + } + } + catch + { + return Dns.GetHostEntry(Dns.GetHostName()).AddressList.FirstOrDefault(i => i.AddressFamily == AddressFamily.InterNetwork)?.ToString(); + } +} + class ConsoleSource : IChatSource { public void SendMsg(string msg) diff --git a/Source/Server/Server.csproj b/Source/Server/Server.csproj index 4d548aaa7..52a290eb6 100644 --- a/Source/Server/Server.csproj +++ b/Source/Server/Server.csproj @@ -18,7 +18,6 @@ - diff --git a/Source/Server/TomlSettings.cs b/Source/Server/TomlSettings.cs deleted file mode 100644 index cc495a2d2..000000000 --- a/Source/Server/TomlSettings.cs +++ /dev/null @@ -1,72 +0,0 @@ -using Multiplayer.Common; -using Tomlyn; -using Tomlyn.Model; - -namespace Server; - -public static class TomlSettings -{ - public static ServerSettings Load(string filename) - { - var toml = new TomlScribe - { - mode = TomlScribeMode.Loading, - root = Toml.ToModel(File.ReadAllText(filename)) - }; - - ScribeLike.provider = toml; - - var settings = new ServerSettings(); - settings.ExposeData(); - - return settings; - } - - public static void Save(ServerSettings settings, string filename) - { - var toml = new TomlScribe { mode = TomlScribeMode.Saving }; - ScribeLike.provider = toml; - - settings.ExposeData(); - - File.WriteAllText(filename, Toml.FromModel(toml.root)); - } -} - -class TomlScribe : ScribeLike.Provider -{ - public TomlTable root = new(); - public TomlScribeMode mode; - - public override void Look(ref T value, string label, T defaultValue, bool forceSave) - { - if (mode == TomlScribeMode.Loading) - { - if (root.ContainsKey(label)) - { - if (typeof(T).IsEnum) - value = (T)Enum.Parse(typeof(T), (string)root[label]); - else if (root[label] is IConvertible convertible) - value = (T)convertible.ToType(typeof(T), null); - else - value = (T)root[label]; - } - else - { - value = defaultValue; - } - } - else if (mode == TomlScribeMode.Saving) - { - if (typeof(T).IsEnum) - root[label] = value.ToString()!; - else - root[label] = value; - } - } -} - -enum TomlScribeMode -{ - Saving, Loading -}