Source code for ligandparam.stages.charge

from typing import Optional, Union, Any
from typing_extensions import override
from pathlib import Path
import warnings

import numpy as np
import MDAnalysis as mda

from ligandparam.stages.abstractstage import AbstractStage
from ligandparam.interfaces import Antechamber
from ligandparam.io.coordinates import Mol2Writer


[docs] class StageUpdateCharge(AbstractStage): """This class creates a new mol2 file with updated charges.""" @override def __init__(self, stage_name: str, main_input: Union[Path, str], cwd: Union[Path, str], *args, **kwargs) -> None: super().__init__(stage_name, main_input, cwd, *args, **kwargs) self.in_mol2 = Path(main_input) self.charge_source = kwargs["charge_source"] self.charge_column = kwargs.get("charge_column", 3) self.out_mol2 = Path(kwargs["out_mol2"]) self.tmp_mol2 = self.cwd / f"{self.out_mol2.stem}_tmp_update.mol2" # tmpresp self.net_charge = kwargs.get("net_charge", 0.0) self.atom_type = kwargs.get("atom_type", "gaff2") self.add_required(Path(self.in_mol2)) self.add_required(Path(self.charge_source)) return def _append_stage(self, stage: "AbstractStage") -> "AbstractStage": return stage
[docs] def execute(self, dry_run=False, nproc: Optional[int] = None, mem: Optional[int] = None) -> Any: super()._setup_execution(dry_run=dry_run, nproc=nproc, mem=mem) # Supress the inevitable mol2 file warnings. with warnings.catch_warnings(): warnings.simplefilter("ignore") if Path(self.charge_source).exists(): charges = np.genfromtxt(self.charge_source, usecols=self.charge_column, unpack=True) else: raise FileNotFoundError(f"File {self.charge_source} not found.") if not dry_run: with warnings.catch_warnings(): warnings.simplefilter("ignore") u = mda.Universe(self.in_mol2, format="mol2") if len(charges) != len(u.atoms): raise ValueError("Error: Number of charges does not match the number of atoms.") u.atoms.charges = charges # Write the Mol2 temporary file Mol2Writer(u, self.tmp_mol2, selection="all").write() ante = Antechamber(cwd=self.cwd, logger=self.logger, nproc=self.nproc) ante.call( i=self.tmp_mol2, fi="mol2", o=self.out_mol2, fo="mol2", pf="y", at=self.atom_type, an="no", nc=self.net_charge, dry_run=dry_run ) return
def _clean(self): raise NotImplementedError("clean method not implemented")
[docs] class StageNormalizeCharge(AbstractStage): """This class normalizes the charges to the net charge. This class works by calculating the charge difference, and then normalizing the charges based on the overall precision that you select, by adjusting each atom charge by the precision until the charge difference is zero.""" def __init__(self, stage_name: str, main_input: Union[Path, str], cwd: Union[Path, str], *args, **kwargs) -> None: super().__init__(stage_name, main_input, cwd, *args, **kwargs) self.in_mol2 = Path(main_input) self.out_mol2 = Path(kwargs["out_mol2"]) self.tmp_mol2 = self.cwd / f"{self.in_mol2.stem}_tmp_norm.mol2" self.atom_type = kwargs.get("atom_type", "gaff2") self.net_charge = kwargs.get("net_charge", 0.0) self.precision = kwargs.get("precision", 0.0001) try: self.decimals = len(str(self.precision).split(".")[1]) except IndexError: raise ValueError(f"ERROR: Invalid precision: {self.precision}. It should be a float between 0 and 0.1") self.add_required(self.in_mol2) def _append_stage(self, stage: "AbstractStage") -> "AbstractStage": return stage
[docs] def execute(self, dry_run=False, nproc: Optional[int] = None, mem: Optional[int] = None) -> Any: """Execute the stage. Raises ------ ValueError If the charge normalization fails TODO: Check what happens when netcharge is nonzero TODO: Check what happens when charge difference is larger than the number of atoms """ super()._setup_execution(dry_run=dry_run, nproc=nproc, mem=mem) with warnings.catch_warnings(): warnings.simplefilter("ignore") self.logger.debug("Checking charges") self.logger.debug(f"Normalizing charges to {self.net_charge}") self.logger.debug(f"Precision {self.precision} with {self.decimals} decimals") with warnings.catch_warnings(): warnings.simplefilter("ignore") u = mda.Universe(self.in_mol2, format="mol2") rounded_charges, total_charge, charge_difference = self.check_charge(u.atoms.charges) if not np.isclose(total_charge, self.net_charge, rtol=1e-10): self.logger.info("Normalizing charges") new_charges = self.normalize(rounded_charges, charge_difference) _, new_total, new_diff = self.check_charge(new_charges) if np.isclose(new_total, self.net_charge, rtol=1e-10): u.atoms.charges = new_charges else: raise ValueError(f"Error: Charge normalization failed, new charge: {new_total}.") else: self.logger.info("Charges are already normalized") if not dry_run: Mol2Writer(u, self.tmp_mol2, selection="all").write() ante = Antechamber(cwd=self.cwd, logger=self.logger, nproc=self.nproc) ante.call( i=self.tmp_mol2, fi="mol2", o=self.out_mol2, fo="mol2", pf="y", at=self.atom_type, an="no", nc=self.net_charge, dry_run=dry_run )
def _clean(self): raise NotImplementedError("clean method not implemented")
[docs] def normalize(self, charges, charge_difference): """This function normalizes the charges to the net charge. Parameters ---------- charges : np.array The charges charge_difference : float The charge difference Returns ------- charges : np.array The normalized charges """ count = np.round(np.abs(charge_difference) / self.precision) adjust = np.round(charge_difference / count, self.decimals) natoms = len(charges) # Choosing charges closest to zero. sorted_indices = np.argsort(np.abs(charges)) # Flip the order to choose the largest charges first. sorted_indices = sorted_indices[::-1] for i in range(int(count)): atom_idx = i % natoms charges[sorted_indices[atom_idx]] += adjust return charges
[docs] def check_charge(self, charges): """This function checks the total charge and the charge difference. Parameters ---------- charges : np.array The charges Returns ------- total_charge : float The total charge charge_difference : float The charge difference """ charges = np.round(charges, self.decimals) total_charge = np.sum(charges) charge_difference = self.net_charge - total_charge self.logger.debug(f"-> Total Charge is {total_charge}") self.logger.debug(f"-> Charge difference is {charge_difference}") return charges, total_charge, charge_difference