From 1e3a1f5cdfb1d28c24460f98048ebb64ad32527c Mon Sep 17 00:00:00 2001 From: Huanchen Zhai Date: Wed, 17 Jul 2024 22:30:59 -0700 Subject: [PATCH] core: rescale const --- pyblock2/driver/core.py | 41 ++++++++++++++++++++++++++++++++++++ src/dmrg/sweep_algorithm.hpp | 4 ++-- tests/cr2-gs/SVP | 2 +- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/pyblock2/driver/core.py b/pyblock2/driver/core.py index 2426ae51..5c0cdbdd 100644 --- a/pyblock2/driver/core.py +++ b/pyblock2/driver/core.py @@ -656,6 +656,7 @@ def __init__( self.pg = "c1" self.orb_sym = None self.ghamil = None + self.n_elec = None self._dmrg = None self._sweep_wfn_spectra = None @@ -874,6 +875,7 @@ def initialize_system( pg_irrep = self.pg_irrep else: pg_irrep = 0 + self.n_elec = n_elec if target is None and bw.qargs is not None: if bw.qargs == ("U1Fermi", "AbelianPG"): @@ -3195,6 +3197,7 @@ def get_qc_mpo( normal_order_ref=None, normal_order_single_ref=None, normal_order_wick=True, + rescale=None, symmetrize=True, sum_mpo_mod=-1, compute_accurate_svd_error=True, @@ -3319,6 +3322,14 @@ def get_qc_mpo( Only have effect if ``normal_order_ref is not None``. If True, will use ``WickNormalOrder`` implementation (via automatic symbolic derivation). Otherwise, will use the manual implementation. Default is True. + rescale : None or float or True + If None, will not rescale (default). + If zero or True, will adjust ``h1e`` and the const energy so that + the average diagonal element of ``h1e`` is zero. + If non-zero float, will adjust ``h1e`` and the const energy so that + the const energy becomes the given ``rescale`` number. + After rescale, the integrals will only be correct for the given + ``n_elec``. symmetrize : bool Only have effect if ``self.orb_sym is not None`` (when point group symmetry is used). If True, will symmetrize integrals so that integral elements violating point group restrictions @@ -3439,6 +3450,29 @@ def get_qc_mpo( x_orb_sym, h1e=h1e, g2e=g2e, k_symm=k_symm, iprint=iprint ) + if rescale is not None: + assert h1e is not None + assert self.n_elec is not None + if iprint >= 1: + print("original const = ", ecore) + if SymmetryTypes.SZ in bw.symm_type: + xn = len(h1e[0]) + len(h1e[1]) + x = np.trace(h1e[0]) + np.trace(h1e[1]) + else: + xn, x = len(h1e), np.trace(h1e) + if isinstance(rescale, int) and rescale == 0: + x = x / xn + else: + x = (rescale - ecore) / self.n_elec + if SymmetryTypes.SZ in bw.symm_type: + h1e[0][np.mgrid[:len(h1e[0])], np.mgrid[:len(h1e[0])]] -= x + h1e[1][np.mgrid[:len(h1e[1])], np.mgrid[:len(h1e[1])]] -= x + else: + h1e[np.mgrid[:len(h1e)], np.mgrid[:len(h1e)]] -= x + ecore += x * self.n_elec + if iprint >= 1: + print("rescaled const = ", ecore) + if integral_cutoff != 0: error = 0 if SymmetryTypes.SZ in bw.symm_type: @@ -3536,6 +3570,10 @@ def get_qc_mpo( self.ghamil = bw.bs.GeneralHamiltonian( self.vacuum, self.n_sites, self.orb_sym, self.heis_twos ) + if normal_order_ref is not None: + normal_order_ref = np.array(normal_order_ref)[idx] + if normal_order_single_ref is not None: + normal_order_single_ref = np.array(normal_order_single_ref)[idx] else: self.reorder_idx = None @@ -7157,6 +7195,7 @@ def get_random_mps( mps : MPS The output MPS (normalized). """ + import numpy as np bw = self.bw if target is None: target = self.target @@ -7201,6 +7240,8 @@ def get_random_mps( else: mps_info.set_bond_dimension_fci(left_vacuum, self.vacuum) if occs is not None: + if self.reorder_idx is not None: + occs = np.array(occs)[self.reorder_idx] mps_info.set_bond_dimension_using_occ(bond_dim, bw.b.VectorDouble(occs)) else: mps_info.set_bond_dimension(bond_dim) diff --git a/src/dmrg/sweep_algorithm.hpp b/src/dmrg/sweep_algorithm.hpp index 31cd45d9..d59bec8e 100644 --- a/src/dmrg/sweep_algorithm.hpp +++ b/src/dmrg/sweep_algorithm.hpp @@ -227,7 +227,7 @@ template struct DMRG { false); int mmps = 0; FPS error = 0.0; - tuple pdi; + tuple pdi; shared_ptr> pket = nullptr, context_pket = nullptr; shared_ptr> pdm = nullptr; @@ -839,7 +839,7 @@ template struct DMRG { } int mmps = 0; FPS error = 0.0; - tuple pdi; + tuple pdi; shared_ptr> pket = nullptr, context_pket = nullptr; shared_ptr> context_old_ket = nullptr; diff --git a/tests/cr2-gs/SVP b/tests/cr2-gs/SVP index 79d43f73..2ceb1ff2 100644 --- a/tests/cr2-gs/SVP +++ b/tests/cr2-gs/SVP @@ -8,7 +8,7 @@ BASIS "ao basis" PRINT -#BASIS SET: (14s,9p,5d) -> [5s,3p,2d] +#BASIS SET: (14s,9p,5d) -> [5s,2p,2d] Cr S 51528.0863490 0.14405823106E-02 7737.2103487 0.11036202287E-01