From b52bc9137c9ac274ab721c03d9a9af20d8ad3ad1 Mon Sep 17 00:00:00 2001 From: windy-pig Date: Thu, 12 Jun 2025 18:39:10 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E4=B8=AARBM=E7=BD=91=E7=BB=9C?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=EF=BC=8C=E6=9C=89=E8=BE=83=E5=AE=8C=E6=95=B4?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=A1=A3=E5=92=8C=E7=B1=BB=E5=9E=8B=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=EF=BC=8C=20=E4=BF=AE=E5=A4=8D=E4=BA=86=E5=A4=A7?= =?UTF-8?q?=E9=83=A8=E5=88=86=E5=B0=8F=E9=97=AE=E9=A2=98=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- qmb/rbm.py | 160 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 qmb/rbm.py diff --git a/qmb/rbm.py b/qmb/rbm.py new file mode 100644 index 0000000..9152dd1 --- /dev/null +++ b/qmb/rbm.py @@ -0,0 +1,160 @@ +""" +This file defines the RBM based wave function. +Apart from predicting the magnitude, it also has a unique way of sampling. +""" + +import torch +from .mlp import MLP +from .bitspack import pack_int, unpack_int + + +class RBM(torch.nn.Module): + """The RBM Network. + + It returns the free energy of configurations in forward, + and also samples configurations according to its weight in the sample function. + + Parameters + ---------- + visible_dim : int + The num of visible nodes. + hidden_dim : int + The num of hidden nodes. + gamma : float + The relative magnitude of initialized weight. + """ + + def __init__(self, visible_dim: int, hidden_dim: int, gamma: float) -> None: + + super().__init__() + self.visible_dim = visible_dim + self.hidden_dim = hidden_dim + self.weights = torch.nn.Parameter(torch.zeros([self.visible_dim, self.hidden_dim])) + init_range = gamma / torch.sqrt(torch.tensor(self.visible_dim)) + torch.nn.init.uniform_(self.weights, -init_range, init_range) + self.visible_bias = torch.nn.Parameter(torch.zeros(visible_dim)) + self.hidden_bias = torch.nn.Parameter(torch.zeros(hidden_dim)) + + def forward(self, v: torch.Tensor) -> torch.Tensor: + """Returns the predicted free energy for each row. + + Parameters + ---------- + v : torch.Tensor + Batch of configurations with free energy to be determined. Batch first. + + Returns + ------- + torch.Tensor + Free energy. + """ + e1 = (v @ self.visible_bias).view(v.size()[:-1]) + mid2 = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=not-callable + e2 = mid2.exp().add(1).div(2).log().sum(dim=-1) + return e1 + e2 + + @torch.jit.export + def sample(self, v: torch.Tensor, k: int = 1) -> torch.Tensor: + """Sample configurations by gibbs sampling. + + Parameters + --------- + v : torch.Tensor + The initial configurations to run Gibbs sampling on. + k : int, optional + Rounds of Gibbs sampling (default: 1). + + Returns + ------- + torch.Tensor + Configurations sampled as a tensor. Batch first. + """ + for _ in range(k): + # samp h + midh = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=not-callable + ph = torch.sigmoid(midh) + h = torch.bernoulli(ph) + # samp v + midv = torch.nn.functional.linear(h, self.weights, self.visible_bias) # pylint: disable=not-callable + pv = torch.sigmoid(midv) + v = torch.bernoulli(pv) + return v + + +class WaveFunctionNormal(torch.nn.Module): + """WaveFunction implemented via RBM. + + (The phase is predicted with MLP, however.) + """ + + def __init__( # pylint: disable=R0913 + self, + *, + sites: int, + physical_dim: int, + is_complex: bool, + rbm_hidden_dim: int, + rbm_gamma: float, + mlp_hidden_size: tuple[int, ...], + ) -> None: + super().__init__() + self.sites: int = sites + assert physical_dim == 2 + assert is_complex == True # pylint: disable=singleton-comparison + self.rbm_hidden_dim: int = rbm_hidden_dim + self.rbm_gamma: float = rbm_gamma + self.mlp_hidden_size: tuple[int, ...] = mlp_hidden_size + + # Build Networks + self.magnitude = RBM(self.sites, self.rbm_hidden_dim, rbm_gamma) + self.phase = MLP(self.sites, 1, self.mlp_hidden_size) + + # Dummy Parameter for Device and Dtype Retrieval + # This parameter is used to infer the device and dtype of the model. + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + + @property + def device(self) -> torch.device: + """Device of the model's parameters""" + return self.dummy_param.device + + @property + def dtype(self) -> torch.dtype: + """dtype of the model's parameters""" + return self.dummy_param.dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculates amplitude + + See model_dict.py + """ + batch_size: torch.Size = x.shape[:-1] + x = unpack_int(x, size=1, last_dim=self.sites).view(*batch_size, self.sites) + x_float: torch.Tensor = x.to(dtype=self.dtype) + free_energy: torch.Tensor = self.magnitude(x_float).double() + ln_magnitude: torch.Tensor = free_energy / 2 - free_energy.mean() / 2 # ?? + phase: torch.Tensor = self.phase(x_float).view(*batch_size).double() + return (ln_magnitude + phase * 1j).exp() + + @torch.jit.export + def generate_conf(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Samples configurations(Not unique). + + Parameters + ---------- + batch_size : int + Num of samples + + Returns + ------- + samples : torch.Tensor + Batch of configurations. Batch first. + amplitude : torch.Tensor + The amplitude of each sample. + """ + samples = self.magnitude.sample(torch.bernoulli(torch.ones((batch_size, self.sites), device=self.device) * 0.5)) + samples = pack_int(samples.byte(), size=1) + amplitude = self(samples) + + return samples, amplitude