diff --git a/CHANGELOG.md b/CHANGELOG.md index dec045c993..9bddb143c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added gradient clipping to StaticCapture utilities. - Bistride Multiscale MeshGraphNet example. - FIGConvUNet model and example. +- The Transolver model. ### Changed diff --git a/docs/img/transolver.png b/docs/img/transolver.png new file mode 100644 index 0000000000..07e31966e3 Binary files /dev/null and b/docs/img/transolver.png differ diff --git a/examples/cfd/darcy_transolver/README.md b/examples/cfd/darcy_transolver/README.md new file mode 100644 index 0000000000..6e910abbd2 --- /dev/null +++ b/examples/cfd/darcy_transolver/README.md @@ -0,0 +1,48 @@ + + +# Transolver for Darcy Flow + +This example demonstrates how to set up a data-driven model for a 2D Darcy flow using +the Transolver inside of Modulus. + +

+ +

+ +Training progress can be tracked through [MLFlow](https://mlflow.org/docs/latest/index.html). +This example runs on a single GPU. + +## Getting Started + +To train the model following modulus's settings, simply run + +```bash +python train_transolver_darcy.py +``` + +Each batch is a new data generated by equation, which is different from commonly-used settings. + +To reproduce the results in the paper, run + +```bash +python train_transolver_darcy_fix.py +``` + +In this case, the train set and test set are fixed after the construction of Dataset, +corresponding to Transolver's setting. + +## Additional Information + +In the fixed case, extra data is needed for training and the data path should be added when +Darcy_2D_fix dataset is constructed. You can download the data +[here](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-). + +More components are added for convenience. `Validators` calculate the loss between +ground-truth and prediction, and visualize them in `./mlruns`. Below is a simple example +of visualization. + +[![visualization](https://s21.ax1x.com/2024/09/26/pAlis3T.png)](https://imgse.com/i/pAlis3T) + +## References + +- [Transolver: A Fast Transformer Solver for PDEs on General Geometries](https://arxiv.org/abs/2402.02366) diff --git a/examples/cfd/darcy_transolver/config.yaml b/examples/cfd/darcy_transolver/config.yaml new file mode 100644 index 0000000000..de43f220a3 --- /dev/null +++ b/examples/cfd/darcy_transolver/config.yaml @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +model: + space_dim: 2 + n_layers: 3 + n_hidden: 128 + dropout: 0.0 + n_head: 8 + Time_Input: False + act: gelu + mlp_ratio: 1 + fun_dim: 0 + out_dim: 1 + slice_dim: 32 + ref: 8 + unified_pos: 1 + slice_num: 32 + + + +normaliser: + permeability: + mean: 1.25 + std_dev: .75 + darcy: + mean: 4.52E-2 + std_dev: 2.79E-2 + +scheduler: + initial_lr: 1.E-3 + decay_rate: .85 + decay_pseudo_epochs: 8 + +training: + resolution: 256 + batch_size: 8 + rec_results_freq : 8 + max_pseudo_epochs: 256 + pseudo_epoch_sample_size: 2048 + +validation: + sample_size: 256 + validation_pseudo_epochs: 4 \ No newline at end of file diff --git a/examples/cfd/darcy_transolver/config_fix.yaml b/examples/cfd/darcy_transolver/config_fix.yaml new file mode 100644 index 0000000000..dd97eeb8d7 --- /dev/null +++ b/examples/cfd/darcy_transolver/config_fix.yaml @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +model: + space_dim: 2 + n_layers: 8 + n_hidden: 128 + dropout: 0.0 + n_head: 8 + Time_Input: False + act: gelu + mlp_ratio: 1 + fun_dim: 1 + out_dim: 1 + slice_dim: 32 + ref: 8 + unified_pos: 1 + slice_num: 64 + + + +normaliser: + permeability: + mean: 1.25 + std_dev: .75 + darcy: + mean: 4.52E-2 + std_dev: 2.79E-2 + +scheduler: + initial_lr: 1.E-3 + decay_rate: 1.E-5 + weight_decay: 1.E-5 + decay_pseudo_epochs: 8 + +training: + resolution: 85 + batch_size: 4 + rec_results_freq : 100 + max_pseudo_epochs: 500 + pseudo_epoch_sample_size: 1000 + +validation: + sample_size: 200 + validation_pseudo_epochs: 2 diff --git a/examples/cfd/darcy_transolver/darcy_datapipe_fix.py b/examples/cfd/darcy_transolver/darcy_datapipe_fix.py new file mode 100644 index 0000000000..2c14469630 --- /dev/null +++ b/examples/cfd/darcy_transolver/darcy_datapipe_fix.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from dataclasses import dataclass +from typing import Dict, Tuple, Union + +import numpy as np +import torch +import warp as wp +import scipy.io as scio + +from ..datapipe import Datapipe +from ..meta import DatapipeMetaData +from .kernels.finite_difference import ( + darcy_mgrid_jacobi_iterative_batched_2d, + mgrid_inf_residual_batched_2d, +) +from .kernels.initialization import init_uniform_random_4d +from .kernels.utils import ( + bilinear_upsample_batched_2d, + fourier_to_array_batched_2d, + threshold_3d, +) + +Tensor = torch.Tensor +# TODO unsure if better to remove this. Keeping this in for now +wp.init() + + +class UnitTransformer: + """Unit transformer class for normalizing and denormalizing data.""" + + def __init__(self, X): + self.mean = X.mean(dim=(0, 1), keepdim=True) + self.std = X.std(dim=(0, 1), keepdim=True) + 1e-8 + + def to(self, device): + self.mean = self.mean.to(device) + self.std = self.std.to(device) + return self + + def cuda(self): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def cpu(self): + self.mean = self.mean.cpu() + self.std = self.std.cpu() + + def encode(self, x): + x = (x - self.mean) / (self.std) + return x + + def decode(self, x): + return x * self.std + self.mean + + def transform(self, X, inverse=True, component="all"): + if component == "all" or "all-reduce": + if inverse: + orig_shape = X.shape + return (X * (self.std - 1e-8) + self.mean).view(orig_shape) + else: + return (X - self.mean) / self.std + else: + if inverse: + orig_shape = X.shape + return ( + X * (self.std[:, component] - 1e-8) + self.mean[:, component] + ).view(orig_shape) + else: + return (X - self.mean[:, component]) / self.std[:, component] + + +@dataclass +class MetaData(DatapipeMetaData): + name: str = "Darcy2D" + # Optimization + auto_device: bool = True + cuda_graphs: bool = True + # Parallel + ddp_sharding: bool = False + + +class Darcy2D_fix(Datapipe): + """2D Darcy flow benchmark problem datapipe. + + This datapipe continuously generates solutions to the 2D Darcy equation with variable + permeability. All samples are generated on the fly and is meant to be a benchmark + problem for testing data driven models. Permeability is drawn from a random Fourier + series and threshold it to give a piecewise constant function. The solution is obtained + using a GPU enabled multi-grid Jacobi iterative method. + + Parameters + ---------- + resolution : int, optional + Resolution to run simulation at, by default 256 + batch_size : int, optional + Batch size of simulations, by default 64 + nr_permeability_freq : int, optional + Number of frequencies to use for generating random permeability. Higher values + will give higher freq permeability fields., by default 5 + max_permeability : float, optional + Max permeability, by default 2.0 + min_permeability : float, optional + Min permeability, by default 0.5 + max_iterations : int, optional + Maximum iterations to use for each multi-grid, by default 30000 + convergence_threshold : float, optional + Solver L-Infinity convergence threshold, by default 1e-6 + iterations_per_convergence_check : int, optional + Number of Jacobi iterations to run before checking convergence, by default 1000 + nr_multigrids : int, optional + Number of multi-grid levels, by default 4 + normaliser : Union[Dict[str, Tuple[float, float]], None], optional + Dictionary with keys `permeability` and `darcy`. The values for these keys are two floats corresponding to mean and std `(mean, std)`. + device : Union[str, torch.device], optional + Device for datapipe to run place data on, by default "cuda" + + Raises + ------ + ValueError + Incompatable multi-grid and resolution settings + """ + + def __init__( + self, + resolution: int = 256, + batch_size: int = 64, + nr_permeability_freq: int = 5, + max_permeability: float = 2.0, + min_permeability: float = 0.5, + max_iterations: int = 30000, + convergence_threshold: float = 1e-6, + iterations_per_convergence_check: int = 1000, + nr_multigrids: int = 4, + normaliser: Union[Dict[str, Tuple[float, float]], None] = None, + device: Union[str, torch.device] = "cuda", + train_path: str = None, + is_test: bool = False, + x_normalizer: UnitTransformer = None, + y_normalizer: UnitTransformer = None, + ): + super().__init__(meta=MetaData()) + + # simulation params + self.resolution = resolution + self.batch_size = batch_size + self.nr_permeability_freq = nr_permeability_freq + self.max_permeability = max_permeability + self.min_permeability = min_permeability + self.max_iterations = max_iterations + self.convergence_threshold = convergence_threshold + self.iterations_per_convergence_check = iterations_per_convergence_check + self.nr_multigrids = nr_multigrids + self.normaliser = normaliser + + # check normaliser keys + if self.normaliser is not None: + if not {"permeability", "darcy"}.issubset(set(self.normaliser.keys())): + raise ValueError( + "normaliser need to have keys permeability and darcy with mean and std" + ) + + # Set up device for warp, warp has same naming convention as torch. + if isinstance(device, torch.device): + device = str(device) + self.device = device + + # spatial dims + self.dx = 1.0 / (self.resolution + 1) # pad edges by 1 for multi-grid + self.dim = (self.batch_size, self.resolution + 1, self.resolution + 1) + self.fourier_dim = ( + 4, + self.batch_size, + self.nr_permeability_freq, + self.nr_permeability_freq, + ) + + # assert resolution is compatible with multi-grid method + # if (resolution % 2 ** (nr_multigrids - 1)) != 0: + # raise ValueError("Resolution is incompatible with number of sub grids.") + + # allocate arrays for constructing dataset + self.darcy0 = wp.zeros(self.dim, dtype=float, device=self.device) + self.darcy1 = wp.zeros(self.dim, dtype=float, device=self.device) + self.permeability = wp.zeros(self.dim, dtype=float, device=self.device) + self.rand_fourier = wp.zeros(self.fourier_dim, dtype=float, device=self.device) + self.inf_residual = wp.zeros([1], dtype=float, device=self.device) + self.train_path = train_path + self.downsample = 5 + self.r = self.downsample + self.h = int(((421 - 1) / self.r) + 1) + self.s = self.h + # print(f"=============={self.s}===============") + self.dx = 1.0 / self.s + + # Output tenors + self.output_k = None + self.output_p = None + + self.is_test = is_test + + if not self.is_test: + n_train = 1000 + else: + n_train = 200 + self.n_train = n_train + + if self.train_path is not None: + self.__get_data__() + + if not self.is_test: + self.x_normalizer = UnitTransformer(self.x_train) + self.y_normalizer = UnitTransformer(self.y_train) + + self.x_train = self.x_normalizer.encode(self.x_train) + self.y_train = self.y_normalizer.encode(self.y_train) + else: + self.x_train = x_normalizer.encode(self.x_train) + + def __get_normalizer__(self): + return self.x_normalizer, self.y_normalizer + + def __get_data__(self): + x = np.linspace(0, 1, self.s) + y = np.linspace(0, 1, self.s) + x, y = np.meshgrid(x, y) + pos = np.c_[x.ravel(), y.ravel()] + pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0).cuda() + self.x_train = scio.loadmat(self.train_path)["coeff"][ + : self.n_train, :: self.r, :: self.r + ][:, : self.s, : self.s] + self.x_train = self.x_train.reshape(self.n_train, -1) + self.x_train = torch.from_numpy(self.x_train).float().cuda() + self.y_train = scio.loadmat(self.train_path)["sol"][ + : self.n_train, :: self.r, :: self.r + ][:, : self.s, : self.s] + self.y_train = self.y_train.reshape(self.n_train, -1) + self.y_train = torch.from_numpy(self.y_train).float().cuda() + self.pos_train = pos.repeat(self.n_train, 1, 1) + + def __iter__(self): + """ + Yields + ------ + Iterator[Tuple[Tensor, Tensor]] + Infinite iterator that returns a batch of (permeability, darcy pressure) + fields of size [batch, resolution, resolution] + """ + # infinite generator + while True: + idx = np.random.choice(200, self.batch_size) + x = self.x_train[idx] + y = self.y_train[idx] + pos = self.pos_train[idx] + yield pos, x, y + + def __len__(self): + return self.n_train // self.batch_size diff --git a/examples/cfd/darcy_transolver/train_transolver_darcy.py b/examples/cfd/darcy_transolver/train_transolver_darcy.py new file mode 100644 index 0000000000..10525f5368 --- /dev/null +++ b/examples/cfd/darcy_transolver/train_transolver_darcy.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +from omegaconf import DictConfig +from math import ceil + +from torch.nn import MSELoss +from utils.testloss import TestLoss +from torch.optim import Adam, lr_scheduler + +from modulus.models.transolver import Transolver +from modulus.datapipes.benchmarks.darcy import Darcy2D +from modulus.distributed import DistributedManager +from modulus.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad +from modulus.launch.utils import load_checkpoint, save_checkpoint +from modulus.launch.logging import PythonLogger, LaunchLogger, initialize_mlflow + +from validator import GridValidator + + +@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") +def darcy_trainer(cfg: DictConfig) -> None: + """Training for the 2D Darcy flow benchmark problem.""" + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() # call if required elsewhere + + # initialize monitoring + log = PythonLogger(name="darcy_transolver") + log.file_logging() + initialize_mlflow( + experiment_name=f"Darcy_Transolver", + experiment_desc=f"training a Transformer-based PDE solver for the Darcy problem", + run_name=f"Darcy Transolver training", + run_desc=f"training Transolver for Darcy", + user_name="Haixu Wu, Huakun Luo, Haowen Wang", + mode="offline", + ) + LaunchLogger.initialize(use_mlflow=True) # Modulus launch logger + + # define model, loss, optimiser, scheduler, data loader + model = Transolver( + space_dim=cfg.model.space_dim, + n_layers=cfg.model.n_layers, + n_hidden=cfg.model.n_hidden, + dropout=cfg.model.dropout, + n_head=cfg.model.n_head, + Time_Input=cfg.model.Time_Input, + act=cfg.model.act, + mlp_ratio=cfg.model.mlp_ratio, + fun_dim=cfg.model.fun_dim, + out_dim=cfg.model.out_dim, + slice_num=cfg.model.slice_num, + ref=cfg.model.ref, + unified_pos=cfg.model.unified_pos, + H=cfg.training.resolution, + W=cfg.training.resolution, + ).to(dist.device) + loss_fun = TestLoss(size_average=False) + optimizer = Adam(model.parameters(), lr=cfg.scheduler.initial_lr) + scheduler = lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda step: cfg.scheduler.decay_rate**step + ) + norm_vars = cfg.normaliser + normaliser = { + "permeability": (norm_vars.permeability.mean, norm_vars.permeability.std_dev), + "darcy": (norm_vars.darcy.mean, norm_vars.darcy.std_dev), + } + dataloader = Darcy2D( + resolution=cfg.training.resolution, + batch_size=cfg.training.batch_size, + normaliser=normaliser, + ) + validator = GridValidator(loss_fun=TestLoss(size_average=False), norm=normaliser) + + ckpt_args = { + "path": f"./checkpoints", + "optimizer": optimizer, + "scheduler": scheduler, + "models": model, + } + loaded_pseudo_epoch = load_checkpoint(device=dist.device, **ckpt_args) + + # calculate steps per pseudo epoch + steps_per_pseudo_epoch = ceil( + cfg.training.pseudo_epoch_sample_size / cfg.training.batch_size + ) + validation_iters = ceil(cfg.validation.sample_size / cfg.training.batch_size) + log_args = { + "name_space": "train", + "num_mini_batch": steps_per_pseudo_epoch, + "epoch_alert_freq": 1, + } + if cfg.training.pseudo_epoch_sample_size % cfg.training.batch_size != 0: + log.warning( + f"increased pseudo_epoch_sample_size to multiple of \ + batch size: {steps_per_pseudo_epoch*cfg.training.batch_size}" + ) + if cfg.validation.sample_size % cfg.training.batch_size != 0: + log.warning( + f"increased validation sample size to multiple of \ + batch size: {validation_iters*cfg.training.batch_size}" + ) + + # define forward passes for training and inference + @StaticCaptureTraining( + model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + ) + def forward_train(invars, target): + pred = model(invars) + loss = loss_fun(pred, target) + return loss + + @StaticCaptureEvaluateNoGrad( + model=model, logger=log, use_amp=False, use_graphs=False + ) + def forward_eval(invars): + return model(invars) + + if loaded_pseudo_epoch == 0: + log.success("Training started...") + else: + log.warning(f"Resuming training from pseudo epoch {loaded_pseudo_epoch+1}.") + + for pseudo_epoch in range( + max(1, loaded_pseudo_epoch + 1), cfg.training.max_pseudo_epochs + 1 + ): + # Wrap epoch in launch logger for console / MLFlow logs + with LaunchLogger(**log_args, epoch=pseudo_epoch) as logger: + for _, batch in zip(range(steps_per_pseudo_epoch), dataloader): + loss = forward_train(batch["permeability"], batch["darcy"]) + logger.log_minibatch({"loss": loss.detach()}) + logger.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]}) + + # save checkpoint + if pseudo_epoch % cfg.training.rec_results_freq == 0: + save_checkpoint(**ckpt_args, epoch=pseudo_epoch) + + # validation step + if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0: + with LaunchLogger("valid", epoch=pseudo_epoch) as logger: + total_loss = 0.0 + for _, batch in zip(range(validation_iters), dataloader): + val_loss = validator.compare( + batch["permeability"], + batch["darcy"], + forward_eval(batch["permeability"]), + pseudo_epoch, + logger, + ) + total_loss += val_loss + logger.log_epoch({"Validation error": total_loss / validation_iters}) + + # update learning rate + if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0: + scheduler.step() + + save_checkpoint(**ckpt_args, epoch=cfg.training.max_pseudo_epochs) + log.success("Training completed *yay*") + + +if __name__ == "__main__": + darcy_trainer() diff --git a/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py new file mode 100644 index 0000000000..b8ddf03600 --- /dev/null +++ b/examples/cfd/darcy_transolver/train_transolver_darcy_fix.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import hydra +from omegaconf import DictConfig +from math import ceil + +from torch.nn import MSELoss +from utils.testloss import TestLoss +from torch.optim import Adam, lr_scheduler, AdamW + +from modulus.models.transolver import Transolver +from modulus.distributed import DistributedManager +from modulus.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad +from modulus.launch.utils import load_checkpoint, save_checkpoint +from modulus.launch.logging import PythonLogger, LaunchLogger, initialize_mlflow + +from darcy_datapipe_fix import Darcy2D_fix +from validator_fix import GridValidator + + +class UnitTransformer: + """Unit transformer class for normalizing and denormalizing data.""" + + def __init__(self, X): + self.mean = X.mean(dim=(0, 1), keepdim=True) + self.std = X.std(dim=(0, 1), keepdim=True) + 1e-8 + + def to(self, device): + self.mean = self.mean.to(device) + self.std = self.std.to(device) + return self + + def cuda(self): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def cpu(self): + self.mean = self.mean.cpu() + self.std = self.std.cpu() + + def encode(self, x): + x = (x - self.mean) / (self.std) + return x + + def decode(self, x): + return x * self.std + self.mean + + def transform(self, X, inverse=True, component="all"): + if component == "all" or "all-reduce": + if inverse: + orig_shape = X.shape + return (X * (self.std - 1e-8) + self.mean).view(orig_shape) + else: + return (X - self.mean) / self.std + else: + if inverse: + orig_shape = X.shape + return ( + X * (self.std[:, component] - 1e-8) + self.mean[:, component] + ).view(orig_shape) + else: + return (X - self.mean[:, component]) / self.std[:, component] + + +def count_parameters(model): + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + params = parameter.numel() + total_params += params + print(f"Total Trainable Params: {total_params}") + return total_params + + +@hydra.main(version_base="1.3", config_path=".", config_name="config_fix.yaml") +def darcy_trainer(cfg: DictConfig) -> None: + """Training for the 2D Darcy flow benchmark problem.""" + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() # call if required elsewhere + + # initialize monitoring + log = PythonLogger(name="darcy_transolver") + log.file_logging() + initialize_mlflow( + experiment_name=f"Darcy_Transolver", + experiment_desc=f"training a Transformer-based PDE solver for the Darcy problem", + run_name=f"Darcy Transolver training", + run_desc=f"training Transolver for Darcy", + user_name="Haixu Wu, Huakun Luo, Haowen Wang", + mode="offline", + ) + LaunchLogger.initialize(use_mlflow=True) # Modulus launch logger + + # define model, loss, optimiser, scheduler, data loader + model = Transolver( + space_dim=cfg.model.space_dim, + n_layers=cfg.model.n_layers, + n_hidden=cfg.model.n_hidden, + dropout=cfg.model.dropout, + n_head=cfg.model.n_head, + Time_Input=cfg.model.Time_Input, + act=cfg.model.act, + mlp_ratio=cfg.model.mlp_ratio, + fun_dim=cfg.model.fun_dim, + out_dim=cfg.model.out_dim, + slice_num=cfg.model.slice_num, + ref=cfg.model.ref, + unified_pos=cfg.model.unified_pos, + H=cfg.training.resolution, + W=cfg.training.resolution, + ).to(dist.device) + count_parameters(model) + loss_fun = TestLoss(size_average=False) + optimizer = AdamW( + model.parameters(), + lr=cfg.scheduler.initial_lr, + weight_decay=cfg.scheduler.weight_decay, + ) + # scheduler = lr_scheduler.LambdaLR( + # optimizer, lr_lambda=lambda step: cfg.scheduler.decay_rate**step + # ) + + norm_vars = cfg.normaliser + normaliser = { + "permeability": (norm_vars.permeability.mean, norm_vars.permeability.std_dev), + "darcy": (norm_vars.darcy.mean, norm_vars.darcy.std_dev), + } + # train_dataloader = Darcy2D_fix( + # resolution=cfg.training.resolution, + # batch_size=cfg.training.batch_size, + # normaliser=normaliser, + # train_path="/data/fno/piececonst_r421_N1024_smooth1.mat", + # is_test=False, + # ) + train_dataloader = Darcy2D_fix( + resolution=cfg.training.resolution, + batch_size=cfg.training.batch_size, + normaliser=normaliser, + train_path="/data/fno/piececonst_r421_N1024_smooth1.mat", + is_test=False, + ) + # calculate steps per pseudo epoch + steps_per_pseudo_epoch = ceil( + cfg.training.pseudo_epoch_sample_size / cfg.training.batch_size + ) + + scheduler = lr_scheduler.OneCycleLR( + optimizer, + max_lr=cfg.scheduler.initial_lr, + steps_per_epoch=steps_per_pseudo_epoch, + epochs=cfg.training.max_pseudo_epochs, + ) + + x_normalizer, y_normalizer = train_dataloader.__get_normalizer__() + + test_dataloader = Darcy2D_fix( + resolution=cfg.training.resolution, + batch_size=cfg.training.batch_size, + normaliser=normaliser, + train_path="/data/fno/piececonst_r421_N1024_smooth2.mat", + is_test=True, + x_normalizer=x_normalizer, + ) + + validator = GridValidator(loss_fun=TestLoss(size_average=False), norm=y_normalizer) + + ckpt_args = { + "path": f"./checkpoints", + "optimizer": optimizer, + "scheduler": scheduler, + "models": model, + } + loaded_pseudo_epoch = load_checkpoint(device=dist.device, **ckpt_args) + + validation_iters = ceil(cfg.validation.sample_size / cfg.training.batch_size) + log_args = { + "name_space": "train", + "num_mini_batch": steps_per_pseudo_epoch, + "epoch_alert_freq": 1, + } + if cfg.training.pseudo_epoch_sample_size % cfg.training.batch_size != 0: + log.warning( + f"increased pseudo_epoch_sample_size to multiple of \ + batch size: {steps_per_pseudo_epoch*cfg.training.batch_size}" + ) + if cfg.validation.sample_size % cfg.training.batch_size != 0: + log.warning( + f"increased validation sample size to multiple of \ + batch size: {validation_iters*cfg.training.batch_size}" + ) + + # define forward passes for training and inference + @StaticCaptureTraining( + model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + ) + def forward_train(pos, x, y, y_normalizer): + pred = model(pos, fx=x.unsqueeze(-1)).squeeze(-1) + pred = y_normalizer.decode(pred) + loss = loss_fun(pred, y) + return loss + + @StaticCaptureEvaluateNoGrad( + model=model, logger=log, use_amp=False, use_graphs=False + ) + def forward_eval(pos, x, y, y_normalizer): + pred = model(pos, fx=x.unsqueeze(-1)).squeeze(-1) + return y_normalizer.decode(pred) + + if loaded_pseudo_epoch == 0: + log.success("Training started...") + else: + log.warning(f"Resuming training from pseudo epoch {loaded_pseudo_epoch+1}.") + + for pseudo_epoch in range( + max(1, loaded_pseudo_epoch + 1), cfg.training.max_pseudo_epochs + 1 + ): + # Wrap epoch in launch logger for console / MLFlow logs + with LaunchLogger(**log_args, epoch=pseudo_epoch) as logger: + for _, batch in zip(range(steps_per_pseudo_epoch), train_dataloader): + loss = forward_train(*batch, y_normalizer) + logger.log_minibatch({"loss": loss.detach() / cfg.training.batch_size}) + scheduler.step() + logger.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]}) + + # save checkpoint + if pseudo_epoch % cfg.training.rec_results_freq == 0: + save_checkpoint(**ckpt_args, epoch=pseudo_epoch) + + # validation step + if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0: + with LaunchLogger("valid", epoch=pseudo_epoch) as logger: + total_loss = 0.0 + for _, batch in zip(range(validation_iters), test_dataloader): + val_loss = validator.compare( + batch[2], + forward_eval(*batch, y_normalizer), + pseudo_epoch, + logger, + ) + total_loss += val_loss + logger.log_epoch( + { + "Validation error": total_loss + / (validation_iters * cfg.training.batch_size) + } + ) + + # update learning rate + # if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0: + + save_checkpoint(**ckpt_args, epoch=cfg.training.max_pseudo_epochs) + log.success("Training completed *yay*") + + +if __name__ == "__main__": + darcy_trainer() diff --git a/examples/cfd/darcy_transolver/utils/__init__.py b/examples/cfd/darcy_transolver/utils/__init__.py new file mode 100644 index 0000000000..36f2ae17fe --- /dev/null +++ b/examples/cfd/darcy_transolver/utils/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .testloss import TestLoss diff --git a/examples/cfd/darcy_transolver/utils/testloss.py b/examples/cfd/darcy_transolver/utils/testloss.py new file mode 100644 index 0000000000..28cdb0854a --- /dev/null +++ b/examples/cfd/darcy_transolver/utils/testloss.py @@ -0,0 +1,79 @@ +# ignore_header_test +# ruff: noqa: E402 +"""""" +""" +Transolver model. This code was modified from, https://github.com/thuml/Transolver + +The following license is provided from their source, + +MIT License + +Copyright (c) 2024 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch + + +class TestLoss(object): + def __init__(self, d=2, p=2, size_average=True, reduction=True): + super(TestLoss, self).__init__() + + assert d > 0 and p > 0 + + self.d = d + self.p = p + self.reduction = reduction + self.size_average = size_average + + def abs(self, x, y): + num_examples = x.size()[0] + + h = 1.0 / (x.size()[1] - 1.0) + + all_norms = (h ** (self.d / self.p)) * torch.norm( + x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1 + ) + + if self.reduction: + if self.size_average: + return torch.mean(all_norms) + else: + return torch.sum(all_norms) + + return all_norms + + def rel(self, x, y): + num_examples = x.size()[0] + + diff_norms = torch.norm( + x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1 + ) + y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) + if self.reduction: + if self.size_average: + return torch.mean(diff_norms / y_norms) + else: + return torch.sum(diff_norms / y_norms) + + return diff_norms / y_norms + + def __call__(self, x, y): + return self.rel(x, y) diff --git a/examples/cfd/darcy_transolver/validator.py b/examples/cfd/darcy_transolver/validator.py new file mode 100644 index 0000000000..6e3f162b7e --- /dev/null +++ b/examples/cfd/darcy_transolver/validator.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib.pyplot as plt +from torch import FloatTensor +from modulus.launch.logging import LaunchLogger + + +class GridValidator: + """Grid Validator + + The validator compares model output and target, inverts normalisation and plots a sample + + Parameters + ---------- + loss_fun : MSELoss + loss function for assessing validation error + norm : Dict, optional + mean and standard deviation for each channel to normalise input and target + font_size : float, optional + font size used in figures + + """ + + def __init__( + self, + loss_fun, + norm: dict = {"permeability": (0.0, 1.0), "darcy": (0.0, 1.0)}, + font_size: float = 28.0, + ): + self.norm = norm + self.criterion = loss_fun + self.font_size = font_size + self.headers = ("invar", "truth", "prediction", "relative error") + + def compare( + self, + invar: FloatTensor, + target: FloatTensor, + prediction: FloatTensor, + step: int, + logger: LaunchLogger, + ) -> float: + """compares model output, target and plots everything + + Parameters + ---------- + invar : FloatTensor + input to model + target : FloatTensor + ground truth + prediction : FloatTensor + model output + step : int + iteration counter + logger : LaunchLogger + logger to which figure is passed + + Returns + ------- + float + validation error + """ + loss = self.criterion(prediction, target) + norm = self.norm + + # pick first sample from batch + invar = invar * norm["permeability"][1] + norm["permeability"][0] + target = target * norm["darcy"][1] + norm["darcy"][0] + prediction = prediction * norm["darcy"][1] + norm["darcy"][0] + invar = invar.cpu().numpy()[0, -1, :, :] + target = target.cpu().numpy()[0, 0, :, :] + prediction = prediction.detach().cpu().numpy()[0, 0, :, :] + + plt.close("all") + plt.rcParams.update({"font.size": self.font_size}) + fig, ax = plt.subplots(1, 4, figsize=(15 * 4, 15), sharey=True) + im = [] + im.append(ax[0].imshow(invar)) + im.append(ax[1].imshow(target)) + im.append(ax[2].imshow(prediction)) + im.append(ax[3].imshow((prediction - target) / norm["darcy"][1])) + + for ii in range(len(im)): + fig.colorbar(im[ii], ax=ax[ii], location="bottom", fraction=0.046, pad=0.04) + ax[ii].set_title(self.headers[ii]) + + logger.log_figure(figure=fig, artifact_file=f"validation_step_{step:03d}.png") + + return loss diff --git a/examples/cfd/darcy_transolver/validator_fix.py b/examples/cfd/darcy_transolver/validator_fix.py new file mode 100644 index 0000000000..170ed5e64a --- /dev/null +++ b/examples/cfd/darcy_transolver/validator_fix.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib.pyplot as plt +from torch import FloatTensor +from modulus.launch.logging import LaunchLogger + + +class GridValidator: + """Grid Validator + + The validator compares model output and target, inverts normalisation and plots a sample + + Parameters + ---------- + loss_fun : MSELoss + loss function for assessing validation error + norm : Dict, optional + mean and standard deviation for each channel to normalise input and target + font_size : float, optional + font size used in figures + + """ + + def __init__( + self, + loss_fun, + norm, + font_size: float = 28.0, + ): + self.norm = norm + self.criterion = loss_fun + self.font_size = font_size + self.headers = ("true", "prediction", "error") + + def compare( + self, + prediction: FloatTensor, + target: FloatTensor, + step: int, + logger: LaunchLogger, + ) -> float: + """compares model output, target and plots everything + + Parameters + ---------- + invar : FloatTensor + input to model + target : FloatTensor + ground truth + prediction : FloatTensor + model output + step : int + iteration counter + logger : LaunchLogger + logger to which figure is passed + + Returns + ------- + float + validation error + """ + loss = self.criterion(prediction, target) + # print(f"target.shape: {target.shape}, prediction.shape: {prediction.shape}") + print("logger begin") + target = target.cpu().numpy()[0, :, :] + prediction = prediction.reshape(-1, 85, 85).detach().cpu().numpy()[0, :, :] + + plt.close("all") + plt.rcParams.update({"font.size": self.font_size}) + fig, ax = plt.subplots(1, 3, figsize=(15 * 3, 15), sharey=True) + im = [] + im.append(ax[0].imshow(target)) + im.append(ax[1].imshow(prediction)) + im.append(ax[2].imshow((prediction - target))) + + for ii in range(len(im)): + fig.colorbar(im[ii], ax=ax[ii], location="bottom", fraction=0.046, pad=0.04) + ax[ii].set_title(self.headers[ii]) + + logger.log_figure(figure=fig, artifact_file=f"validation_step_{step:03d}.png") + print("logger finished") + + return loss diff --git a/modulus/models/transolver/Embedding.py b/modulus/models/transolver/Embedding.py new file mode 100644 index 0000000000..b359d89091 --- /dev/null +++ b/modulus/models/transolver/Embedding.py @@ -0,0 +1,122 @@ +# ignore_header_test +# ruff: noqa: E402 +"""""" +""" +Transolver model. This code was modified from, https://github.com/thuml/Transolver + +The following license is provided from their source, + +MIT License + +Copyright (c) 2024 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import math + +import torch +import torch.nn as nn +from einops import rearrange + + +class RotaryEmbedding(nn.Module): + "ROPE: Rotary Position Embedding" + + def __init__(self, dim, min_freq=1 / 2, scale=1.0): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.min_freq = min_freq + self.scale = scale + self.register_buffer("inv_freq", inv_freq) + + def forward(self, coordinates, device): + # coordinates [b, n] + t = coordinates.to(device).type_as(self.inv_freq) + t = t * (self.scale / self.min_freq) + freqs = torch.einsum("... i , j -> ... i j", t, self.inv_freq) # [b, n, d//2] + return torch.cat((freqs, freqs), dim=-1) # [b, n, d] + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + + +def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): + # split t into first half and second half + # t: [b, h, n, d] + # freq_x/y: [b, n, d] + d = t.shape[-1] + t_x, t_y = t[..., : d // 2], t[..., d // 2 :] + + return torch.cat( + (apply_rotary_pos_emb(t_x, freqs_x), apply_rotary_pos_emb(t_y, freqs_y)), dim=-1 + ) + + +class PositionalEncoding(nn.Module): + "Implement the PE function." + + def __init__(self, d_model, dropout, max_len=421 * 421): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)].requires_grad_(False) + return self.dropout(x) + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/modulus/models/transolver/Physics_Attention.py b/modulus/models/transolver/Physics_Attention.py new file mode 100644 index 0000000000..519935e900 --- /dev/null +++ b/modulus/models/transolver/Physics_Attention.py @@ -0,0 +1,273 @@ +# ignore_header_test +# ruff: noqa: E402 +"""""" +""" +Transolver model. This code was modified from, https://github.com/thuml/Transolver + +The following license is provided from their source, + +MIT License + +Copyright (c) 2024 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +import torch.nn as nn +from einops import rearrange + + +class Physics_Attention_Irregular_Mesh(nn.Module): + "for irregular meshes in 1D, 2D or 3D space" + + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, slice_num=64): + super().__init__() + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.scale = dim_head**-0.5 + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + + self.in_project_x = nn.Linear(dim, inner_dim) + self.in_project_fx = nn.Linear(dim, inner_dim) + self.in_project_slice = nn.Linear(dim_head, slice_num) + for l_i in [self.in_project_slice]: + torch.nn.init.orthogonal_(l_i.weight) # use a principled initialization + self.to_q = nn.Linear(dim_head, dim_head, bias=False) + self.to_k = nn.Linear(dim_head, dim_head, bias=False) + self.to_v = nn.Linear(dim_head, dim_head, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + def forward(self, x): + # B N C + B, N, C = x.shape + + ### (1) Slice + fx_mid = ( + self.in_project_fx(x) + .reshape(B, N, self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .contiguous() + ) # B H N C + x_mid = ( + self.in_project_x(x) + .reshape(B, N, self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .contiguous() + ) # B H N C + slice_weights = self.softmax( + self.in_project_slice(x_mid) / self.temperature + ) # B H N G + slice_norm = slice_weights.sum(2) # B H G + slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) + slice_token = slice_token / ( + (slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head) + ) + + ### (2) Attention among slice tokens + q_slice_token = self.to_q(slice_token) + k_slice_token = self.to_k(slice_token) + v_slice_token = self.to_v(slice_token) + dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale + attn = self.softmax(dots) + attn = self.dropout(attn) + out_slice_token = torch.matmul(attn, v_slice_token) # B H G D + + ### (3) Deslice + out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) + out_x = rearrange(out_x, "b h n d -> b n (h d)") + return self.to_out(out_x) + + +class Physics_Attention_Structured_Mesh_2D(nn.Module): + "for structured mesh in 2D space" + + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + slice_num=64, + H=101, + W=31, + kernel=3, + ): # kernel=3): + super().__init__() + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.scale = dim_head**-0.5 + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + self.H = H + self.W = W + + self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) + self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) + self.in_project_slice = nn.Linear(dim_head, slice_num) + for l_i in [self.in_project_slice]: + torch.nn.init.orthogonal_(l_i.weight) # use a principled initialization + self.to_q = nn.Linear(dim_head, dim_head, bias=False) + self.to_k = nn.Linear(dim_head, dim_head, bias=False) + self.to_v = nn.Linear(dim_head, dim_head, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + def forward(self, x): + # B N C + B, N, C = x.shape + x = ( + x.reshape(B, self.H, self.W, C) + .contiguous() + .permute(0, 3, 1, 2) + .contiguous() + ) # B C H W + + ### (1) Slice + fx_mid = ( + self.in_project_fx(x) + .permute(0, 2, 3, 1) + .contiguous() + .reshape(B, N, self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .contiguous() + ) # B H N C + x_mid = ( + self.in_project_x(x) + .permute(0, 2, 3, 1) + .contiguous() + .reshape(B, N, self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .contiguous() + ) # B H N G + slice_weights = self.softmax( + self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5) + ) # B H N G + slice_norm = slice_weights.sum(2) # B H G + slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) + slice_token = slice_token / ( + (slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head) + ) + + ### (2) Attention among slice tokens + q_slice_token = self.to_q(slice_token) + k_slice_token = self.to_k(slice_token) + v_slice_token = self.to_v(slice_token) + dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale + attn = self.softmax(dots) + attn = self.dropout(attn) + out_slice_token = torch.matmul(attn, v_slice_token) # B H G D + + ### (3) Deslice + out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) + out_x = rearrange(out_x, "b h n d -> b n (h d)") + return self.to_out(out_x) + + +class Physics_Attention_Structured_Mesh_3D(nn.Module): + "for structured mesh in 3D space" + + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + slice_num=32, + H=32, + W=32, + D=32, + kernel=3, + ): + super().__init__() + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.scale = dim_head**-0.5 + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + self.H = H + self.W = W + self.D = D + + self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) + self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) + self.in_project_slice = nn.Linear(dim_head, slice_num) + for l_i in [self.in_project_slice]: + torch.nn.init.orthogonal_(l_i.weight) # use a principled initialization + self.to_q = nn.Linear(dim_head, dim_head, bias=False) + self.to_k = nn.Linear(dim_head, dim_head, bias=False) + self.to_v = nn.Linear(dim_head, dim_head, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + def forward(self, x): + # B N C + B, N, C = x.shape + x = ( + x.reshape(B, self.H, self.W, self.D, C) + .contiguous() + .permute(0, 4, 1, 2, 3) + .contiguous() + ) # B C H W + + ### (1) Slice + fx_mid = ( + self.in_project_fx(x) + .permute(0, 2, 3, 4, 1) + .contiguous() + .reshape(B, N, self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .contiguous() + ) # B H N C + x_mid = ( + self.in_project_x(x) + .permute(0, 2, 3, 4, 1) + .contiguous() + .reshape(B, N, self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .contiguous() + ) # B H N G + slice_weights = self.softmax( + self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5) + ) # B H N G + slice_norm = slice_weights.sum(2) # B H G + slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights) + slice_token = slice_token / ( + (slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head) + ) + + ### (2) Attention among slice tokens + q_slice_token = self.to_q(slice_token) + k_slice_token = self.to_k(slice_token) + v_slice_token = self.to_v(slice_token) + dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale + attn = self.softmax(dots) + attn = self.dropout(attn) + out_slice_token = torch.matmul(attn, v_slice_token) # B H G D + + ### (3) Deslice + out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights) + out_x = rearrange(out_x, "b h n d -> b n (h d)") + return self.to_out(out_x) diff --git a/modulus/models/transolver/__init__.py b/modulus/models/transolver/__init__.py new file mode 100644 index 0000000000..e2ee88fe95 --- /dev/null +++ b/modulus/models/transolver/__init__.py @@ -0,0 +1,32 @@ +# ignore_header_test +# ruff: noqa: E402 +"""""" +""" +Transolver model. This code was modified from, https://github.com/thuml/Transolver + +The following license is provided from their source, + +MIT License + +Copyright (c) 2024 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from .transolver import Transolver diff --git a/modulus/models/transolver/transolver.py b/modulus/models/transolver/transolver.py new file mode 100644 index 0000000000..60b51e4280 --- /dev/null +++ b/modulus/models/transolver/transolver.py @@ -0,0 +1,401 @@ +# ignore_header_test +# ruff: noqa: E402 +"""""" +""" +Transolver model. This code was modified from, https://github.com/thuml/Transolver + +The following license is provided from their source, + +MIT License + +Copyright (c) 2024 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn as nn +from timm.layers import trunc_normal_ + +import modulus # noqa: F401 for docs + +from ..meta import ModelMetaData +from ..module import Module +from .Embedding import timestep_embedding +from .Physics_Attention import Physics_Attention_Structured_Mesh_2D + +ACTIVATION = { + "gelu": nn.GELU, + "tanh": nn.Tanh, + "sigmoid": nn.Sigmoid, + "relu": nn.ReLU, + "leaky_relu": nn.LeakyReLU(0.1), + "softplus": nn.Softplus, + "ELU": nn.ELU, + "silu": nn.SiLU, +} + + +class MLP(nn.Module): + def __init__(self, n_input, n_hidden, n_output, n_layers=1, act="gelu", res=True): + super(MLP, self).__init__() + + if act in ACTIVATION.keys(): + act = ACTIVATION[act] + else: + raise NotImplementedError + self.n_input = n_input + self.n_hidden = n_hidden + self.n_output = n_output + self.n_layers = n_layers + self.res = res + self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) + self.linear_post = nn.Linear(n_hidden, n_output) + self.linears = nn.ModuleList( + [ + nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) + for _ in range(n_layers) + ] + ) + + def forward(self, x): + # print(x.shape) + x = self.linear_pre(x) + for i in range(self.n_layers): + if self.res: + x = self.linears[i](x) + x + else: + x = self.linears[i](x) + x = self.linear_post(x) + return x + + +class Transolver_block(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + dropout: float, + act="gelu", + mlp_ratio=4, + last_layer=False, + out_dim=1, + slice_num=32, + H=85, + W=85, + ): + super().__init__() + self.last_layer = last_layer + self.ln_1 = nn.LayerNorm(hidden_dim) + self.Attn = Physics_Attention_Structured_Mesh_2D( + hidden_dim, + heads=num_heads, + dim_head=hidden_dim // num_heads, + dropout=dropout, + slice_num=slice_num, + H=H, + W=W, + ) + + self.ln_2 = nn.LayerNorm(hidden_dim) + self.mlp = MLP( + hidden_dim, + hidden_dim * mlp_ratio, + hidden_dim, + n_layers=0, + res=False, + act=act, + ) + if self.last_layer: + self.ln_3 = nn.LayerNorm(hidden_dim) + self.mlp2 = nn.Linear(hidden_dim, out_dim) + + def forward(self, fx): + fx = self.Attn(self.ln_1(fx)) + fx + fx = self.mlp(self.ln_2(fx)) + fx + if self.last_layer: + return self.mlp2(self.ln_3(fx)) + else: + return fx + + +class Model(nn.Module): + def __init__( + self, + space_dim=1, + n_layers=5, + n_hidden=256, + dropout=0.0, + n_head=8, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=False, + H=85, + W=85, + ): + super().__init__() + self.__name__ = "Transolver_2D" + self.H = H + self.W = W + self.ref = ref + self.unified_pos = unified_pos + if self.unified_pos: + self.pos = self.get_grid() + self.preprocess = MLP( + fun_dim + self.ref * self.ref, + n_hidden * 2, + n_hidden, + n_layers=0, + res=False, + act=act, + ) + else: + self.preprocess = MLP( + fun_dim + space_dim, + n_hidden * 2, + n_hidden, + n_layers=0, + res=False, + act=act, + ) + + self.Time_Input = Time_Input + self.n_hidden = n_hidden + self.space_dim = space_dim + if Time_Input: + self.time_fc = nn.Sequential( + nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden) + ) + + self.blocks = nn.ModuleList( + [ + Transolver_block( + num_heads=n_head, + hidden_dim=n_hidden, + dropout=dropout, + act=act, + mlp_ratio=mlp_ratio, + out_dim=out_dim, + slice_num=slice_num, + H=H, + W=W, + last_layer=(_ == n_layers - 1), + ) + for _ in range(n_layers) + ] + ) + self.initialize_weights() + self.placeholder = nn.Parameter( + (1 / (n_hidden)) * torch.rand(n_hidden, dtype=torch.float) + ) + + def initialize_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_grid(self, batchsize=1): + size_x, size_y = self.H, self.W + gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) + gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) + gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) + gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) + grid = torch.cat((gridx, gridy), dim=-1) # B H W 2 + + gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) + gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1]) + gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float) + gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1]) + grid_ref = torch.cat((gridx, gridy), dim=-1) # B H W 8 8 2 + + pos = ( + torch.sqrt( + torch.sum( + (grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) + ** 2, + dim=-1, + ) + ) + .reshape(batchsize, size_x, size_y, self.ref * self.ref) + .contiguous() + ) + return pos + + def forward(self, x, fx, T=None): + if self.unified_pos: + x = ( + self.pos.repeat(x.shape[0], 1, 1, 1) + .reshape(x.shape[0], self.H * self.W, self.ref * self.ref) + .to(x.device) + ) + if fx is not None: + fx = torch.cat((x, fx), -1) + fx = self.preprocess(fx) + else: + fx = self.preprocess(x) + fx = fx + self.placeholder[None, None, :] + + if T is not None: + Time_emb = timestep_embedding(T, self.n_hidden).repeat(1, x.shape[1], 1) + Time_emb = self.time_fc(Time_emb) + fx = fx + Time_emb + + for block in self.blocks: + fx = block(fx) + + return fx + + +@dataclass +class MetaData(ModelMetaData): + name: str = "Transolver" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp: bool = False + # Inference + onnx_cpu: bool = False # No FFT op on CPU + onnx_gpu: bool = True + onnx_runtime: bool = True + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + + +class Transolver(Module): + """Transformer-based solver for PDEs. + + Note + ---- + Transolver is a model specifically designed for structured 2D mesh data. + + Parameters + ---------- + space_dim : int + The spatial dimension of the input data. + n_layers : int + The number of transformer layers. + n_hidden : int + The hidden dimension of the transformer. + dropout : float + The dropout rate. + n_head : int + The number of attention heads. + Time_Input : bool + Whether to include time embeddings. + act : str + The activation function. + mlp_ratio : int + The ratio of hidden dimension in the MLP. + fun_dim : int + The dimension of the function. + out_dim : int + The output dimension. + slice_num : int + The number of slices in the structured attention. + ref : int + The reference dimension. + unified_pos : bool + Whether to use unified positional embeddings. + H : int + The height of the mesh. + W : int + The width of the mesh. + """ + + def __init__( + self, + space_dim: int, + n_layers: int, + n_hidden: int, + dropout: float, + n_head: int, + Time_Input: bool, + act: str, + mlp_ratio: int, + fun_dim: int, + out_dim: int, + slice_num: int, + ref: int, + unified_pos: bool, + H: int, + W: int, + ) -> None: + super().__init__(meta=MetaData()) + self.H = H + self.W = W + self.model = Model( + space_dim=space_dim, + n_layers=n_layers, + n_hidden=n_hidden, + dropout=dropout, + n_head=n_head, + Time_Input=Time_Input, + act=act, + mlp_ratio=mlp_ratio, + fun_dim=fun_dim, + out_dim=out_dim, + slice_num=slice_num, + ref=ref, + unified_pos=unified_pos, + H=H, + W=W, + ) + + def forward( + self, x: torch.Tensor, fx: torch.Tensor = None, T: torch.Tensor = None + ) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + fx : torch.Tensor + The function tensor. + T : torch.Tensor + The time tensor. + + Returns + ------- + torch.Tensor + The output tensor. + + """ + y = self.model(x, fx, T) + y = y.reshape(x.shape[0], self.H, self.W, -1) + return y diff --git a/test/models/data/transolver_output.pth b/test/models/data/transolver_output.pth new file mode 100644 index 0000000000..1cbd941dbf Binary files /dev/null and b/test/models/data/transolver_output.pth differ diff --git a/test/models/test_transolver.py b/test/models/test_transolver.py new file mode 100644 index 0000000000..35154398db --- /dev/null +++ b/test/models/test_transolver.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import pytest +import torch + +from modulus.models.transolver import Transolver + +from . import common + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_transolver_forward(device): + """Test FNO forward pass""" + torch.manual_seed(0) + # Construct FNO model + model = Transolver( + space_dim=2, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=1, + H=85, + W=85, + ).to(device) + + bsize = 4 + pos = torch.randn(bsize, 85, 85).to(device) + invar = torch.randn(bsize, 85 * 85, 1).to(device) + + assert common.validate_forward_accuracy( + model, + ( + pos, + invar, + ), + file_name="transolver_output.pth", + atol=1e-3, + ) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_transolver_constructor(device): + """Test transolver constructor options""" + # Define dictionary of constructor args + model = Transolver( + space_dim=2, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=1, + H=85, + W=85, + ).to(device) + + bsize = random.randint(1, 4) + pos = torch.randn(bsize, 85, 85).to(device) + invar = torch.randn(bsize, 85 * 85, 1).to(device) + + outvar = model(pos, invar) + assert outvar.shape == (bsize, 85, 85, 1) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_transolver_optims(device): + """Test transolver optimizations""" + + def setup_model(): + """Setups up fresh transolver model and inputs for each optim test""" + model = Transolver( + space_dim=2, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=1, + H=85, + W=85, + ).to(device) + + bsize = random.randint(1, 2) + pos = torch.randn(bsize, 85, 85).to(device) + invar = torch.randn(bsize, 85 * 85, 1).to(device) + + return model, pos, invar + + # Ideally always check graphs first + model, pos, invar = setup_model() + assert common.validate_cuda_graphs( + model, + ( + pos, + invar, + ), + ) + + # Check JIT + model, pos, invar = setup_model() + assert common.validate_jit( + model, + ( + pos, + invar, + ), + ) + # Check AMP + model, pos, invar = setup_model() + assert common.validate_amp( + model, + ( + pos, + invar, + ), + ) + # Check Combo + model, pos, invar = setup_model() + assert common.validate_combo_optims( + model, + ( + pos, + invar, + ), + ) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_transolver_checkpoint(device): + """Test transolver checkpoint save/load""" + # Construct transolver models + model_1 = Transolver( + space_dim=2, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=1, + H=85, + W=85, + ).to(device) + + model_2 = Transolver( + space_dim=2, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=1, + H=85, + W=85, + ).to(device) + + bsize = random.randint(1, 2) + pos = torch.randn(bsize, 85, 85).to(device) + invar = torch.randn(bsize, 85 * 85, 1).to(device) + + assert common.validate_checkpoint( + model_1, + model_2, + ( + pos, + invar, + ), + ) + + +@common.check_ort_version() +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_transolverdeploy(device): + """Test transolver deployment support""" + # Construct transolver model + model = Transolver( + space_dim=2, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + Time_Input=False, + act="gelu", + mlp_ratio=1, + fun_dim=1, + out_dim=1, + slice_num=32, + ref=8, + unified_pos=1, + H=85, + W=85, + ).to(device) + + bsize = random.randint(1, 2) + pos = torch.randn(bsize, 85, 85).to(device) + invar = torch.randn(bsize, 85 * 85, 1).to(device) + + assert common.validate_onnx_export( + model, + ( + pos, + invar, + ), + ) + assert common.validate_onnx_runtime( + model, + ( + invar, + invar, + ), + 1e-2, + 1e-2, + )