From 94470d6a626bace05e7fb4dabad88fcf87e6abba Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Fri, 20 Mar 2026 13:53:02 +0100 Subject: [PATCH] Add items() method to nn.Module for state_dict iteration Add items() method to nn.Module that returns an enumerator of (name, tensor) tuples from the module's state_dict. This enables easy iteration over all parameters and persistent buffers, consistent with the existing items() pattern in ModuleDict and ParameterDict. This addresses the core request in issue #1474 by providing the items() API needed for model merging workflows (averaging parameters between models using state_dict + load_state_dict). Changes: - Add virtual items() method to Module class - Add 'new' keyword to ModuleDict.items() and ParameterDict.items() to properly hide the base class method (different return types) - Add tests for items() on simple and nested modules - Add test demonstrating the model merge pattern from the issue Closes #1474 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/TorchSharp/NN/Module.cs | 14 ++++++ src/TorchSharp/NN/ModuleDict.cs | 2 +- src/TorchSharp/NN/ParameterDict.cs | 2 +- test/TorchSharpTest/NN.cs | 69 ++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 50f1c5e98..2c5a90419 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -510,6 +510,20 @@ public virtual void zero_grad(bool set_to_none = true) /// public virtual IEnumerable children() => named_children().Select(np => np.module); + /// + /// Return an enumeration of the module's state_dict key/value pairs. + /// + /// This is equivalent to calling state_dict() and iterating over its entries. + /// Both parameters and persistent buffers are included. + /// + /// An enumerator of (name, tensor) tuples + public virtual IEnumerator<(string name, Tensor value)> items() + { + foreach (var kv in state_dict()) { + yield return (kv.Key, kv.Value); + } + } + /// /// Returns a dictionary containing a whole state of the module. /// diff --git a/src/TorchSharp/NN/ModuleDict.cs b/src/TorchSharp/NN/ModuleDict.cs index 89402fd41..3880228da 100644 --- a/src/TorchSharp/NN/ModuleDict.cs +++ b/src/TorchSharp/NN/ModuleDict.cs @@ -40,7 +40,7 @@ public void clear() /// Return an enumeration of the ParameterDict key/value pairs. /// /// - public IEnumerator<(string, T)> items() => _list.GetEnumerator(); + public new IEnumerator<(string, T)> items() => _list.GetEnumerator(); /// /// Return the ParameterDict keys. diff --git a/src/TorchSharp/NN/ParameterDict.cs b/src/TorchSharp/NN/ParameterDict.cs index b8680eaae..8b555ce0d 100644 --- a/src/TorchSharp/NN/ParameterDict.cs +++ b/src/TorchSharp/NN/ParameterDict.cs @@ -40,7 +40,7 @@ public void clear() /// Return an enumeration of the ParameterDict key/value pairs. /// /// - public IEnumerator<(string, Parameter)> items() => _list.GetEnumerator(); + public new IEnumerator<(string, Parameter)> items() => _list.GetEnumerator(); /// /// Return the ParameterDict keys. diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index f2ed50db3..31e345603 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -3309,6 +3309,75 @@ public void TestCustomComponentName() Assert.True(sd.ContainsKey("_linear2.weight")); } + [Fact] + public void TestModuleItems() + { + var lin = Linear(10, 5, true); + var sd = lin.state_dict(); + var items = new List<(string, Tensor)>(); + + using (var enumerator = lin.items()) { + while (enumerator.MoveNext()) { + items.Add(enumerator.Current); + } + } + + // items() should return the same entries as state_dict() + Assert.Equal(sd.Count, items.Count); + foreach (var (name, value) in items) { + Assert.True(sd.ContainsKey(name)); + Assert.Equal(sd[name].shape, value.shape); + } + } + + [Fact] + public void TestModuleItemsWithSubmodules() + { + var seq = Sequential( + ("lin1", Linear(10, 5)), + ("lin2", Linear(5, 2))); + var sd = seq.state_dict(); + var items = new List<(string, Tensor)>(); + + using (var enumerator = seq.items()) { + while (enumerator.MoveNext()) { + items.Add(enumerator.Current); + } + } + + Assert.Equal(sd.Count, items.Count); + Assert.Contains(items, i => i.Item1 == "lin1.weight"); + Assert.Contains(items, i => i.Item1 == "lin2.weight"); + } + + [Fact] + public void TestModelMergeUsingItemsAndStateDict() + { + // Demonstrate model merging pattern using items() + state_dict() + load_state_dict() + // This is how users would merge models, matching the PyTorch pattern + var model1 = Linear(10, 5, true); + var model2 = Linear(10, 5, true); + + var sd1 = model1.state_dict(); + var sd2 = model2.state_dict(); + + var merged = new Dictionary(); + using (var enumerator = model1.items()) { + while (enumerator.MoveNext()) { + var (name, _) = enumerator.Current; + merged[name] = (sd1[name] + sd2[name]) / 2; + } + } + + model1.load_state_dict(merged); + + // Verify the merged parameters are the average + var finalSd = model1.state_dict(); + foreach (var key in merged.Keys) { + Assert.True(finalSd[key].allclose(merged[key])); + } + } + private class TestModule3 : Module { public TestModule3() : base(nameof(TestModule3)) { RegisterComponents(); }