Skip to content

Commit

Permalink
torch install options
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 17, 2024
1 parent 9784f8b commit f6b377d
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 9 deletions.
8 changes: 6 additions & 2 deletions ramannoodle/io/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

from ramannoodle.structure.reference import ReferenceStructure
import ramannoodle.io.vasp as vasp_io
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore

# These map between file formats and appropriate IO functions.
_PHONON_READERS = {
Expand Down Expand Up @@ -189,7 +193,7 @@ def read_structure_and_polarizability(
def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
file_format: str,
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from files.
Parameters
Expand Down
14 changes: 11 additions & 3 deletions ramannoodle/io/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
IncompatibleStructureException,
)
from ramannoodle.globals import ATOM_SYMBOLS
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore


def _skip_file_until_line_contains(file: TextIO, content: str) -> str:
Expand Down Expand Up @@ -95,7 +99,7 @@ def _read_polarizability_dataset(
[str | Path],
tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]],
],
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand All @@ -114,7 +118,11 @@ def _read_polarizability_dataset(
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
ModuleNotFoundError
Torch installation could not be found.
"""
if not dataset.TORCH_PRESENT:
raise ModuleNotFoundError("torch installation not found")
filepaths = pathify_as_list(filepaths)

lattices: list[NDArray[np.float64]] = []
Expand Down Expand Up @@ -143,7 +151,7 @@ def _read_polarizability_dataset(
positions_list.append(positions)
polarizabilities.append(polarizability)

return PolarizabilityDataset(
return dataset.PolarizabilityDataset(
np.array(lattices),
atomic_numbers_list,
np.array(positions_list),
Expand Down
8 changes: 6 additions & 2 deletions ramannoodle/io/vasp/outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore


# Utilities for OUTCAR. Warning: some of these functions partially read files.
Expand Down Expand Up @@ -400,7 +404,7 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand Down
8 changes: 6 additions & 2 deletions ramannoodle/io/vasp/vasprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore


def _get_root_element(file: TextIO) -> Element:
Expand Down Expand Up @@ -195,7 +199,7 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> PolarizabilityDataset:
) -> dataset.PolarizabilityDataset:
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand Down
2 changes: 2 additions & 0 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ramannoodle.exceptions import verify_ndarray_shape, verify_list_len, get_type_error
import ramannoodle.polarizability.torch.utils as rn_torch_utils

TORCH_PRESENT = True


def _scale_and_flatten_polarizabilities(
polarizabilities: Tensor,
Expand Down
44 changes: 44 additions & 0 deletions ramannoodle/polarizability/torch/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Dummy polarizability PyTorch dataset.
Used when torch installation cannot be found.
:meta private:
"""

import numpy as np
from numpy.typing import NDArray

TORCH_PRESENT = False


class PolarizabilityDataset: # pylint: disable=too-few-public-methods
"""PyTorch dataset of atomic structures and polarizabilities.
Polarizabilities are scaled and flattened into vectors containing the six
independent tensor components.
Parameters
----------
lattices
| (Å) 3D array with shape (S,3,3) where S is the number of samples.
atomic_numbers
| List of length S containing lists of length N, where N is the number of atoms.
positions
| (fractional) 3D array with shape (S,N,3).
polarizabilities
| 3D array with shape (S,3,3).
scale_mode
| Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by
| standard deviation), and ``"none"`` (no scaling).
"""

def __init__( # pylint: disable=too-many-arguments
self,
lattices: NDArray[np.float64],
atomic_numbers: list[list[int]],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale_mode: str = "standard",
):
raise ModuleNotFoundError("torch installation not found")

0 comments on commit f6b377d

Please sign in to comment.