-
Notifications
You must be signed in to change notification settings - Fork 1
Dev/add rbm network #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
windy-pig
wants to merge
1
commit into
main
Choose a base branch
from
dev/add-rbm-network
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| """ | ||
| This file defines the RBM based wave function. | ||
| Apart from predicting the magnitude, it also has a unique way of sampling. | ||
| """ | ||
|
|
||
| import torch | ||
| from .mlp import MLP | ||
| from .bitspack import pack_int, unpack_int | ||
|
|
||
|
|
||
| class RBM(torch.nn.Module): | ||
| """The RBM Network. | ||
|
|
||
| It returns the free energy of configurations in forward, | ||
| and also samples configurations according to its weight in the sample function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| visible_dim : int | ||
| The num of visible nodes. | ||
| hidden_dim : int | ||
| The num of hidden nodes. | ||
| gamma : float | ||
| The relative magnitude of initialized weight. | ||
| """ | ||
|
|
||
| def __init__(self, visible_dim: int, hidden_dim: int, gamma: float) -> None: | ||
|
|
||
| super().__init__() | ||
| self.visible_dim = visible_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.weights = torch.nn.Parameter(torch.zeros([self.visible_dim, self.hidden_dim])) | ||
| init_range = gamma / torch.sqrt(torch.tensor(self.visible_dim)) | ||
| torch.nn.init.uniform_(self.weights, -init_range, init_range) | ||
| self.visible_bias = torch.nn.Parameter(torch.zeros(visible_dim)) | ||
| self.hidden_bias = torch.nn.Parameter(torch.zeros(hidden_dim)) | ||
|
|
||
| def forward(self, v: torch.Tensor) -> torch.Tensor: | ||
| """Returns the predicted free energy for each row. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| v : torch.Tensor | ||
| Batch of configurations with free energy to be determined. Batch first. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Free energy. | ||
| """ | ||
| e1 = (v @ self.visible_bias).view(v.size()[:-1]) | ||
| mid2 = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=not-callable | ||
| e2 = mid2.exp().add(1).div(2).log().sum(dim=-1) | ||
| return e1 + e2 | ||
|
|
||
| @torch.jit.export | ||
| def sample(self, v: torch.Tensor, k: int = 1) -> torch.Tensor: | ||
| """Sample configurations by gibbs sampling. | ||
|
|
||
| Parameters | ||
| --------- | ||
| v : torch.Tensor | ||
| The initial configurations to run Gibbs sampling on. | ||
| k : int, optional | ||
| Rounds of Gibbs sampling (default: 1). | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Configurations sampled as a tensor. Batch first. | ||
| """ | ||
| for _ in range(k): | ||
| # samp h | ||
| midh = torch.nn.functional.linear(v, self.weights.T, self.hidden_bias) # pylint: disable=not-callable | ||
| ph = torch.sigmoid(midh) | ||
| h = torch.bernoulli(ph) | ||
| # samp v | ||
| midv = torch.nn.functional.linear(h, self.weights, self.visible_bias) # pylint: disable=not-callable | ||
| pv = torch.sigmoid(midv) | ||
| v = torch.bernoulli(pv) | ||
| return v | ||
|
|
||
|
|
||
| class WaveFunctionNormal(torch.nn.Module): | ||
| """WaveFunction implemented via RBM. | ||
|
|
||
| (The phase is predicted with MLP, however.) | ||
| """ | ||
|
|
||
| def __init__( # pylint: disable=R0913 | ||
| self, | ||
| *, | ||
| sites: int, | ||
| physical_dim: int, | ||
| is_complex: bool, | ||
| rbm_hidden_dim: int, | ||
| rbm_gamma: float, | ||
| mlp_hidden_size: tuple[int, ...], | ||
| ) -> None: | ||
| super().__init__() | ||
| self.sites: int = sites | ||
| assert physical_dim == 2 | ||
| assert is_complex == True # pylint: disable=singleton-comparison | ||
| self.rbm_hidden_dim: int = rbm_hidden_dim | ||
| self.rbm_gamma: float = rbm_gamma | ||
| self.mlp_hidden_size: tuple[int, ...] = mlp_hidden_size | ||
|
|
||
| # Build Networks | ||
| self.magnitude = RBM(self.sites, self.rbm_hidden_dim, rbm_gamma) | ||
| self.phase = MLP(self.sites, 1, self.mlp_hidden_size) | ||
|
|
||
| # Dummy Parameter for Device and Dtype Retrieval | ||
| # This parameter is used to infer the device and dtype of the model. | ||
| self.dummy_param = torch.nn.Parameter(torch.empty(0)) | ||
|
|
||
| @property | ||
| def device(self) -> torch.device: | ||
| """Device of the model's parameters""" | ||
| return self.dummy_param.device | ||
|
|
||
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| """dtype of the model's parameters""" | ||
| return self.dummy_param.dtype | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Calculates amplitude | ||
|
|
||
| See model_dict.py | ||
| """ | ||
| batch_size: torch.Size = x.shape[:-1] | ||
| x = unpack_int(x, size=1, last_dim=self.sites).view(*batch_size, self.sites) | ||
| x_float: torch.Tensor = x.to(dtype=self.dtype) | ||
| free_energy: torch.Tensor = self.magnitude(x_float).double() | ||
| ln_magnitude: torch.Tensor = free_energy / 2 - free_energy.mean() / 2 # ?? | ||
| phase: torch.Tensor = self.phase(x_float).view(*batch_size).double() | ||
| return (ln_magnitude + phase * 1j).exp() | ||
|
|
||
| @torch.jit.export | ||
| def generate_conf(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Samples configurations(Not unique). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch_size : int | ||
| Num of samples | ||
|
|
||
| Returns | ||
| ------- | ||
| samples : torch.Tensor | ||
| Batch of configurations. Batch first. | ||
| amplitude : torch.Tensor | ||
| The amplitude of each sample. | ||
| """ | ||
| samples = self.magnitude.sample(torch.bernoulli(torch.ones((batch_size, self.sites), device=self.device) * 0.5)) | ||
| samples = pack_int(samples.byte(), size=1) | ||
| amplitude = self(samples) | ||
|
|
||
| return samples, amplitude | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer
assert is_complexover comparing toTrueto simplify the expression and satisfy linters without disabling them.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用管他