deepqmc.sampling.sampling_utils — DeepQMC 1.2.0 documentation
from collections.abc import Callable, Iterable from statistics import mean, stdev from typing import Optional import jax import jax.numpy as jnp from ..hamil import MolecularHamiltonian from ..molecule import Molecule from ..parallel import pmap, rng_iterator, select_one_device from ..physics import pairwise_diffs from ..types import ( Ansatz, KeyArray, ParametrizedWaveFunction, Params, PhysicalConfiguration, SamplerState, ) from .base import ElectronSampler from .combined_samplers import ( MoleculeIdxSampler, MultiElectronicStateSampler, MultiNuclearGeometrySampler, ) from .nuclei_samplers import IdleNucleiSampler, no_elec_warp __all__ = ['combine_samplers'][docs] def chain(*samplers) -> ElectronSampler: r""" Combine multiple sampler types, to create advanced sampling schemes. For example :data:`chain(DecorrSampler(10),MetropolisSampler(hamil, tau=1.))` will create a :class:`MetropolisSampler`, where the samples are taken from every 10th MCMC step. The last element of the sampler chain has to be either a :class:`MetropolisSampler` or a :class:`LangevinSampler`. Args: samplers (~deepqmc.sampling.base.ElectronSampler): one or more sampler instances to combine. Returns: :type:`~deepqmc.sampling.base.ElectronSampler`: the combined sampler. """ name = 'Sampler' bases = tuple(map(type, samplers)) for base in bases: name = name.replace('Sampler', base.__name__) chained = type(name, bases, {'__init__': lambda self: None})() for sampler in samplers: chained.__dict__.update(sampler.__dict__) return chained # type: ignore
[docs] def combine_samplers( samplers, hamil: MolecularHamiltonian, wf: ParametrizedWaveFunction ) -> ElectronSampler: r"""Combine samplers to create more advanced sampling schemes. Args: samplers (list[~deepqmc.sampling.base.ElectronSampler]): one or more sampler instances to combine. hamil (~deepqmc.hamil.MolecularHamiltonian): the molecular Hamiltonian. wf (~deepqmc.types.ParametrizedWaveFunction): the wave function to sample. """ sampler = chain(*samplers[:-1], samplers[-1](hamil, wf)) return sampler
def diffs_to_nearest_nuc(r, coords): z = pairwise_diffs(r, coords) idx = jnp.argmin(z[..., -1], axis=-1) return z[jnp.arange(len(r)), idx], idx def crossover_parameter(z, f, charge): z, z2 = z[..., :3], z[..., 3] eps = jnp.finfo(f.dtype).eps z_unit = z / jnp.linalg.norm(z, axis=-1, keepdims=True) f_unit = f / jnp.clip(jnp.linalg.norm(f, axis=-1, keepdims=True), eps, None) Z2z2 = charge**2 * z2 return (1 + jnp.sum(f_unit * z_unit, axis=-1)) / 2 + Z2z2 / (10 * (4 + Z2z2)) def clean_force(force, phys_conf, mol, *, tau): z, idx = jax.vmap(diffs_to_nearest_nuc)(phys_conf.r, phys_conf.R) a = crossover_parameter(z, force, mol.charges[idx]) av2tau = a * jnp.sum(force**2, axis=-1) * tau # av2tau can be small or zero, so the following expression must handle that factor = 2 / (jnp.sqrt(1 + 2 * av2tau) + 1) force = factor[..., None] * force eps = jnp.finfo(phys_conf.r.dtype).eps norm_factor = jnp.minimum( 1.0, jnp.sqrt(z[..., -1]) / (tau * jnp.clip(jnp.linalg.norm(force, axis=-1), eps, None)), ) force = force * norm_factor[..., None] return force def equilibrate( rng: KeyArray, params: Params, molecule_idx_sampler: MoleculeIdxSampler, sampler: MultiNuclearGeometrySampler, state: SamplerState, criterion: Callable[[PhysicalConfiguration], jax.Array], steps: Iterable[int], *, block_size: int, n_blocks: int = 5, allow_early_stopping: bool = True, ): sample_wf = pmap(sampler.sample) buffer_size = block_size * n_blocks buffer: list[float] = [] for step, rng in zip(steps, rng_iterator(rng)): mol_idxs = molecule_idx_sampler.sample() state, phys_conf, stats = sample_wf(rng, state, params, mol_idxs) yield step, state, select_one_device(mol_idxs), stats if allow_early_stopping: buffer = [*buffer[-buffer_size + 1 :], criterion(phys_conf).item()] if len(buffer) < buffer_size: continue b1, b2 = buffer[:block_size], buffer[-block_size:] if abs(mean(b1) - mean(b2)) < min(stdev(b1), stdev(b2)): break def initialize_sampling( rng: KeyArray, hamil: MolecularHamiltonian, ansatz: Ansatz, mols: list[Molecule], electronic_states: int, molecule_batch_size: int, *, elec_sampler, nuc_sampler=None, elec_warp_fn: Optional[Callable] = None, update_nuc_period: Optional[int] = None, elec_equilibration_steps: Optional[int] = None, ) -> tuple[MoleculeIdxSampler, MultiNuclearGeometrySampler]: molecule_idx_sampler = MoleculeIdxSampler( rng, len(mols), molecule_batch_size, 'once' ) elec_sampler = elec_sampler(hamil=hamil, wf=ansatz.apply) multi_state_elec_sampler = MultiElectronicStateSampler( elec_sampler, electronic_states ) nuc_sampler = (IdleNucleiSampler if nuc_sampler is None else nuc_sampler)( hamil.mol.charges, ) elec_warp_fn = no_elec_warp if elec_warp_fn is None else elec_warp_fn sampler = MultiNuclearGeometrySampler( multi_state_elec_sampler, nuc_sampler, elec_warp_fn, update_nuc_period, elec_equilibration_steps, ) return molecule_idx_sampler, sampler def initialize_sampler_state(rng, sampler, params, electron_batch_size, mols): @jax.pmap def sampler_state_initializer(rng, params): return sampler.init( rng, params, electron_batch_size // jax.device_count(), jnp.stack([mol.coords for mol in mols]), ) return sampler_state_initializer(rng, params)