Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,20 @@ public virtual void zero_grad(bool set_to_none = true)
/// </summary>
public virtual IEnumerable<Module> children() => named_children().Select(np => np.module);

/// <summary>
/// Return an enumeration of the module's state_dict key/value pairs.
///
/// This is equivalent to calling state_dict() and iterating over its entries.
/// Both parameters and persistent buffers are included.
/// </summary>
/// <returns>An enumerator of (name, tensor) tuples</returns>
public virtual IEnumerator<(string name, Tensor value)> items()
{
foreach (var kv in state_dict()) {
yield return (kv.Key, kv.Value);
}
Comment on lines +522 to +524
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module.items() is implemented by calling state_dict(), which allocates and populates a new Dictionary<string, Tensor> on every enumeration. Since this API is intended for iteration, consider yielding directly from named_parameters(recurse: true) and named_buffers(recurse: true, include_nonpersistent: false) to avoid the intermediate dictionary allocation and duplicate traversal work.

Suggested change
foreach (var kv in state_dict()) {
yield return (kv.Key, kv.Value);
}
foreach (var p in named_parameters(recurse: true)) {
yield return (p.name, p.Item2);
}
foreach (var b in named_buffers(recurse: true, include_nonpersistent: false)) {
yield return (b.name, b.Item2);
}

Copilot uses AI. Check for mistakes.
}

/// <summary>
/// Returns a dictionary containing a whole state of the module.
///
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/NN/ModuleDict.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void clear()
/// Return an enumeration of the ParameterDict key/value pairs.
/// </summary>
/// <returns></returns>
public IEnumerator<(string, T)> items() => _list.GetEnumerator();
public new IEnumerator<(string, T)> items() => _list.GetEnumerator();

/// <summary>
/// Return the ParameterDict keys.
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/NN/ParameterDict.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void clear()
/// Return an enumeration of the ParameterDict key/value pairs.
/// </summary>
/// <returns></returns>
public IEnumerator<(string, Parameter)> items() => _list.GetEnumerator();
public new IEnumerator<(string, Parameter)> items() => _list.GetEnumerator();

/// <summary>
/// Return the ParameterDict keys.
Expand Down
69 changes: 69 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3309,6 +3309,75 @@ public void TestCustomComponentName()
Assert.True(sd.ContainsKey("_linear2.weight"));
}

[Fact]
public void TestModuleItems()
{
var lin = Linear(10, 5, true);
var sd = lin.state_dict();
var items = new List<(string, Tensor)>();

using (var enumerator = lin.items()) {
while (enumerator.MoveNext()) {
items.Add(enumerator.Current);
}
}

// items() should return the same entries as state_dict()
Assert.Equal(sd.Count, items.Count);
foreach (var (name, value) in items) {
Assert.True(sd.ContainsKey(name));
Assert.Equal(sd[name].shape, value.shape);
}
}
Comment on lines +3312 to +3331
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new items() API is documented to include persistent buffers as well as parameters, but the added tests only exercise parameters (Linear / Sequential). Consider adding an assertion using a module with persistent buffers (e.g., BatchNorm) to verify that items() exposes buffer entries too (e.g., running_mean / running_var) and matches state_dict() for those keys.

Copilot uses AI. Check for mistakes.

[Fact]
public void TestModuleItemsWithSubmodules()
{
var seq = Sequential(
("lin1", Linear(10, 5)),
("lin2", Linear(5, 2)));
var sd = seq.state_dict();
var items = new List<(string, Tensor)>();

using (var enumerator = seq.items()) {
while (enumerator.MoveNext()) {
items.Add(enumerator.Current);
}
}

Assert.Equal(sd.Count, items.Count);
Assert.Contains(items, i => i.Item1 == "lin1.weight");
Assert.Contains(items, i => i.Item1 == "lin2.weight");
}

[Fact]
public void TestModelMergeUsingItemsAndStateDict()
{
// Demonstrate model merging pattern using items() + state_dict() + load_state_dict()
// This is how users would merge models, matching the PyTorch pattern
var model1 = Linear(10, 5, true);
var model2 = Linear(10, 5, true);

var sd1 = model1.state_dict();
var sd2 = model2.state_dict();

var merged = new Dictionary<string, Tensor>();
using (var enumerator = model1.items()) {
while (enumerator.MoveNext()) {
var (name, _) = enumerator.Current;
merged[name] = (sd1[name] + sd2[name]) / 2;
Comment on lines +3365 to +3368
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TestModelMergeUsingItemsAndStateDict, the averaging operation is performed outside a torch.no_grad() context. Since state_dict() returns parameters that typically have requires_grad=true, the arithmetic will build an autograd graph and retain references unnecessarily. Wrapping the merge computation in torch.no_grad() (or detaching/cloning the source tensors) would better reflect the recommended model-merge pattern and avoid extra graph/memory overhead.

Suggested change
using (var enumerator = model1.items()) {
while (enumerator.MoveNext()) {
var (name, _) = enumerator.Current;
merged[name] = (sd1[name] + sd2[name]) / 2;
using (var _ = no_grad()) {
using (var enumerator = model1.items()) {
while (enumerator.MoveNext()) {
var (name, _) = enumerator.Current;
merged[name] = (sd1[name] + sd2[name]) / 2;
}

Copilot uses AI. Check for mistakes.
}
}

model1.load_state_dict(merged);

// Verify the merged parameters are the average
var finalSd = model1.state_dict();
foreach (var key in merged.Keys) {
Assert.True(finalSd[key].allclose(merged[key]));
}
}

private class TestModule3 : Module<Tensor, Tensor>
{
public TestModule3() : base(nameof(TestModule3)) { RegisterComponents(); }
Expand Down
Loading