From ca296b46e90295dd266d555a57913740976f7003 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Fri, 18 Jul 2025 01:44:56 +0800 Subject: [PATCH 1/4] Add RL based imaginary time evolution algorithm. --- qmb/__main__.py | 1 + qmb/rlim.py | 146 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 qmb/rlim.py diff --git a/qmb/__main__.py b/qmb/__main__.py index fd45fe3..d477bed 100644 --- a/qmb/__main__.py +++ b/qmb/__main__.py @@ -12,6 +12,7 @@ from . import vmc as _ # type: ignore[no-redef] from . import imag as _ # type: ignore[no-redef] from . import rldiag as _ # type: ignore[no-redef] +from . import rlim as _ # type: ignore[no-redef] from . import precompile as _ # type: ignore[no-redef] from . import list_loss as _ # type: ignore[no-redef] from . import chop_imag as _ # type: ignore[no-redef] diff --git a/qmb/rlim.py b/qmb/rlim.py new file mode 100644 index 0000000..cd3416c --- /dev/null +++ b/qmb/rlim.py @@ -0,0 +1,146 @@ +""" +This file implements a reinforcement learning based imaginary time evolution algorithm. +""" + +import logging +import typing +import dataclasses +import torch +import torch.utils.tensorboard +import tyro +from .common import CommonConfig +from .subcommand_dict import subcommand_dict +from .optimizer import initialize_optimizer + + +@dataclasses.dataclass +class RlimConfig: + """ + The reinforcement learning based imaginary time evolution algorithm. + """ + + # pylint: disable=too-many-instance-attributes + + common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes] + + # The sampling count + sampling_count: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 4000 + # The number of relative configurations to be used in energy calculation + relative_count: typing.Annotated[int, tyro.conf.arg(aliases=["-c"])] = 40000 + # The learning rate for the local optimizer + learning_rate: typing.Annotated[float, tyro.conf.arg(aliases=["-r"])] = 1e-3 + # The learning rate for the imaginary time evolution + evolution_time: typing.Annotated[float, tyro.conf.arg(aliases=["-t"])] = 1e-3 + # The number of steps for the local optimizer + local_step: typing.Annotated[int, tyro.conf.arg(aliases=["-s"])] = 32 + + def main(self) -> None: + """ + The main function for the RLIM optimization. + """ + # pylint: disable=too-many-statements + # pylint: disable=too-many-locals + + model, network, data = self.common.main() + ref_network = network + + logging.info( + "Arguments Summary: " + "Sampling Count: %d, " + "Relative Count: %d, " + "Learning Rate: %.10f, " + "Evolution Time: %.10f, " + "Local Steps: %d, ", + self.sampling_count, + self.relative_count, + self.learning_rate, + self.evolution_time, + self.local_step, + ) + + optimizer = initialize_optimizer( + network.parameters(), + use_lbfgs=False, + learning_rate=self.learning_rate, + state_dict=data.get("optimizer"), + ) + + if "rlim" not in data: + data["rlim"] = {"global": 0, "local": 0} + + writer = torch.utils.tensorboard.SummaryWriter(log_dir=self.common.folder()) # type: ignore[no-untyped-call] + + while True: + logging.info("Starting a new optimization cycle") + + logging.info("Sampling configurations") + configs_i, psi_i, _, _ = network.generate_unique(self.sampling_count) + ref_configs_i, ref_psi_i, _, _ = ref_network.generate_unique(self.sampling_count) + logging.info("Sampling completed, unique configurations count: %d, reference unique configurations count: %d", len(configs_i), len(ref_configs_i)) + + logging.info("Calculating relative configurations") + if self.relative_count <= len(configs_i): + configs_src = configs_i + configs_dst = configs_i + else: + configs_src = configs_i + configs_dst = torch.cat([configs_i, model.find_relative(configs_i, psi_i, self.relative_count - len(configs_i))]) + logging.info("Relative configurations calculated, count: %d", len(configs_dst)) + if self.relative_count <= len(ref_configs_i): + ref_configs_src = ref_configs_i + ref_configs_dst = ref_configs_i + else: + ref_configs_src = ref_configs_i + ref_configs_dst = torch.cat([ref_configs_i, model.find_relative(ref_configs_i, ref_psi_i, self.relative_count - len(ref_configs_i))]) + logging.info("Reference relative configurations calculated, count: %d", len(ref_configs_dst)) + + def closure() -> torch.Tensor: + # Optimizing loss + optimizer.zero_grad() + psi_src = network(configs_src) # psi s + ref_psi_src = network(ref_configs_src) # psi r + with torch.no_grad(): + psi_dst = network(configs_dst) # psi s' + ref_psi_dst = network(ref_configs_dst) # psi r' + hamiltonian_psi_dst = model.apply_within(configs_dst, psi_dst, configs_src) # H ss' psi s' + ref_hamiltonian_psi_dst = model.apply_within(ref_configs_dst, ref_psi_dst, ref_configs_src) # H rr' psi r' + a = torch.outer(psi_src.detach().conj(), ref_psi_src) - torch.outer(psi_src.conj(), ref_psi_src.detach()) + b = torch.outer(hamiltonian_psi_dst.conj(), ref_psi_src) - torch.outer(psi_src.conj(), ref_hamiltonian_psi_dst) + diff = (a - self.evolution_time * b).flatten() + loss = (diff.conj() @ diff).real + loss.backward() # type: ignore[no-untyped-call] + # Calculate energy + with torch.no_grad(): + num = psi_src.conj() @ hamiltonian_psi_dst + den = psi_src.conj() @ psi_src + energy = (num / den).real + loss.energy = energy # type: ignore[attr-defined] + return loss + + logging.info("Starting local optimization process") + + for i in range(self.local_step): + loss: torch.Tensor = optimizer.step(closure) # type: ignore[assignment,arg-type] + energy: float = loss.energy # type: ignore[attr-defined] + logging.info("Local optimization in progress, step: %d, loss: %.10f, energy: %.10f, ref energy: %.10f, energy error: %.10f", i, loss.item(), energy, model.ref_energy, + energy - model.ref_energy) + writer.add_scalar("rlim/energy", energy, data["rlim"]["local"]) # type: ignore[no-untyped-call] + writer.add_scalar("rlim/error", energy - model.ref_energy, data["rlim"]["local"]) # type: ignore[no-untyped-call] + writer.add_scalar("rlim/loss", loss, data["rlim"]["local"]) # type: ignore[no-untyped-call] + data["rlim"]["local"] += 1 + + logging.info("Local optimization process completed") + + writer.flush() # type: ignore[no-untyped-call] + + logging.info("Saving model checkpoint") + data["rlim"]["global"] += 1 + data["network"] = network.state_dict() + data["optimizer"] = optimizer.state_dict() + self.common.save(data, data["rlim"]["global"]) + logging.info("Checkpoint successfully saved") + + logging.info("Current optimization cycle completed") + + +subcommand_dict["rlim"] = RlimConfig From 7edd1175659b58ecc67f3ca7dd1ea4944249ab86 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Fri, 18 Jul 2025 11:05:47 +0800 Subject: [PATCH 2/4] Add normalization layers in the mlp networks. --- qmb/mlp.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/qmb/mlp.py b/qmb/mlp.py index 1db3c81..16ed579 100644 --- a/qmb/mlp.py +++ b/qmb/mlp.py @@ -48,14 +48,22 @@ def __init__(self, dim_input: int, dim_output: int, hidden_size: tuple[int, ...] dimensions: list[int] = [dim_input] + list(hidden_size) + [dim_output] linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in itertools.pairwise(dimensions)] - modules: list[torch.nn.Module] = [module for linear in linears for module in (linear, torch.nn.SiLU())][:-1] - self.model: torch.nn.Module = torch.nn.Sequential(*modules) + self.layers: torch.nn.ModuleList = torch.nn.ModuleList(linears) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the MLP. """ - return self.model(x) + for index, layer in enumerate(self.layers): + y = layer(x) + if x.shape != y.shape: + x = y + else: + x = x + y + if index != len(self.layers) - 1: + x = torch.nn.functional.normalize(x, dim=-1) + x = torch.nn.functional.silu(x) + return x class WaveFunctionElectronUpDown(torch.nn.Module): From af7e16823b273e1bbfa0d4eece28fe724e6a51b3 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Fri, 18 Jul 2025 03:22:25 +0800 Subject: [PATCH 3/4] Add dropout in the loss function. --- qmb/rlim.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/qmb/rlim.py b/qmb/rlim.py index cd3416c..42fa779 100644 --- a/qmb/rlim.py +++ b/qmb/rlim.py @@ -33,6 +33,8 @@ class RlimConfig: evolution_time: typing.Annotated[float, tyro.conf.arg(aliases=["-t"])] = 1e-3 # The number of steps for the local optimizer local_step: typing.Annotated[int, tyro.conf.arg(aliases=["-s"])] = 32 + # The dropout of the loss function + dropout: typing.Annotated[float, tyro.conf.arg(aliases=["-d"])] = 0.5 def main(self) -> None: """ @@ -50,12 +52,14 @@ def main(self) -> None: "Relative Count: %d, " "Learning Rate: %.10f, " "Evolution Time: %.10f, " - "Local Steps: %d, ", + "Local Steps: %d, " + "Dropout: %.2f", self.sampling_count, self.relative_count, self.learning_rate, self.evolution_time, self.local_step, + self.dropout, ) optimizer = initialize_optimizer( @@ -106,8 +110,8 @@ def closure() -> torch.Tensor: ref_hamiltonian_psi_dst = model.apply_within(ref_configs_dst, ref_psi_dst, ref_configs_src) # H rr' psi r' a = torch.outer(psi_src.detach().conj(), ref_psi_src) - torch.outer(psi_src.conj(), ref_psi_src.detach()) b = torch.outer(hamiltonian_psi_dst.conj(), ref_psi_src) - torch.outer(psi_src.conj(), ref_hamiltonian_psi_dst) - diff = (a - self.evolution_time * b).flatten() - loss = (diff.conj() @ diff).real + diff = torch.nn.functional.dropout(torch.view_as_real(a - self.evolution_time * b).abs(), p=self.dropout).flatten() + loss = diff @ diff loss.backward() # type: ignore[no-untyped-call] # Calculate energy with torch.no_grad(): From 0e183647217545a333f657fab0d403b0fa40a1b3 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Fri, 18 Jul 2025 11:58:25 +0800 Subject: [PATCH 4/4] Remove useless conjugate in rlim. --- qmb/rlim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qmb/rlim.py b/qmb/rlim.py index 42fa779..325ed2f 100644 --- a/qmb/rlim.py +++ b/qmb/rlim.py @@ -108,8 +108,8 @@ def closure() -> torch.Tensor: ref_psi_dst = network(ref_configs_dst) # psi r' hamiltonian_psi_dst = model.apply_within(configs_dst, psi_dst, configs_src) # H ss' psi s' ref_hamiltonian_psi_dst = model.apply_within(ref_configs_dst, ref_psi_dst, ref_configs_src) # H rr' psi r' - a = torch.outer(psi_src.detach().conj(), ref_psi_src) - torch.outer(psi_src.conj(), ref_psi_src.detach()) - b = torch.outer(hamiltonian_psi_dst.conj(), ref_psi_src) - torch.outer(psi_src.conj(), ref_hamiltonian_psi_dst) + a = torch.outer(psi_src.detach(), ref_psi_src) - torch.outer(psi_src, ref_psi_src.detach()) + b = torch.outer(hamiltonian_psi_dst, ref_psi_src) - torch.outer(psi_src, ref_hamiltonian_psi_dst) diff = torch.nn.functional.dropout(torch.view_as_real(a - self.evolution_time * b).abs(), p=self.dropout).flatten() loss = diff @ diff loss.backward() # type: ignore[no-untyped-call]