Skip to content

Commit

Permalink
feat: Analysis Module (#176)
Browse files Browse the repository at this point in the history
* Create branch

* improve trjcat

* add ugly plot_energy

* add time to runmgr

* WIP: radical population analysis

* finish radical population analysis

---------

Co-authored-by: Eric Hartmann <hartmaec@rh05659.villa-bosch.de>
  • Loading branch information
ehhartmann and Eric Hartmann authored Aug 10, 2023
1 parent 031f2a5 commit f37d6c8
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 80 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ plugins =
[options.entry_points]
console_scripts =
kimmdy = kimmdy.cmd:kimmdy
kimmdy-analysis = kimmdy.cmd:analysis
kimmdy-build_examples = kimmdy.cmd:build_examples
219 changes: 219 additions & 0 deletions src/kimmdy/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from typing import Union
from pathlib import Path
import subprocess
import argparse
from math import isclose
import matplotlib.pyplot as plt
import MDAnalysis as mda

from kimmdy.utils import run_shell_cmd
from kimmdy.parsing import read_json, write_json


def get_subdirs(run_dir: Path, steps: Union[list, str]):
## create list of subdirectories of run_dir that match the ones named in steps
subdirs_sorted = sorted(
list(filter(lambda d: d.is_dir(), run_dir.glob("*_*/"))),
key=lambda p: int(p.name.split("_")[0]),
)
if steps == "all":
steps = list(set([x.name.split("_")[1] for x in subdirs_sorted]))
subdirs_matched = list(
filter(lambda d: d.name.split("_")[1] in steps, subdirs_sorted)
)

if not subdirs_matched:
raise ValueError(
f"Could not find directories {steps} in {run_dir}. Thus, no trajectories can be concatenated"
)

return subdirs_matched


def concat_traj(args: argparse.Namespace):
"""Find and concatenate trajectories (.xtc files) from KIMMDY runs."""
run_dir = Path(args.dir).expanduser().resolve()
steps: Union[list, str] = args.steps

## check if step argument is valid
if not isinstance(steps, list):
if not steps in ["all"]:
raise ValueError(f"Steps argument {steps} can not be dealt with.")

subdirs_matched = get_subdirs(run_dir, steps)

## create output dir
(run_dir / "analysis").mkdir(exist_ok=True)
out = run_dir / "analysis" / "concat.xtc"
out = Path(out).expanduser()

## gather trajectories
trajectories = []
tprs = []
for d in subdirs_matched:
trajectories.extend(d.glob("*.xtc"))
tprs.extend(d.glob("*.tpr"))

# trajectories = list(filter(lambda p: "rotref" not in p.stem, trajectories))
trajectories = [str(t) for t in trajectories]
assert (
len(trajectories) > 0
), f"No trrs found to concatenate in {run_dir} with subdirectory names {steps}"

## write concatenated trajectory
run_shell_cmd(
f"gmx trjcat -f {' '.join(trajectories)} -o {str(out.with_name('tmp.xtc'))} -cat",
cwd=run_dir,
)
run_shell_cmd(
f"echo '1 0' | gmx trjconv -f {str(out.with_name('tmp.xtc'))} -s {tprs[0]} -o {str(out)} -center -pbc mol",
cwd=run_dir,
)


def plot_energy(args: argparse.Namespace):
run_dir = Path(args.dir).expanduser().resolve()
steps: Union[list, str] = args.steps
terms_list = args.terms
xvg_entries = ["time"] + terms_list
terms: str = "\n".join(args.terms)

subdirs_matched = get_subdirs(run_dir, steps)

## create output dir
(run_dir / "analysis").mkdir(exist_ok=True)
xvgs_dir = run_dir / "analysis" / "energy_xvgs"
xvgs_dir.mkdir(exist_ok=True)

## gather energy files
edrs = []
for d in subdirs_matched:
edrs.extend(d.glob("*.edr"))
assert (
len(edrs) > 0
), f"No GROMACS energy files in {run_dir} with subdirectory names {steps}"

energy = []
## write energy .xvg files
for edr in edrs:
print(edr.parents[0].name + ".xvg")
xvg = str(xvgs_dir / edr.parents[0].with_suffix(".xvg").name)
run_shell_cmd(
f"echo '{terms} \n\n' | gmx energy -f {str(edr)} -o {xvg}",
cwd=run_dir,
)

## read energy .xvg files
with open(xvg, "r") as f:
energy_raw = f.readlines()
for line in energy_raw:
if line[0] not in ["@", "#"]:
energy.append({k: float(v) for k, v in zip(xvg_entries, line.split())})

## plot energy
snapshot = range(len(energy))
sim_start = [i for i in snapshot if isclose(energy[i]["time"], 0)]
sim_names = [str(edr.parents[0].name).split("_")[1] for edr in edrs]
# diffs =[j-i for i, j in zip(sim_start[:-1],sim_start[1:])]
limy = [energy[0][terms_list[0]], energy[0][terms_list[0]]]
print(sim_start)

for term in terms_list:
val = [x[term] for x in energy]
print(term, min(val), max(val))
limy[0] = min(val) if min(val) < limy[0] else limy[0]
limy[1] = max(val) if max(val) > limy[1] else limy[1]
plt.plot(snapshot, val, label=term)

for i, pos in enumerate(sim_start):
plt.plot([pos, pos], limy, c="k", linewidth=1)
plt.text(pos + 1, limy[1] - 0.05 * (limy[1] - limy[0]), sim_names[i])

plt.xlabel("Snapshot #")
plt.ylabel("Energy [kJ mol-1]")
plt.legend()
plt.savefig(str(run_dir / "analysis" / "energy.png"), dpi=300)

print(limy)


def radical_population(args):
# TODO: weigh radical population by time

select_atoms = args.select_atoms
## set up directory to store radical information
radical_info = {"time": [], "radicals": []}

for curr_dir in args.dir[::-1]:
run_dir = Path(curr_dir).expanduser().resolve()

## find .gro file
subdirs_sorted = sorted(
list(filter(lambda d: d.is_dir(), run_dir.glob("*_*/"))),
key=lambda p: int(p.name.split("_")[0]),
)
for subdir in subdirs_sorted:
gro = list(subdir.glob("*.gro"))
if gro:
break
assert gro

## gather radical info
radical_jsons = run_dir.glob("**/radicals.json")
# print(list(radical_jsons))

## parse radical info
for radical_json in radical_jsons:
data = read_json(radical_json)
for k in radical_info.keys():
radical_info[k].append(data[k])

## create output dir (only goes to first mentioned run_dir)
(run_dir / "analysis").mkdir(exist_ok=True)
out = run_dir / "analysis"
out = Path(out).expanduser()

## write gathered radical info
write_json(radical_info, out / "radical_population.json")

## get info from gro file
u = mda.Universe(str(gro[0]), format="gro")
print(u)
atoms = u.select_atoms(select_atoms)
atoms_identifier = [
"-".join(x)
for x in list(
zip(
[str(resid) for resid in atoms.resids],
[str(name) for name in atoms.names],
)
)
]
atoms_id = atoms.ids
print(atoms_identifier)
print(atoms_id)

## plot fingerprint
counts = {i: 0.0 for i in atoms_id}
n_states = len(radical_info["time"])
for state in range(n_states):
for idx in radical_info["radicals"][state]:
if int(idx) in counts.keys():
counts[int(idx)] += 1 / n_states
print(counts)

plt.bar(x=atoms_identifier, height=counts.values())
plt.xlabel("Atom identifier")
plt.ylabel("Fractional Radical Occupancy")
plt.ylim(0, 1)
plt.xticks(atoms_identifier, rotation=90, ha="right")
plt.tight_layout()
plt.savefig(str(run_dir / "analysis" / "radical_population_fingerprint"), dpi=300)

u.add_TopologyAttr("tempfactors")
atoms = u.select_atoms(select_atoms)
print(list(counts.values()))
print(atoms.tempfactors)
atoms.tempfactors = list(counts.values())
protein = u.select_atoms("protein")
protein.write(str(out))
113 changes: 87 additions & 26 deletions src/kimmdy/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pathlib import Path
import dill
from kimmdy.config import Config
from kimmdy.misc_helper import concat_traj, _build_examples
from kimmdy.analysis import concat_traj, plot_energy, radical_population
from kimmdy.misc_helper import _build_examples
from kimmdy.runmanager import RunManager
from kimmdy.utils import check_gmx_version, increment_logfile
import importlib.resources as pkg_resources
Expand Down Expand Up @@ -45,17 +46,6 @@ def get_cmdline_args():
"--logfile", "-f", type=str, help="logfile", default="kimmdy.log"
)
parser.add_argument("--checkpoint", "-c", type=str, help="checkpoint file")
parser.add_argument(
"--concat",
type=Path,
nargs="?",
const=True,
help=(
"Concatenate trrs of this run"
"Optionally, the run directory can be give"
"Will save as concat.trr in current directory"
),
)

