from typing import Optional, Union, Any
from typing_extensions import override
from pathlib import Path
import warnings
import numpy as np
import MDAnalysis as mda
from ligandparam.stages.abstractstage import AbstractStage
from ligandparam.interfaces import Antechamber
from ligandparam.io.coordinates import Mol2Writer
[docs]
class StageUpdateCharge(AbstractStage):
"""
Create a new mol2 file with updated charges.
Parameters
----------
stage_name : str
The name of the stage.
main_input : Union[Path, str]
Path to the input mol2 file.
cwd : Union[Path, str]
Current working directory.
charge_source : str
Path to the file containing charges.
charge_column : int, optional
Column index in charge_source to use for charges (default: 3).
out_mol2 : str
Path to the output mol2 file.
net_charge : float, optional
Net charge for the molecule (default: 0.0).
atom_type : str, optional
Atom type (default: 'gaff2').
Attributes
----------
in_mol2 : Path
Path to the input mol2 file.
charge_source : str
Path to the file containing charges.
charge_column : int
Column index in charge_source to use for charges.
out_mol2 : Path
Path to the output mol2 file.
tmp_mol2 : Path
Path to the temporary mol2 file.
net_charge : float
Net charge for the molecule.
atom_type : str
Atom type.
"""
@override
def __init__(self, stage_name: str, main_input: Union[Path, str], cwd: Union[Path, str], *args, **kwargs) -> None:
super().__init__(stage_name, main_input, cwd, *args, **kwargs)
self.in_mol2 = Path(main_input)
self.charge_source = kwargs["charge_source"]
self.charge_column = kwargs.get("charge_column", 3)
self.out_mol2 = Path(kwargs["out_mol2"])
self.tmp_mol2 = self.cwd / f"{self.out_mol2.stem}_tmp_update.mol2" # tmpresp
self.net_charge = kwargs.get("net_charge", 0.0)
self.atom_type = kwargs.get("atom_type", "gaff2")
self.add_required(Path(self.in_mol2))
self.add_required(Path(self.charge_source))
return
def _append_stage(self, stage: "AbstractStage") -> "AbstractStage":
return stage
[docs]
def execute(self, dry_run=False, nproc: Optional[int] = None, mem: Optional[int] = None) -> Any:
super()._setup_execution(dry_run=dry_run, nproc=nproc, mem=mem)
# Supress the inevitable mol2 file warnings.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if Path(self.charge_source).exists():
charges = np.genfromtxt(self.charge_source, usecols=self.charge_column, unpack=True)
else:
raise FileNotFoundError(f"File {self.charge_source} not found.")
if not dry_run:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
u = mda.Universe(self.in_mol2, format="mol2")
if len(charges) != len(u.atoms):
raise ValueError("Error: Number of charges does not match the number of atoms.")
u.atoms.charges = charges
# Write the Mol2 temporary file
Mol2Writer(u, self.tmp_mol2, selection="all").write()
ante = Antechamber(cwd=self.cwd, logger=self.logger, nproc=self.nproc)
ante.call(
i=self.tmp_mol2, fi="mol2", o=self.out_mol2, fo="mol2", pf="y", at=self.atom_type, an="no",
nc=self.net_charge, dry_run=dry_run
)
return
def _clean(self):
raise NotImplementedError("clean method not implemented")
[docs]
class StageNormalizeCharge(AbstractStage):
"""
Normalize the charges in a mol2 file to the specified net charge.
This class works by calculating the charge difference, and then normalizing the charges
based on the overall precision that you select, by adjusting each atom charge by the precision
until the charge difference is zero.
Parameters
----------
stage_name : str
The name of the stage.
main_input : Union[Path, str]
Path to the input mol2 file.
cwd : Union[Path, str]
Current working directory.
out_mol2 : str
Path to the output mol2 file.
atom_type : str, optional
Atom type (default: 'gaff2').
net_charge : float, optional
Net charge for the molecule (default: 0.0).
precision : float, optional
Precision for charge normalization (default: 0.0001).
Attributes
----------
in_mol2 : Path
Path to the input mol2 file.
out_mol2 : Path
Path to the output mol2 file.
tmp_mol2 : Path
Path to the temporary mol2 file.
atom_type : str
Atom type.
net_charge : float
Net charge for the molecule.
precision : float
Precision for charge normalization.
decimals : int
Number of decimals for rounding charges.
"""
def __init__(self, stage_name: str, main_input: Union[Path, str], cwd: Union[Path, str], *args, **kwargs) -> None:
super().__init__(stage_name, main_input, cwd, *args, **kwargs)
self.in_mol2 = Path(main_input)
self.out_mol2 = Path(kwargs["out_mol2"])
self.tmp_mol2 = self.cwd / f"{self.in_mol2.stem}_tmp_norm.mol2"
self.atom_type = kwargs.get("atom_type", "gaff2")
self.net_charge = kwargs.get("net_charge", 0.0)
self.precision = kwargs.get("precision", 0.0001)
try:
self.decimals = len(str(self.precision).split(".")[1])
except IndexError:
raise ValueError(f"ERROR: Invalid precision: {self.precision}. It should be a float between 0 and 0.1")
self.add_required(self.in_mol2)
def _append_stage(self, stage: "AbstractStage") -> "AbstractStage":
return stage
[docs]
def execute(self, dry_run=False, nproc: Optional[int] = None, mem: Optional[int] = None) -> Any:
"""
Execute the charge normalization stage.
Parameters
----------
dry_run : bool, optional
If True, the stage will not be executed, but the function will print the commands that would be run.
nproc : int, optional
Number of processors to use.
mem : int, optional
Amount of memory to use (in GB).
Raises
------
ValueError
If the charge normalization fails.
Notes
-----
TODO: Check what happens when netcharge is nonzero.
TODO: Check what happens when charge difference is larger than the number of atoms.
"""
super()._setup_execution(dry_run=dry_run, nproc=nproc, mem=mem)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.logger.debug("Checking charges")
self.logger.debug(f"Normalizing charges to {self.net_charge}")
self.logger.debug(f"Precision {self.precision} with {self.decimals} decimals")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
u = mda.Universe(self.in_mol2, format="mol2")
rounded_charges, total_charge, charge_difference = self.check_charge(u.atoms.charges)
if not np.isclose(total_charge, self.net_charge, rtol=1e-10):
self.logger.info("Normalizing charges")
new_charges = self.normalize(rounded_charges, charge_difference)
_, new_total, new_diff = self.check_charge(new_charges)
if np.isclose(new_total, self.net_charge, rtol=1e-10):
u.atoms.charges = new_charges
else:
raise ValueError(f"Error: Charge normalization failed, new charge: {new_total}.")
else:
self.logger.info("Charges are already normalized")
if not dry_run:
Mol2Writer(u, self.tmp_mol2, selection="all").write()
ante = Antechamber(cwd=self.cwd, logger=self.logger, nproc=self.nproc)
ante.call(
i=self.tmp_mol2, fi="mol2", o=self.out_mol2, fo="mol2", pf="y", at=self.atom_type, an="no",
nc=self.net_charge, dry_run=dry_run
)
def _clean(self):
raise NotImplementedError("clean method not implemented")
[docs]
def normalize(self, charges, charge_difference):
"""
Normalize the charges to the net charge.
Parameters
----------
charges : np.ndarray
Array of atomic charges.
charge_difference : float
The charge difference to be corrected.
Returns
-------
np.ndarray
The normalized charges.
"""
count = np.round(np.abs(charge_difference) / self.precision)
adjust = np.round(charge_difference / count, self.decimals)
natoms = len(charges)
# Choosing charges closest to zero.
sorted_indices = np.argsort(np.abs(charges))
# Flip the order to choose the largest charges first.
sorted_indices = sorted_indices[::-1]
for i in range(int(count)):
atom_idx = i % natoms
charges[sorted_indices[atom_idx]] += adjust
return charges
[docs]
def check_charge(self, charges):
"""
Check the total charge and the charge difference.
Parameters
----------
charges : np.ndarray
Array of atomic charges.
Returns
-------
tuple
charges : np.ndarray
Rounded charges.
total_charge : float
The total charge.
charge_difference : float
The charge difference.
"""
charges = np.round(charges, self.decimals)
total_charge = np.sum(charges)
charge_difference = self.net_charge - total_charge
self.logger.debug(f"-> Total Charge is {total_charge}")
self.logger.debug(f"-> Charge difference is {charge_difference}")
return charges, total_charge, charge_difference