# flag to show available plugins
parser.add_argument(
Expand All @@ -74,6 +64,78 @@ def get_cmdline_args():
return parser.parse_args()


def get_analysis_cmdline_args():
"""
concat :
Don't perform a full KIMMDY run but instead concatenate trajectories
from a previous run.
"""
parser = argparse.ArgumentParser(
description="Welcome to the KIMMDY analysis module"
)
subparsers = parser.add_subparsers(required=True, metavar="module", dest="module")

## trjcat
parser_trjcat = subparsers.add_parser(
name="trjcat", help="Concatenate trajectories of a KIMMDY run"
)
parser_trjcat.add_argument(
"dir", type=str, help="KIMMDY run directory to be analysed."
)
parser_trjcat.add_argument(
"--steps",
"-s",
nargs="*",
default="all",
help=(
"Apply analysis method to subdirectories with these names. Uses all subdirectories by default"
),
)

## plot_energy
parser_plot_energy = subparsers.add_parser(
name="plot_energy", help="Plot GROMACS energy for a KIMMDY run"
)
parser_plot_energy.add_argument(
"dir", type=str, help="KIMMDY run directory to be analysed."
)
parser_plot_energy.add_argument(
"--steps",
"-s",
nargs="*",
default="all",
help=(
"Apply analysis method to subdirectories with these names. Uses all subdirectories by default"
),
)
parser_plot_energy.add_argument(
"--terms",
"-t",
nargs="*",
default=["Potential"],
help=(
"Terms from gmx energy that will be plotted. Uses 'Potential' by default"
),
)

## radical population
parser_radical_population = subparsers.add_parser(
name="radical_population",
help="Plot population of radicals for one or multiple KIMMDY run(s)",
)
parser_radical_population.add_argument(
"dir", nargs="+", help="KIMMDY run directory to be analysed. Can be multiple."
)
parser_radical_population.add_argument(
"--select_atoms",
"-a",
type=str,
help="Atoms chosen for radical population analysis, default is protein (uses MDAnalysis selection syntax)",
default="protein",
)
return parser.parse_args()


def configure_logging(args: argparse.Namespace, color=False):
"""Configure logging.
Expand Down Expand Up @@ -133,15 +195,6 @@ def _run(args: argparse.Namespace):

exit()

if args.concat:
logging.info("KIMMDY will concatenate trrs and exit.")

run_dir = Path().cwd()
if type(args.concat) != bool:
run_dir = args.concat
concat_traj(run_dir)
exit()

logging.info("Welcome to KIMMDY")
logging.info("KIMMDY is running with these command line options:")
logging.info(args)
Expand All @@ -166,7 +219,6 @@ def kimmdy_run(
loglevel: str = "DEBUG",
logfile: Path = Path("kimmdy.log"),
checkpoint: str = "",
concat: bool = False,
show_plugins: bool = False,
show_schema_path: bool = False,
):
Expand All @@ -188,9 +240,6 @@ def kimmdy_run(
File path of the logfile.
checkpoint :
File path if a kimmdy.cpt file to restart KIMMDY from a checkpoint.
concat :
Don't perform a full KIMMDY run but instead concatenate trajectories
from a previous run.
show_plugins :
Show available plugins and exit.
show_schema_path :
Expand All @@ -201,7 +250,6 @@ def kimmdy_run(
loglevel=loglevel,
logfile=logfile,
checkpoint=checkpoint,
concat=concat,
show_plugins=show_plugins,
show_schema_path=show_schema_path,
)
Expand Down Expand Up @@ -235,6 +283,19 @@ def build_examples():
pass


def analysis():
"""Analyse existing KIMMDY runs."""

args = get_analysis_cmdline_args()
print(args)
if args.module == "trjcat":
concat_traj(args)
elif args.module == "plot_energy":
plot_energy(args)
elif args.module == "radical_population":
radical_population(args)


def kimmdy():
"""Run KIMMDY from the command line.
Expand Down
Loading

0 comments on commit f37d6c8

Please sign in to comment.