diff --git a/.circleci/config.yml b/.circleci/config.yml index 19c2d377a..c86323420 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,25 +18,45 @@ setup_env: &setup_env - run: name: Setup environment command: | - python3.8 --version - python3.8 -m pip install --upgrade pip - cd python - python3.8 setup.py bdist_wheel - sudo python3.8 -m pip install --no-input dist/*.whl - cd .. - python3.8 -m pip install pytest - python3.8 -m pip install torch - python3.8 -m pip install numpy - python3.8 -m pip install jinja2 - python3.8 -m pip install recordtype - python3.8 -m pip install parameterized - python3.8 -m pip install einops - git submodule sync - git submodule update --init - echo 'export PYTHONPATH=$PWD/python:$PYTHONPATH' >> $BASH_ENV - echo 'export PATH=/usr/local/cuda-11.4/bin:$PATH' >> $BASH_ENV - echo 'export CI_FLAG=CIRCLECI' >> $BASH_ENV - echo 'export CACHE_DIR=$PWD/tests/ci_profile_cache' >> $BASH_ENV + for i in {1..3}; do + python3.8 --version && + python3.8 -m pip install --upgrade pip && + cd /home/circleci/project/python && + python3.8 setup.py bdist_wheel && + sudo python3.8 -m pip install --no-input dist/*.whl && + cd /home/circleci/project && + python3.8 -m pip install pytest && + python3.8 -m pip install torch && + python3.8 -m pip install numpy && + python3.8 -m pip install jinja2 && + python3.8 -m pip install recordtype && + python3.8 -m pip install parameterized && + python3.8 -m pip install einops && + git submodule sync && + git submodule update --init && + echo 'export PYTHONPATH=$PWD/python:$PYTHONPATH' >> $BASH_ENV && + echo 'export PATH=/usr/local/cuda-11.4/bin:$PATH' >> $BASH_ENV && + echo 'export CI_FLAG=CIRCLECI' >> $BASH_ENV && + echo 'export CACHE_DIR=$PWD/tests/ci_profile_cache' >> $BASH_ENV && + break || sleep 5; + done + + +setup_fx2ait_env: &setup_fx2ait_env + - run: + name: Setup fx2ait environment + command: | + for i in {1..3}; do + wget https://developer.download.nvidia.com/compute/redist/cudnn/v8.7.0/local_installers/11.8/cudnn-linux-x86_64-8.7.0.84_cuda11-archive.tar.xz + tar -xvf cudnn-*-archive.tar.xz + sudo cp cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include + sudo cp -P cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64 + sudo chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn* + python3.8 -m pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + python3.8 fx2ait/setup.py install --prefix=/home/circleci/ + echo 'export PYTHONPATH=$PWD/fx2ait:$PYTHONPATH' >> $BASH_ENV + break || sleep 5; + done basic_tests: &basic_tests - run: @@ -44,19 +64,38 @@ basic_tests: &basic_tests command: | set -e TEST_FILES=$(circleci tests glob "tests/unittest/**/test_*.py" | grep -v benchmark | circleci tests split --split-by=timings) - mkdir test-results - python3.8 -m pytest $TEST_FILES --junitxml=test-results/junit.xml --verbose --continue-on-collection-errors -rA + mkdir ~/test-results + python3.8 -m pytest $TEST_FILES -o junit_family=xunit1 --junitxml=~/test-results/junit.xml --verbose --continue-on-collection-errors -rA +fx2ait_tests: &fx2ait_tests + - run: + name: Run fx2ait tests + command: | + source $BASH_ENV + mkdir -p ~/test-fx2ait-results + TEST_FILES=$(circleci tests glob "fx2ait/fx2ait/test/test_*.py" "fx2ait/fx2ait/test/converters/**/test_*.py") + python3.8 -m pytest $TEST_FILES -o junit_family=xunit1 --junitxml=~/test-fx2ait-results/junit.xml --verbose --continue-on-collection-errors -rA # Define a job to be invoked later in a workflow. # See: https://circleci.com/docs/2.0/configuration-reference/#jobs jobs: + fx2ait-test: + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + steps: + - checkout + - <<: *setup_env + - <<: *setup_fx2ait_env + - <<: *fx2ait_tests + - store_test_results: + path: ~/test-fx2ait-results + build-and-test: machine: image: ubuntu-2004-cuda-11.4:202110-01 # Check T101565170 for multi-gpu use cases. resource_class: gpu.nvidia.medium - parallelism: 10 # Checkout the code as the first step. This is a dedicated CircleCI step. @@ -69,7 +108,7 @@ jobs: - <<: *setup_env - <<: *basic_tests - store_test_results: - path: test-results + path: ~/test-results # Invoke jobs via workflows # See: https://circleci.com/docs/2.0/configuration-reference/#workflows @@ -77,4 +116,5 @@ workflows: unittest: # This is the name of the workflow, feel free to change it to better match your workflow. # Inside the workflow, you define the jobs you want to run. jobs: + - fx2ait-test - build-and-test diff --git a/.flake8 b/.flake8 index 71a5883ed..9ef66bc0d 100644 --- a/.flake8 +++ b/.flake8 @@ -7,111 +7,111 @@ ignore = # Found in https://github.com/psf/black/issues/429 # Line too long. B950, - # Indentation is not a multiple of four. - E111, + # Indentation is not a multiple of four. + E111, # Expected an indented block (comment). - E115, + E115, # Over-indented. E117, - # Continuation line under-indented for hanging indent. + # Continuation line under-indented for hanging indent. E121, - # Continuation line missing indentation or outdented. + # Continuation line missing indentation or outdented. E122, - # Closing bracket does not match indentation of opening bracket's line. + # Closing bracket does not match indentation of opening bracket's line. E123, - # Closing bracket does not match visual indentation. + # Closing bracket does not match visual indentation. E124, - # Continuation line with same indent as next logical line. + # Continuation line with same indent as next logical line. E125, - # Continuation line over-indented for hanging indent. + # Continuation line over-indented for hanging indent. E126, - # Continuation line over-indented for visual indent. + # Continuation line over-indented for visual indent. E127, - # Continuation line under-indented for visual indent. + # Continuation line under-indented for visual indent. E128, - # Visually indented line with same indent as next logical line. + # Visually indented line with same indent as next logical line. E129, - # Continuation line unaligned for hanging indent. + # Continuation line unaligned for hanging indent. E131, - # Whitespace after '('. + # Whitespace after '('. E201, - # Whitespace before ')'. + # Whitespace before ')'. E202, - # Whitespace before ':'. + # Whitespace before ':'. E203, - # Multiple spaces before operator. + # Multiple spaces before operator. E221, - # Multiple spaces after operator. + # Multiple spaces after operator. E222, - # Missing whitespace around operator. + # Missing whitespace around operator. E225, - # Missing whitespace around arithmetic operator. + # Missing whitespace around arithmetic operator. E226, - # Missing whitespace around bitwise or shift operator. + # Missing whitespace around bitwise or shift operator. E227, - # Missing whitespace after ',', ';', or ':'. + # Missing whitespace after ',', ';', or ':'. E231, - # Multiple spaces after ','. + # Multiple spaces after ','. E241, - # Unexpected spaces around keyword / parameter equals. + # Unexpected spaces around keyword / parameter equals. E251, - # Missing whitespace around parameter equals. + # Missing whitespace around parameter equals. E252, - # At least two spaces before inline comment. - E261, + # At least two spaces before inline comment. + E261, # Inline comment should start with '# '. - E262, + E262, # Block comment should start with '# '. E265, - # Multiple spaces after keyword. + # Multiple spaces after keyword. E271, - # Multiple spaces before keyword. + # Multiple spaces before keyword. E272, - # Expected 1 blank line, found 0. + # Expected 1 blank line, found 0. E301, - # Expected 2 blank lines, found 0. + # Expected 2 blank lines, found 0. E302, - # Too many blank lines (3). + # Too many blank lines (3). E303, - # Expected 2 blank lines after end of function or class. + # Expected 2 blank lines after end of function or class. E305, - # Expected 1 blank line before a nested definition. + # Expected 1 blank line before a nested definition. E306, - # Line too long (82 > 79 characters). + # Line too long (82 > 79 characters). E501, - # The backslash is redundant between brackets. + # The backslash is redundant between brackets. E502, - # Multiple statements on one line (colon). + # Multiple statements on one line (colon). E701, - # Multiple statements on one line (semicolon). + # Multiple statements on one line (semicolon). E702, - # Statement ends with a semicolon. + # Statement ends with a semicolon. E703, - # Multiple statements on one line (def). + # Multiple statements on one line (def). E704, - # Trailing whitespace. + # Trailing whitespace. W291, - # No newline at end of file. + # No newline at end of file. W292, - # Blank line contains whitespace. + # Blank line contains whitespace. W293, - # Blank line at end of file. + # Blank line at end of file. W391, - # Line break occurred after a binary operator. - W504, + # Line break occurred after a binary operator. + W504, # Too opinionated. # Block comment should start with '# '. E265, - # Too many leading '#' for block comment. + # Too many leading '#' for block comment. E266, - # Module level import not at top of file. (Use cases like demandimport https://fburl.com/demandimport require statements before imports) - E402, + # Module level import not at top of file. (Use cases like demandimport https://fburl.com/demandimport require statements before imports) + E402, # Do not use bare except, specify exception instead. (Duplicate of B001) - E722, + E722, # (Duplicate of B003) - P207, + P207, # (Duplicate of C403) P208, # Line break occurred before a binary operator. - W503 + W503 diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 000000000..6c1bd8ba9 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,39 @@ +name: Docs + +on: + push: + branches: + - main + + pull_request: + branches: + - main +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install autodocsumm + pip install sphinx_rtd_theme + pip install sphinx_gallery + pip install sphinxcontrib-inlinesyntaxhighlight + pip install sphinx_toolbox + cd python + python setup.py develop + cd .. + pip install numpy + - name: Build documents with Sphinx + run: | + cd docs + make html + cd .. diff --git a/.github/workflows/docs.yml b/.github/workflows/pages.yaml similarity index 96% rename from .github/workflows/docs.yml rename to .github/workflows/pages.yaml index 208bd1f77..d9074b8b3 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/pages.yaml @@ -1,5 +1,5 @@ # Simple workflow for deploying static content to GitHub Pages -name: Documentation +name: Deploy docs to Pages on: # Runs on pushes targeting the default branch @@ -29,7 +29,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.8"] steps: - name: Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/lint.yml b/.github/workflows/pylint.yaml similarity index 91% rename from .github/workflows/lint.yml rename to .github/workflows/pylint.yaml index dbd4beb83..be97139fa 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/pylint.yaml @@ -23,9 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ufmt - pip install click - pip install flake8 + pip install ufmt==2.0.1 click==8.1.3 black==22.12.0 flake8==5.0.4 - name: Analyzing the code with flake8 run: | echo "::add-matcher::tests/lint/flake8_problem_matcher.json" @@ -38,4 +36,4 @@ jobs: - name: Check Meta copyright header run: | python tests/lint/check_meta_header.py --path=./tests --fixit=False - python tests/lint/check_meta_header.py --path=./python --fixit=False \ No newline at end of file + python tests/lint/check_meta_header.py --path=./python --fixit=False diff --git a/examples/05_stable_diffusion/benchmark.py b/examples/05_stable_diffusion/benchmark.py deleted file mode 100644 index 6f0e3f695..000000000 --- a/examples/05_stable_diffusion/benchmark.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import logging - -import click - -import numpy as np -import torch -from aitemplate.compiler import Model -from aitemplate.testing import detect_target -from aitemplate.testing.benchmark_pt import benchmark_torch_function -from diffusers import StableDiffusionPipeline - -from torch import autocast -from transformers import CLIPTokenizer - -USE_CUDA = detect_target().name() == "cuda" - -access_token = True -pipe = None - - -def get_int_shape(x): - shape = [it.value() for it in x._attrs["shape"]] - return shape - - -def mark_output(y): - if type(y) is not tuple: - y = (y,) - for i in range(len(y)): - y[i]._attrs["is_output"] = True - y[i]._attrs["name"] = "output_%d" % (i) - y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] - print("AIT output_{} shape: {}".format(i, y_shape)) - - -def benchmark_unet( - batch_size=2, - hh=64, - ww=64, - dim=320, - hidden_size=1024, - benchmark_pt=False, - verify=False, -): - - exe_module = Model("./tmp/UNet2DConditionModel/test.so") - if exe_module is None: - print("Error!! Cannot find compiled module for UNet2DConditionModel.") - exit(-1) - - # run PT unet model - pt_mod = pipe.unet - pt_mod = pt_mod.eval() - - latent_model_input_pt = torch.randn(batch_size, 4, hh, ww).cuda().half() - text_embeddings_pt = torch.randn(batch_size, 64, hidden_size).cuda().half() - timesteps_pt = torch.Tensor([1, 1]).cuda().half() - - with autocast("cuda"): - pt_ys = pt_mod( - latent_model_input_pt, - timesteps_pt, - encoder_hidden_states=text_embeddings_pt, - ).sample - - # PT benchmark - if benchmark_pt: - args = (latent_model_input_pt, 1, text_embeddings_pt) - pt_time = benchmark_torch_function(100, pt_mod, *args) - print(f"PT batch_size: {batch_size}, {pt_time} ms") - with open("sd_pt_benchmark.txt", "a") as f: - f.write(f"unet batch_size: {batch_size}, latency: {pt_time} ms\n") - - print("pt output:", pt_ys.shape) - - # run AIT unet model - inputs = { - "input0": latent_model_input_pt.permute((0, 2, 3, 1)).contiguous(), - "input1": timesteps_pt, - "input2": text_embeddings_pt, - } - - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys) - - # verification - y_transpose = ys[0].permute((0, 3, 1, 2)) - - if verify: - eps = 1e-1 - np.testing.assert_allclose( - pt_ys.detach().cpu().numpy(), - y_transpose.cpu().numpy(), - atol=eps, - rtol=eps, - ) - print("UNet2DCondition verification pass") - - # AIT benchmark - # warmup - exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) - # benchmark - t, _, _ = exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) - with open("sd_ait_benchmark.txt", "a") as f: - f.write(f"unet batch_size: {batch_size}, latency: {t} ms\n") - - -def benchmark_clip( - batch_size=1, - seqlen=64, - benchmark_pt=False, - verify=False, -): - mask_seq = 0 - version = "openai/clip-vit-large-patch14" - - exe_module = Model("./tmp/CLIPTextModel/test.so") - if exe_module is None: - print("Error!! Cannot find compiled module for CLIPTextModel.") - exit(-1) - - # run PT clip - pt_mod = pipe.text_encoder - pt_mod = pt_mod.eval() - - tokenizer = CLIPTokenizer.from_pretrained(version) - text_input = tokenizer( - ["a photo of an astronaut riding a horse on mars"], - padding="max_length", - max_length=seqlen, - truncation=True, - return_tensors="pt", - ) - input_ids = text_input["input_ids"].cuda() - - attention_mask = torch.ones((batch_size, seqlen)) - attention_mask[-1, -mask_seq:] = 0 - attention_mask = None - - position_ids = torch.arange(seqlen).expand((batch_size, -1)).cuda() - pt_ys = pt_mod(input_ids, attention_mask, position_ids) - print("pt output:", pt_ys[0].shape) - - # PT benchmark - if benchmark_pt: - args = (input_ids, attention_mask, position_ids) - pt_time = benchmark_torch_function(100, pt_mod, *args) - print(f"PT batch_size: {batch_size}, {pt_time} ms") - with open("sd_pt_benchmark.txt", "a") as f: - f.write(f"clip batch_size: {batch_size}, latency: {pt_time} ms\n") - - # run AIT clip - inputs = { - "input0": input_ids, - "input1": position_ids, - } - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys) - - # verification - if verify: - eps = 1e-1 - pt_np = pt_ys[0].detach().cpu().numpy() - np.testing.assert_allclose( - pt_np, - ys[0].cpu().numpy(), - atol=eps, - rtol=eps, - ) - print("CLIPTextTransformer verification pass") - - # AIT benchmark - # warmup - exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) - # benchmark - t, _, _ = exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) - with open("sd_ait_benchmark.txt", "a") as f: - f.write(f"clip batch_size: {batch_size}, latency: {t} ms\n") - - -def benchmark_vae(batch_size=1, height=64, width=64, benchmark_pt=False, verify=False): - - latent_channels = 4 - - exe_module = Model("./tmp/AutoencoderKL/test.so") - if exe_module is None: - print("Error!! Cannot find compiled module for AutoencoderKL.") - exit(-1) - - # run PT vae - pt_vae = pipe.vae - pt_vae = pt_vae.cuda().half() - pt_vae.eval() - - pt_input = torch.rand([batch_size, latent_channels, height, width]).cuda().half() - print("pt_input shape", pt_input.shape) - with autocast("cuda"): - pt_output = pt_vae.decode(pt_input).sample - pt_output = pt_output.half() - - # PT benchmark - if benchmark_pt: - args = (pt_input,) - pt_time = benchmark_torch_function(100, pt_vae.decode, *args) - print(f"PT batch_size: {batch_size}, {pt_time} ms") - with open("sd_pt_benchmark.txt", "a") as f: - f.write(f"vae batch_size: {batch_size}, latency: {pt_time} ms\n") - - # run AIT vae - y = ( - torch.empty( - pt_output.size(0), - pt_output.size(2), - pt_output.size(3), - pt_output.size(1), - ) - .cuda() - .half() - ) - ait_input_pt_tensor = torch.permute(pt_input, (0, 2, 3, 1)).contiguous() - print("input pt tensor size: ", ait_input_pt_tensor.shape) - print("output pt tensor size: ", y.shape) - exe_module.run_with_tensors([ait_input_pt_tensor], [y]) - - # verification - if verify: - y_pt = torch.permute(y, (0, 3, 1, 2)) - eps = 1e-1 - np.testing.assert_allclose( - pt_output.detach().cpu().numpy(), - y_pt.cpu().numpy(), - atol=eps, - rtol=eps, - ) - logging.info("VAE Verification done!") - - # AIT benchmark: - # warmup - exe_module.benchmark_with_tensors([ait_input_pt_tensor], [y], count=100, repeat=4) - # benchmark - t, _, _ = exe_module.benchmark_with_tensors( - [ait_input_pt_tensor], [y], count=100, repeat=4 - ) - with open("sd_ait_benchmark.txt", "a") as f: - f.write(f"vae batch_size: {batch_size}, latency: {t} ms\n") - - -@click.command() -@click.option("--token", default="", help="access token") -@click.option("--batch-size", default=1, help="batch size") -@click.option("--verify", type=bool, default=False, help="verify correctness") -@click.option("--benchmark-pt", type=bool, default=False, help="run pt benchmark") -def benchmark_diffusers(token, batch_size, verify, benchmark_pt): - assert batch_size == 1, "batch size must be 1 for submodule verification" - logging.getLogger().setLevel(logging.INFO) - np.random.seed(0) - torch.manual_seed(4896) - - global access_token, pipe - if token != "": - access_token = token - - pipe = StableDiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-2", - revision="fp16", - torch_dtype=torch.float16, - use_auth_token=access_token, - ).to("cuda") - - # CLIP - benchmark_clip(batch_size=batch_size, benchmark_pt=benchmark_pt, verify=verify) - # UNet - benchmark_unet(batch_size=batch_size * 2, benchmark_pt=benchmark_pt, verify=verify) - # VAE - benchmark_vae(batch_size=batch_size, benchmark_pt=benchmark_pt, verify=verify) - - -if __name__ == "__main__": - benchmark_diffusers() diff --git a/examples/05_stable_diffusion/benchmark_pt.py b/examples/05_stable_diffusion/benchmark_pt.py deleted file mode 100644 index aa9af8596..000000000 --- a/examples/05_stable_diffusion/benchmark_pt.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import click -import torch - -from aitemplate.testing.benchmark_pt import benchmark_torch_function -from diffusers import StableDiffusionPipeline - - -@click.command() -@click.option("--token", default="", help="access token") -@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") -@click.option( - "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" -) -def run(token, prompt, benchmark): - pipe = StableDiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-2", - revision="fp16", - torch_dtype=torch.float16, - use_auth_token=token, - ).to("cuda") - - with torch.autocast("cuda"): - image = pipe(prompt).images[0] - if benchmark: - t = benchmark_torch_function(10, pipe, prompt) - print(f"sd pt e2e: {t} ms") - - image.save("example_pt.png") - - -if __name__ == "__main__": - run() diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py deleted file mode 100644 index f9f5224df..000000000 --- a/examples/05_stable_diffusion/compile.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -from collections import OrderedDict - -import click -import numpy as np - -import torch - -from aitemplate.compiler import compile_model -from aitemplate.frontend import Tensor -from aitemplate.testing import detect_target -from diffusers import StableDiffusionPipeline - -from modeling.clip import CLIPTextTransformer as ait_CLIPTextTransformer - -from modeling.unet_2d_condition import UNet2DConditionModel as ait_UNet2DConditionModel - -from modeling.vae import AutoencoderKL as ait_AutoencoderKL - - -USE_CUDA = detect_target().name() == "cuda" - -access_token = True -pipe = None - - -def mark_output(y): - if type(y) is not tuple: - y = (y,) - for i in range(len(y)): - y[i]._attrs["is_output"] = True - y[i]._attrs["name"] = "output_%d" % (i) - y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] - print("AIT output_{} shape: {}".format(i, y_shape)) - - -def map_unet_params(pt_mod, dim): - pt_params = dict(pt_mod.named_parameters()) - params_ait = {} - for key, arr in pt_params.items(): - if len(arr.shape) == 4: - arr = arr.permute((0, 2, 3, 1)).contiguous() - elif key.endswith("ff.net.0.proj.weight"): - w1, w2 = arr.chunk(2, dim=0) - params_ait[key.replace(".", "_")] = w1 - params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 - continue - elif key.endswith("ff.net.0.proj.bias"): - w1, w2 = arr.chunk(2, dim=0) - params_ait[key.replace(".", "_")] = w1 - params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 - continue - params_ait[key.replace(".", "_")] = arr - - params_ait["arange"] = ( - torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() - ) - return params_ait - - -def map_vae_params(ait_module, pt_module, batch_size, seq_len): - pt_params = dict(pt_module.named_parameters()) - mapped_pt_params = OrderedDict() - for name, _ in ait_module.named_parameters(): - ait_name = name.replace(".", "_") - if name in pt_params: - if ( - "conv" in name - and "norm" not in name - and name.endswith(".weight") - and len(pt_params[name].shape) == 4 - ): - mapped_pt_params[ait_name] = torch.permute( - pt_params[name], [0, 2, 3, 1] - ).contiguous() - else: - mapped_pt_params[ait_name] = pt_params[name] - elif name.endswith("attention.qkv.weight"): - prefix = name[: -len("attention.qkv.weight")] - q_weight = pt_params[prefix + "query.weight"] - k_weight = pt_params[prefix + "key.weight"] - v_weight = pt_params[prefix + "value.weight"] - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - mapped_pt_params[ait_name] = qkv_weight - elif name.endswith("attention.qkv.bias"): - prefix = name[: -len("attention.qkv.bias")] - q_bias = pt_params[prefix + "query.bias"] - k_bias = pt_params[prefix + "key.bias"] - v_bias = pt_params[prefix + "value.bias"] - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) - mapped_pt_params[ait_name] = qkv_bias - elif name.endswith("attention.proj.weight"): - prefix = name[: -len("attention.proj.weight")] - pt_name = prefix + "proj_attn.weight" - mapped_pt_params[ait_name] = pt_params[pt_name] - elif name.endswith("attention.proj.bias"): - prefix = name[: -len("attention.proj.bias")] - pt_name = prefix + "proj_attn.bias" - mapped_pt_params[ait_name] = pt_params[pt_name] - elif name.endswith("attention.cu_length"): - cu_len = np.cumsum([0] + [seq_len] * batch_size).astype("int32") - mapped_pt_params[ait_name] = torch.from_numpy(cu_len).cuda() - else: - pt_param = pt_module.get_parameter(name) - mapped_pt_params[ait_name] = pt_param - - return mapped_pt_params - - -def map_clip_params(pt_mod, batch_size, seqlen, depth): - - params_pt = list(pt_mod.named_parameters()) - - params_ait = {} - pt_params = {} - for key, arr in params_pt: - pt_params[key.replace("text_model.", "")] = arr - - pt_params = dict(pt_mod.named_parameters()) - for key, arr in pt_params.items(): - name = key.replace("text_model.", "") - ait_name = name.replace(".", "_") - if name.endswith("out_proj.weight"): - ait_name = ait_name.replace("out_proj", "proj") - elif name.endswith("out_proj.bias"): - ait_name = ait_name.replace("out_proj", "proj") - elif name.endswith("q_proj.weight"): - ait_name = ait_name.replace("q_proj", "qkv") - prefix = key[: -len("q_proj.weight")] - q = pt_params[prefix + "q_proj.weight"] - k = pt_params[prefix + "k_proj.weight"] - v = pt_params[prefix + "v_proj.weight"] - qkv_weight = torch.cat([q, k, v], dim=0) - params_ait[ait_name] = qkv_weight - continue - elif name.endswith("q_proj.bias"): - ait_name = ait_name.replace("q_proj", "qkv") - prefix = key[: -len("q_proj.bias")] - q = pt_params[prefix + "q_proj.bias"] - k = pt_params[prefix + "k_proj.bias"] - v = pt_params[prefix + "v_proj.bias"] - qkv_bias = torch.cat([q, k, v], dim=0) - params_ait[ait_name] = qkv_bias - continue - elif name.endswith("k_proj.weight"): - continue - elif name.endswith("k_proj.bias"): - continue - elif name.endswith("v_proj.weight"): - continue - elif name.endswith("v_proj.bias"): - continue - params_ait[ait_name] = arr - - if USE_CUDA: - for i in range(depth): - prefix = "encoder_layers_%d_self_attn_cu_length" % (i) - cu_len = np.cumsum([0] + [seqlen] * batch_size).astype("int32") - params_ait[prefix] = torch.from_numpy(cu_len).cuda() - - return params_ait - - -def compile_unet( - batch_size=2, - hh=64, - ww=64, - dim=320, - hidden_dim=1024, - use_fp16_acc=False, - convert_conv_to_gemm=False, -): - - ait_mod = ait_UNet2DConditionModel( - sample_size=64, - cross_attention_dim=hidden_dim, - attention_head_dim=[5, 10, 20, 20], - ) - ait_mod.name_parameter_tensor() - - # set AIT parameters - pt_mod = pipe.unet - pt_mod = pt_mod.eval() - params_ait = map_unet_params(pt_mod, dim) - - latent_model_input_ait = Tensor( - [batch_size, hh, ww, 4], name="input0", is_input=True - ) - timesteps_ait = Tensor([batch_size], name="input1", is_input=True) - text_embeddings_pt_ait = Tensor( - [batch_size, 64, hidden_dim], name="input2", is_input=True - ) - - Y = ait_mod(latent_model_input_ait, timesteps_ait, text_embeddings_pt_ait) - mark_output(Y) - - target = detect_target( - use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm - ) - compile_model(Y, target, "./tmp", "UNet2DConditionModel", constants=params_ait) - - -def compile_clip( - batch_size=1, - seqlen=64, - dim=768, - num_heads=12, - use_fp16_acc=False, - convert_conv_to_gemm=False, -): - mask_seq = 0 - causal = True - depth = 23 - - ait_mod = ait_CLIPTextTransformer( - num_hidden_layers=depth, - hidden_size=dim, - num_attention_heads=num_heads, - batch_size=batch_size, - seq_len=seqlen, - causal=causal, - mask_seq=mask_seq, - ) - ait_mod.name_parameter_tensor() - - pt_mod = pipe.text_encoder - pt_mod = pt_mod.eval() - params_ait = map_clip_params(pt_mod, batch_size, seqlen, depth) - - input_ids_ait = Tensor( - [batch_size, seqlen], name="input0", dtype="int64", is_input=True - ) - position_ids_ait = Tensor( - [batch_size, seqlen], name="input1", dtype="int64", is_input=True - ) - Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait) - mark_output(Y) - - target = detect_target( - use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm - ) - compile_model(Y, target, "./tmp", "CLIPTextModel", constants=params_ait) - - -def compile_vae( - batch_size=1, height=64, width=64, use_fp16_acc=False, convert_conv_to_gemm=False -): - in_channels = 3 - out_channels = 3 - down_block_types = [ - "DownEncoderBlock2D", - "DownEncoderBlock2D", - "DownEncoderBlock2D", - "DownEncoderBlock2D", - ] - up_block_types = [ - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D", - ] - block_out_channels = [128, 256, 512, 512] - layers_per_block = 2 - act_fn = "silu" - latent_channels = 4 - sample_size = 512 - - ait_vae = ait_AutoencoderKL( - batch_size, - height, - width, - in_channels=in_channels, - out_channels=out_channels, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - latent_channels=latent_channels, - sample_size=sample_size, - ) - ait_input = Tensor( - shape=[batch_size, height, width, latent_channels], - name="vae_input", - is_input=True, - ) - ait_vae.name_parameter_tensor() - - pt_mod = pipe.vae - pt_mod = pt_mod.eval() - params_ait = map_vae_params(ait_vae, pt_mod, batch_size, height * width) - - Y = ait_vae.decode(ait_input) - mark_output(Y) - target = detect_target( - use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm - ) - compile_model( - Y, - target, - "./tmp", - "AutoencoderKL", - constants=params_ait, - ) - - -@click.command() -@click.option("--token", default="", help="access token") -@click.option("--width", default=512, help="Width of generated image") -@click.option("--height", default=512, help="Height of generated image") -@click.option("--batch-size", default=1, help="batch size") -@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") -@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") -def compile_diffusers( - token, width, height, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True -): - logging.getLogger().setLevel(logging.INFO) - np.random.seed(0) - torch.manual_seed(4896) - - if detect_target().name() == "rocm": - convert_conv_to_gemm = False - - global access_token, pipe - if token != "": - access_token = token - - pipe = StableDiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-2", - revision="fp16", - torch_dtype=torch.float16, - use_auth_token=access_token, - ).to("cuda") - - ww = width // 8 - hh = height // 8 - - # CLIP - compile_clip( - batch_size=batch_size, - dim=1024, - num_heads=16, - use_fp16_acc=use_fp16_acc, - convert_conv_to_gemm=convert_conv_to_gemm, - ) - # UNet - compile_unet( - batch_size=batch_size * 2, - ww=ww, - hh=hh, - use_fp16_acc=use_fp16_acc, - convert_conv_to_gemm=convert_conv_to_gemm, - ) - # VAE - compile_vae( - batch_size=batch_size, - width=ww, - height=hh, - use_fp16_acc=use_fp16_acc, - convert_conv_to_gemm=convert_conv_to_gemm, - ) - - -if __name__ == "__main__": - compile_diffusers() diff --git a/examples/05_stable_diffusion/demo.py b/examples/05_stable_diffusion/demo.py deleted file mode 100644 index 1a2fca835..000000000 --- a/examples/05_stable_diffusion/demo.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import click -import torch - -from aitemplate.testing.benchmark_pt import benchmark_torch_function -from diffusers import EulerDiscreteScheduler -from pipeline_stable_diffusion_ait import StableDiffusionAITPipeline - - -@click.command() -@click.option("--token", default="", help="access token") -@click.option("--width", default=512, help="Width of generated image") -@click.option("--height", default=512, help="Height of generated image") -@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") -@click.option( - "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" -) -def run(token, width, height, prompt, benchmark): - - model_id = "stabilityai/stable-diffusion-2" - scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") - - pipe = StableDiffusionAITPipeline.from_pretrained( - model_id, - scheduler=scheduler, - revision="fp16", - torch_dtype=torch.float16, - use_auth_token=token, - ).to("cuda") - - with torch.autocast("cuda"): - image = pipe(prompt, height, width).images[0] - if benchmark: - t = benchmark_torch_function(10, pipe, prompt, height=height, width=width) - print(f"sd e2e: {t} ms") - - image.save("example_ait.png") - - -if __name__ == "__main__": - run() diff --git a/examples/05_stable_diffusion/demo_img2img.py b/examples/05_stable_diffusion/demo_img2img.py deleted file mode 100644 index 569a713ed..000000000 --- a/examples/05_stable_diffusion/demo_img2img.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from io import BytesIO - -import click -import requests -import torch - -from aitemplate.testing.benchmark_pt import benchmark_torch_function -from PIL import Image -from pipeline_stable_diffusion_img2img_ait import StableDiffusionImg2ImgAITPipeline - - -@click.command() -@click.option("--token", default="", help="access token") -@click.option("--width", default=512, help="Width of generated image") -@click.option("--height", default=512, help="Height of generated image") -@click.option( - "--prompt", default="A fantasy landscape, trending on artstation", help="prompt" -) -@click.option( - "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" -) -def run(token, width, height, prompt, benchmark): - - # load the pipeline - device = "cuda" - model_id_or_path = "runwayml/stable-diffusion-v1-5" - pipe = StableDiffusionImg2ImgAITPipeline.from_pretrained( - model_id_or_path, - revision="fp16", - torch_dtype=torch.float16, - use_auth_token=token, - ) - pipe = pipe.to(device) - - # let's download an initial image - url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - - response = requests.get(url) - init_image = Image.open(BytesIO(response.content)).convert("RGB") - init_image = init_image.resize((height, width)) - - with torch.autocast("cuda"): - images = pipe( - prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5 - ).images - if benchmark: - args = (prompt, init_image) - t = benchmark_torch_function(10, pipe, *args) - print(f"sd e2e: {t} ms") - - images[0].save("fantasy_landscape_ait.png") - - -if __name__ == "__main__": - run() diff --git a/examples/05_stable_diffusion/modeling/attention.py b/examples/05_stable_diffusion/modeling/attention.py deleted file mode 100644 index 14993e6d9..000000000 --- a/examples/05_stable_diffusion/modeling/attention.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Implementations are translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py. -""" - -from typing import Optional - -from aitemplate.compiler.ops import reshape - -from aitemplate.frontend import nn, Tensor - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - Parameters: - batch_size (:obj:`int`): The number of examples per batch. - height (:obj:`int`): Height of each image example. - width (:obj:`int`): Width of each image example. - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - def __init__( - self, - batch_size: int, - height: int, - width: int, - channels: int, - num_head_channels: Optional[int] = None, - num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.batch_size = batch_size - self.height = height - self.width = width - self.channels = channels - self.num_heads = ( - channels // num_head_channels if num_head_channels is not None else 1 - ) - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_groups, channels, eps) - self.attention = nn.MultiheadAttention( - channels, - batch_size, - height * width, - self.num_heads, - qkv_bias=True, - has_residual=True, - use_mem_eff=True, - ) - self.rescale_output_factor = rescale_output_factor - - def forward(self, hidden_states) -> Tensor: - """ - input hidden_states shape: [batch, height, width, channel] - output shape: [batch, height, width, channel] - """ - residual = hidden_states - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = reshape()( - hidden_states, [self.batch_size, self.height * self.width, self.channels] - ) - - batch, hw, channel = hidden_states.shape() - if ( - batch.value() != self.batch_size - or hw.value() != self.width * self.height - or channel.value() != self.channels - ): - raise RuntimeError( - "nchw params do not match! " - f"Expected: {self.batch_size}, {self.channels}, {self.height} * {self.width}, " - f"actual: {batch}, {channel}, {hw}." - ) - - res = self.attention(hidden_states, residual) * (1 / self.rescale_output_factor) - res = reshape()(res, [self.batch_size, self.height, self.width, self.channels]) - - return res diff --git a/examples/05_stable_diffusion/modeling/clip.py b/examples/05_stable_diffusion/modeling/clip.py deleted file mode 100644 index 8d6079988..000000000 --- a/examples/05_stable_diffusion/modeling/clip.py +++ /dev/null @@ -1,588 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from inspect import isfunction -from typing import Optional - -from aitemplate.compiler import ops -from aitemplate.frontend import nn, Tensor -from aitemplate.testing import detect_target - -# pylint: disable=W0102 - -USE_CUDA = detect_target().name() == "cuda" - - -def get_shape(x): - shape = [it.value() for it in x._attrs["shape"]] - return shape - - -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -class CrossAttention(nn.Module): - def __init__( - self, - query_dim, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - dtype="float16", - ): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head**-0.5 - self.heads = heads - self.dim_head = dim_head - - self.to_q_weight = nn.Parameter(shape=[inner_dim, query_dim], dtype=dtype) - self.to_k_weight = nn.Parameter(shape=[inner_dim, context_dim], dtype=dtype) - self.to_v_weight = nn.Parameter(shape=[inner_dim, context_dim], dtype=dtype) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - - def forward(self, x, context=None, mask=None, residual=None): - nheads = self.heads - d = self.dim_head - - layout = "20314" if USE_CUDA else "m2n3" - - bs, seqlen, _ = get_shape(x) - q = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( - ops.reshape()(x, [bs * seqlen, -1]), self.to_q_weight.tensor() - ) - context = default(context, x) - - seqlen = get_shape(context)[1] - k = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( - ops.reshape()(context, [bs * seqlen, -1]), self.to_k_weight.tensor() - ) - v = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( - ops.reshape()(context, [bs * seqlen, -1]), self.to_v_weight.tensor() - ) - - if USE_CUDA: - attn_op = ops.mem_eff_attention(causal=False) - out = attn_op( - (ops.reshape()(q, [bs, nheads, -1, d])), - (ops.reshape()(k, [bs, nheads, -1, d])), - (ops.reshape()(v, [bs, nheads, -1, d])), - ) - else: - OP = ops.bmm_softmax_bmm_permute(shape=(nheads,), scale=self.scale) - out = OP( - (ops.reshape()(q, [bs * nheads, -1, d])), - (ops.reshape()(k, [bs * nheads, -1, d])), - (ops.reshape()(v, [bs * nheads, -1, d])), - ) - out = ops.reshape()(out, [bs, -1, nheads * d]) - proj = self.to_out(out) - proj = ops.reshape()(proj, [bs, -1, nheads * d]) - if residual is not None: - return proj + residual - else: - return proj - - -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out, specialization="mul") - self.gate = nn.Linear(dim_in, dim_out, specialization="fast_gelu") - - def forward(self, x): - return self.proj(x, self.gate(x)) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential( - nn.Linear(dim, inner_dim, specialization="fast_gelu"), - ) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x, residual=None): - shape = ops.size()(x) - x = self.net(x) - x = ops.reshape()(x, shape) - if residual is not None: - return x + residual - else: - return x - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - self.param = (dim, n_heads, d_head, context_dim, gated_ff, checkpoint) - - def forward(self, x, context=None): - x = self.attn1(self.norm1(x), residual=x) - x = self.attn2(self.norm2(x), context=context, residual=x) - x = self.ff(self.norm3(x), residual=x) - return x - - -def Normalize(in_channels): - return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - """ - - def __init__( - self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None - ): - super().__init__() - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) # Group Norm - - self.proj_in = nn.Conv2dBias( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim - ) - for d in range(depth) - ] - ) - - self.proj_out = nn.Conv2dBias( - inner_dim, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - b, h, w, c = get_shape(x) - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = ops.reshape()(x, [b, -1, c]) - for block in self.transformer_blocks: - x = block(x, context=context) - x = ops.reshape()(x, [b, h, w, c]) - x = self.proj_out(x) - return x + x_in - - -class CLIPAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size=768, - num_attention_heads=12, - attention_dropout=0.0, - batch_size=1, - seq_len=16, - layer_norm_eps=1e-5, - hidden_dropout_prob=0.0, - causal=False, - mask_seq=0, - ): - super().__init__() - self.attn = nn.MultiheadAttention( - dim=hidden_size, - batch_size=batch_size, - seq_len=seq_len, - num_heads=num_attention_heads, - qkv_bias=True, - attn_drop=attention_dropout, - proj_drop=hidden_dropout_prob, - has_residual=False, - causal=causal, - mask_seq=mask_seq, - ) - - def forward( - self, - hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, - causal_attention_mask: Optional[Tensor] = None, - output_attentions: Optional[bool] = False, - residual: Optional[Tensor] = None, - ): - if residual is not None: - self_output = self.attn(hidden_states, residual) - else: - self_output = self.attn(hidden_states) - return self_output - - -class QuickGELUActivation(nn.Module): - """ - Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs - """ - - def forward(self, x): - x1 = x * 1.702 - x1 = ops.sigmoid(x1) - x = x * x1 - return x - - -class CLIPMLP(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer="GELU", - drop=0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - self.fc1 = nn.Linear( - in_features, - hidden_features, - specialization="gelu", - ) - self.fc2 = nn.Linear(hidden_features, out_features, specialization="add") - - def forward(self, x, res): - shape = get_shape(x) - x = self.fc1(x) - x = self.fc2(x, res) - return ops.reshape()(x, shape) - - -class CLIPEncoderLayer(nn.Module): - def __init__( - self, - hidden_size=768, - num_attention_heads=12, - attention_dropout=0.0, - mlp_ratio=4.0, - batch_size=1, - seq_len=16, - causal=False, - mask_seq=0, - ): - super().__init__() - self.embed_dim = hidden_size - self.self_attn = nn.MultiheadAttention( - dim=hidden_size, - batch_size=batch_size, - seq_len=seq_len, - num_heads=num_attention_heads, - qkv_bias=True, - attn_drop=attention_dropout, - proj_drop=0, - has_residual=True, - causal=causal, - mask_seq=mask_seq, - use_mem_eff=True, - ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim) - self.mlp = CLIPMLP(hidden_size, int(hidden_size * mlp_ratio)) - self.layer_norm2 = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: Tensor, - output_attentions: Optional[bool] = False, - ): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states, residual) - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states, residual) - - return hidden_states - - -class CLIPEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`CLIPEncoderLayer`]. - Args: - config: CLIPConfig - """ - - def __init__( - self, - num_hidden_layers=12, - output_attentions=False, - output_hidden_states=False, - use_return_dict=False, - hidden_size=768, - num_attention_heads=12, - batch_size=1, - seq_len=64, - causal=False, - mask_seq=0, - ): - super().__init__() - self.layers = nn.ModuleList( - [ - CLIPEncoderLayer( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - batch_size=batch_size, - seq_len=seq_len, - causal=causal, - mask_seq=mask_seq, - ) - for _ in range(num_hidden_layers) - ] - ) - self.output_attentions = output_attentions - self.output_hidden_states = output_hidden_states - self.use_return_dict = use_return_dict - - def forward( - self, - inputs_embeds, - attention_mask: Optional[Tensor] = None, - causal_attention_mask: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Causal mask for the text model. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.use_return_dict - - encoder_states = () if output_hidden_states else None - # all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for _, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer(hidden_states) - hidden_states = layer_outputs - - return hidden_states - - -class CLIPTextEmbeddings(nn.Module): - def __init__( - self, - hidden_size=768, - vocab_size=49408, - max_position_embeddings=77, - dtype="float16", - ): - super().__init__() - embed_dim = hidden_size - - self.token_embedding = nn.Embedding(shape=[vocab_size, embed_dim], dtype=dtype) - self.position_embedding = nn.Embedding( - shape=[max_position_embeddings, embed_dim], dtype=dtype - ) - - def forward( - self, - input_ids: Tensor, - position_ids: Tensor, - inputs_embeds: Optional[Tensor] = None, - ) -> Tensor: - - input_shape = ops.size()(input_ids) - - # [B * S] - input_ids = ops.reshape()(input_ids, [-1]) - - position_ids = ops.reshape()(position_ids, [-1]) - - if inputs_embeds is None: - inputs_embeds = ops.batch_gather()(self.token_embedding.tensor(), input_ids) - - position_embeddings = ops.batch_gather()( - self.position_embedding.tensor(), position_ids - ) - - embeddings = inputs_embeds + position_embeddings - - embeddings = ops.reshape()(embeddings, [input_shape[0], input_shape[1], -1]) - - return embeddings - - -class CLIPTextTransformer(nn.Module): - def __init__( - self, - hidden_size=768, - output_attentions=False, - output_hidden_states=False, - use_return_dict=False, - num_hidden_layers=12, - num_attention_heads=12, - batch_size=1, - seq_len=64, - causal=False, - mask_seq=0, - ): - super().__init__() - embed_dim = hidden_size - self.embeddings = CLIPTextEmbeddings(hidden_size=hidden_size) - self.encoder = CLIPEncoder( - num_hidden_layers=num_hidden_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - batch_size=batch_size, - seq_len=seq_len, - causal=causal, - mask_seq=mask_seq, - ) - self.final_layer_norm = nn.LayerNorm(embed_dim) - - self.output_attentions = output_attentions - self.output_hidden_states = output_hidden_states - self.use_return_dict = use_return_dict - - def forward( - self, - input_ids: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - r""" - Returns: - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.use_return_dict - - if input_ids is None: - raise ValueError("You have to specify either input_ids") - - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - ) - - last_hidden_state = encoder_outputs - last_hidden_state = self.final_layer_norm(last_hidden_state) - return last_hidden_state diff --git a/examples/05_stable_diffusion/modeling/embeddings.py b/examples/05_stable_diffusion/modeling/embeddings.py deleted file mode 100644 index 36b96a4fb..000000000 --- a/examples/05_stable_diffusion/modeling/embeddings.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import math - -from aitemplate.compiler import ops -from aitemplate.frontend import nn, Tensor - - -def get_shape(x): - shape = [it.value() for it in x._attrs["shape"]] - return shape - - -def get_timestep_embedding( - timesteps: Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(get_shape(timesteps)) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - - exponent = (-math.log(max_period)) * Tensor( - shape=[half_dim], dtype="float16", name="arange" - ) - - exponent = exponent * (1.0 / (half_dim - downscale_freq_shift)) - - emb = ops.exp(exponent) - emb = ops.reshape()(timesteps, [-1, 1]) * ops.reshape()(emb, [1, -1]) - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - if flip_sin_to_cos: - emb = ops.concatenate()( - [ops.cos(emb), ops.sin(emb)], - dim=-1, - ) - else: - emb = ops.concatenate()( - [ops.sin(emb), ops.cos(emb)], - dim=-1, - ) - return emb - - -class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): - super().__init__() - - self.linear_1 = nn.Linear(channel, time_embed_dim, specialization="swish") - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) - - def forward(self, sample): - sample = self.linear_1(sample) - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__( - self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float - ): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb diff --git a/examples/05_stable_diffusion/modeling/resnet.py b/examples/05_stable_diffusion/modeling/resnet.py deleted file mode 100644 index 03e4f8023..000000000 --- a/examples/05_stable_diffusion/modeling/resnet.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from aitemplate.compiler import ops -from aitemplate.frontend import nn - - -def get_shape(x): - shape = [it.value() for it in x._attrs["shape"]] - return shape - - -class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2dBias(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = nn.Conv2dBias(self.channels, self.out_channels, 3, 1, 1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward(self, x): - assert get_shape(x)[-1] == self.channels - if self.use_conv_transpose: - return self.conv(x) - - x = nn.Upsampling2d(scale_factor=2.0, mode="nearest")(x) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - x = self.conv(x) - else: - x = self.Conv2d_0(x) - - return x - - -class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - conv = nn.Conv2dBias( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, x): - assert get_shape(x)[-1] == self.channels - x = self.conv(x) - - return x - - -class ResnetBlock2D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, - output_scale_factor=1.0, - use_nin_shortcut=None, - up=False, - down=False, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = nn.GroupNorm( - num_groups=groups, - num_channels=in_channels, - eps=eps, - affine=True, - use_swish=True, - ) - - self.conv1 = nn.Conv2dBias( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if temb_channels is not None: - self.time_emb_proj = nn.Linear(temb_channels, out_channels) - else: - self.time_emb_proj = None - - self.norm2 = nn.GroupNorm( - num_groups=groups_out, - num_channels=out_channels, - eps=eps, - affine=True, - use_swish=True, - ) - self.dropout = nn.Dropout(dropout) - self.conv2 = nn.Conv2dBias( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - self.upsample = self.downsample = None - - self.use_nin_shortcut = ( - self.in_channels != self.out_channels - if use_nin_shortcut is None - else use_nin_shortcut - ) - - if self.use_nin_shortcut: - self.conv_shortcut = nn.Conv2dBias( - in_channels, out_channels, 1, 1, 0 - ) # kernel_size=1, stride=1, padding=0) # conv_bias_add - else: - self.conv_shortcut = None - - def forward(self, x, temb=None): - hidden_states = x - - # make sure hidden states is in float32 - # when running in half-precision - hidden_states = self.norm1( - hidden_states - ) # .float()).type(hidden_states.dtype) # fused swish - # hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - x = self.upsample(x) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - x = self.downsample(x) - hidden_states = self.downsample(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(ops.silu(temb)) - bs, dim = get_shape(temb) - temb = ops.reshape()(temb, [bs, 1, 1, dim]) - hidden_states = hidden_states + temb - - # make sure hidden states is in float32 - # when running in half-precision - hidden_states = self.norm2(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - x = self.conv_shortcut(x) - - out = hidden_states + x - - return out diff --git a/examples/05_stable_diffusion/modeling/unet_2d_condition.py b/examples/05_stable_diffusion/modeling/unet_2d_condition.py deleted file mode 100644 index a21879dea..000000000 --- a/examples/05_stable_diffusion/modeling/unet_2d_condition.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Optional, Tuple, Union - -from aitemplate.frontend import nn - -from modeling.embeddings import TimestepEmbedding, Timesteps -from modeling.unet_blocks import get_down_block, get_up_block, UNetMidBlock2DCrossAttn - - -class UNet2DConditionModel(nn.Module): - r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`int`, *optional*): The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - """ - - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types: Tuple[str] = ( - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - ): - super().__init__() - self.center_input_sample = center_input_sample - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv2dBias(in_channels, block_out_channels[0], 3, 1, 1) - # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], - num_groups=norm_num_groups, - eps=norm_eps, - use_swish=True, - ) - - self.conv_out = nn.Conv2dBias(block_out_channels[0], out_channels, 3, 1, 1) - - def forward( - self, - sample, - timesteps, - encoder_hidden_states, - return_dict: bool = True, - ): - """r - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - - # 1. time - t_emb = self.time_proj(timesteps) - emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if ( - hasattr(downsample_block, "attentions") - and downsample_block.attentions is not None - ): - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states - ) - - # 5. up - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - - if ( - hasattr(upsample_block, "attentions") - and upsample_block.attentions is not None - ): - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples - ) - - # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision - sample = self.conv_norm_out(sample) - sample = self.conv_out(sample) - return sample diff --git a/examples/05_stable_diffusion/modeling/unet_blocks.py b/examples/05_stable_diffusion/modeling/unet_blocks.py deleted file mode 100644 index 75de2e0c8..000000000 --- a/examples/05_stable_diffusion/modeling/unet_blocks.py +++ /dev/null @@ -1,761 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# flake8: noqa -from aitemplate.compiler import ops - -from aitemplate.frontend import nn, Tensor -from aitemplate.testing import detect_target -from modeling.attention import AttentionBlock - -from modeling.clip import SpatialTransformer -from modeling.resnet import Downsample2D, ResnetBlock2D, Upsample2D - -# pylint: disable=W0102 - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - cross_attention_dim=None, - downsample_padding=None, -): - down_block_type = ( - down_block_type[7:] - if down_block_type.startswith("UNetRes") - else down_block_type - ) - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - ) - elif down_block_type == "AttnDownBlock2D": - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnDownBlock2D" - ) - return CrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - ) - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - cross_attention_dim=None, -): - up_block_type = ( - up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - ) - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnUpBlock2D" - ) - return CrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - elif up_block_type == "AttnUpBlock2D": - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - raise ValueError(f"{up_block_type} does not exist.") - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=1.0, - cross_attention_dim=1280, - **kwargs, - ): - super().__init__() - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - SpatialTransformer( - in_channels, - attn_num_head_channels, - in_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - attention_type="default", - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - - resnets = [] - attentions = [] - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - SpatialTransformer( - out_channels, - attn_num_head_channels, - out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - in_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - attention_type="default", - output_scale_factor=1.0, - downsample_padding=1, - add_upsample=True, - ): - super().__init__() - - resnets = [] - attentions = [] - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - SpatialTransformer( - out_channels, - attn_num_head_channels, - out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - ): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = ops.concatenate()( - [hidden_states, res_hidden_states], dim=-1 - ) - - hidden_states = resnet(hidden_states, temb=temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = ops.concatenate()( - [hidden_states, res_hidden_states], dim=-1 - ) - - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class UNetMidBlock2D(nn.Module): - def __init__( - self, - batch_size, - height, - width, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=1.0, - **kwargs, - ): - super().__init__() - - if attention_type != "default": - raise NotImplementedError( - f"attention_type must be default! current value: {attention_type}" - ) - - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - AttentionBlock( - batch_size, - height, - width, - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, encoder_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states diff --git a/examples/05_stable_diffusion/modeling/vae.py b/examples/05_stable_diffusion/modeling/vae.py deleted file mode 100644 index 6a239f233..000000000 --- a/examples/05_stable_diffusion/modeling/vae.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py. -""" - -from typing import Tuple - -from aitemplate.frontend import nn, Tensor -from modeling.unet_blocks import get_up_block, UNetMidBlock2D - - -class Decoder(nn.Module): - def __init__( - self, - batch_size, - height, - width, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - act_fn="silu", - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2dBias( - in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 - ) - - # mid - self.mid_block = UNetMidBlock2D( - batch_size, - height, - width, - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=32, - temb_channels=None, - ) - - # up - self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=None, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - attn_num_head_channels=None, - temb_channels=None, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], - num_groups=num_groups_out, - eps=1e-6, - use_swish=True, - ) - self.conv_out = nn.Conv2dBias( - block_out_channels[0], out_channels, kernel_size=3, padding=1, stride=1 - ) - - def forward(self, z) -> Tensor: - sample = z - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample) - - # up - for up_block in self.up_blocks: - sample = up_block(sample) - - sample = self.conv_norm_out(sample) - sample = self.conv_out(sample) - - return sample - - -class AutoencoderKL(nn.Module): - def __init__( - self, - batch_size: int, - height: int, - width: int, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - sample_size: int = 32, - ): - super().__init__() - self.decoder = Decoder( - batch_size, - height, - width, - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - ) - self.post_quant_conv = nn.Conv2dBias( - latent_channels, latent_channels, kernel_size=1, stride=1, padding=0 - ) - - def decode(self, z: Tensor, return_dict: bool = True): - - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self): - raise NotImplementedError("Only decode() is implemented for AutoencoderKL!") diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py deleted file mode 100644 index 3a14debcc..000000000 --- a/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import inspect - -import os -import warnings -from typing import List, Optional, Union - -import torch -from aitemplate.compiler import Model - -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) - -from diffusers.pipelines.stable_diffusion import ( - StableDiffusionPipelineOutput, - StableDiffusionSafetyChecker, -) - -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - - -class StableDiffusionAITPipeline(StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) - - workdir = "tmp/" - self.clip_ait_exe = self.init_ait_module( - model_name="CLIPTextModel", workdir=workdir - ) - self.unet_ait_exe = self.init_ait_module( - model_name="UNet2DConditionModel", workdir=workdir - ) - self.vae_ait_exe = self.init_ait_module( - model_name="AutoencoderKL", workdir=workdir - ) - - def init_ait_module( - self, - model_name, - workdir, - ): - mod = Model(os.path.join(workdir, model_name, "test.so")) - return mod - - def unet_inference(self, latent_model_input, timesteps, encoder_hidden_states): - exe_module = self.unet_ait_exe - timesteps_pt = timesteps.expand(latent_model_input.shape[0]) - inputs = { - "input0": latent_model_input.permute((0, 2, 3, 1)) - .contiguous() - .cuda() - .half(), - "input1": timesteps_pt.cuda().half(), - "input2": encoder_hidden_states.cuda().half(), - } - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys, graph_mode=False) - noise_pred = ys[0].permute((0, 3, 1, 2)).float() - return noise_pred - - def clip_inference(self, input_ids, seqlen=64): - exe_module = self.clip_ait_exe - bs = input_ids.shape[0] - position_ids = torch.arange(seqlen).expand((bs, -1)).cuda() - inputs = { - "input0": input_ids, - "input1": position_ids, - } - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys, graph_mode=False) - return ys[0].float() - - def vae_inference(self, vae_input): - exe_module = self.vae_ait_exe - inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()] - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys, graph_mode=False) - vae_out = ys[0].permute((0, 3, 1, 2)).float() - return vae_out - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." - " Consider using `pipe.to(torch_device)` instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=64, # self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.clip_inference(text_input.input_ids.to(self.device)) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - max_length = text_input.input_ids.shape[-1] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - return_tensors="pt", - ) - uncond_embeddings = self.clip_inference( - uncond_input.input_ids.to(self.device) - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - ) - else: - if latents.shape != latents_shape: - raise ValueError( - f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" - ) - latents = latents.to(self.device) - - # set timesteps - accepts_offset = "offset" in set( - inspect.signature(self.scheduler.set_timesteps).parameters.keys() - ) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - latents = latents * self.scheduler.init_noise_sigma - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # predict the noise residual - noise_pred = self.unet_inference( - latent_model_input, t, encoder_hidden_states=text_embeddings - ) - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step( - noise_pred, i, latents, **extra_step_kwargs - ).prev_sample - else: - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae_inference(latents) - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - # run safety checker - if self.safety_checker is not None: - safety_cheker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="pt" - ).to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values - ) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py deleted file mode 100644 index 7380aeebd..000000000 --- a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# flakes8: noqa -import inspect -import os -from typing import List, Optional, Union - -import numpy as np - -import PIL -import torch -from aitemplate.compiler import Model - -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - StableDiffusionImg2ImgPipeline, - UNet2DConditionModel, -) -from diffusers.pipelines.stable_diffusion import ( - StableDiffusionPipelineOutput, - StableDiffusionSafetyChecker, -) -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - - -def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -class StableDiffusionImg2ImgAITPipeline(StableDiffusionImg2ImgPipeline): - r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - # super().__init__() - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) - scheduler = scheduler.set_format("pt") - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - workdir = "tmp/" - self.clip_ait_exe = self.init_ait_module( - model_name="CLIPTextModel", workdir=workdir - ) - self.unet_ait_exe = self.init_ait_module( - model_name="UNet2DConditionModel", workdir=workdir - ) - self.vae_ait_exe = self.init_ait_module( - model_name="AutoencoderKL", workdir=workdir - ) - - def init_ait_module( - self, - model_name, - workdir, - ): - mod = Model(os.path.join(workdir, model_name, "test.so")) - return mod - - def unet_inference(self, latent_model_input, timesteps, encoder_hidden_states): - exe_module = self.unet_ait_exe - timesteps_pt = timesteps.expand(latent_model_input.shape[0]) - inputs = { - "input0": latent_model_input.permute((0, 2, 3, 1)) - .contiguous() - .cuda() - .half(), - "input1": timesteps_pt.cuda().half(), - "input2": encoder_hidden_states.cuda().half(), - } - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys, graph_mode=False) - noise_pred = ys[0].permute((0, 3, 1, 2)).float() - return noise_pred - - def clip_inference(self, input_ids, seqlen=64): - exe_module = self.clip_ait_exe - bs = input_ids.shape[0] - position_ids = torch.arange(seqlen).expand((bs, -1)).cuda() - inputs = { - "input0": input_ids, - "input1": position_ids, - } - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys, graph_mode=False) - return ys[0].float() - - def vae_inference(self, vae_input): - exe_module = self.vae_ait_exe - inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()] - ys = [] - num_ouputs = len(exe_module.get_output_name_to_index_map()) - for i in range(num_ouputs): - shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) - exe_module.run_with_tensors(inputs, ys, graph_mode=False) - vae_out = ys[0].permute((0, 3, 1, 2)).float() - return vae_out - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - - if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) - - # set timesteps - accepts_offset = "offset" in set( - inspect.signature(self.scheduler.set_timesteps).parameters.keys() - ) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - if isinstance(init_image, PIL.Image.Image): - init_image = preprocess(init_image) - - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) - - # get the original timestep using init_timestep - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, - dtype=torch.long, - device=self.device, - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor( - [timesteps] * batch_size, dtype=torch.long, device=self.device - ) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to( - self.device - ) - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=64, # self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.clip_inference(text_input.input_ids.to(self.device)) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, - padding="max_length", - max_length=max_length, - return_tensors="pt", - ) - uncond_embeddings = self.clip_inference( - uncond_input.input_ids.to(self.device) - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - latents = init_latents - - t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): - t_index = t_start + i - - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - latent_model_input = latent_model_input.to(self.unet.dtype) - t = t.to(self.unet.dtype) - - # predict the noise residual - noise_pred = self.unet_inference( - latent_model_input, t, encoder_hidden_states=text_embeddings - ) - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step( - noise_pred, t_index, latents, **extra_step_kwargs - ).prev_sample - else: - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae_inference(latents) - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - # run safety checker - safety_cheker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="pt" - ).to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values - ) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) diff --git a/fx2ait/fx2ait/TARGETS b/fx2ait/fx2ait/TARGETS deleted file mode 100644 index c247fbd11..000000000 --- a/fx2ait/fx2ait/TARGETS +++ /dev/null @@ -1,41 +0,0 @@ -# @noautodeps -load("@fbcode_macros//build_defs:python_library.bzl", "python_library") -load("@fbsource//tools/build_defs:glob_defs.bzl", "glob") - -oncall("aitemplate") - -# Note that we exclude common acc_tracer python files here and will reuse -# those in torch_tensorrt/fx/tracer/acc_tracer/ -python_library( - name = "fx2ait", - srcs = glob( - [ - "converters/*.py", - "*.py", - "passes/*.py", - "tools/*.py", - ] + [ - "acc_tracer/ait_acc_normalizer.py", - "acc_tracer/ait_acc_ops_registry.py", - "acc_tracer/ait_acc_ops.py", - ], - exclude = [ - "cache.py", - ], - ), - base_module = "fx2ait", - deps = [ - "fbsource//third-party/pypi/graphviz:graphviz", - "fbsource//third-party/pypi/numpy:numpy", - "fbsource//third-party/pypi/pydot:pydot", - "//aitemplate/AITemplate/fx2ait/fx2ait/fb:acc_import_helper", - "//aitemplate/AITemplate/fx2ait/fx2ait/fb/lower:ait_lowering_setting", - "//aitemplate/AITemplate/python/aitemplate:aitemplate", - "//caffe2:torch", - "//deeplearning/ait:AITModel", - "//executorch/exir:graph_module", - "//executorch/exir:lib", - "//executorch/exir:tracer", - "//pytorch/vision:torchvision", - ], -) diff --git a/fx2ait/fx2ait/csrc/TARGETS b/fx2ait/fx2ait/csrc/TARGETS deleted file mode 100644 index 88893b1f7..000000000 --- a/fx2ait/fx2ait/csrc/TARGETS +++ /dev/null @@ -1,29 +0,0 @@ -load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library") - -oncall("aitemplate") - -cpp_library( - name = "AITModelImpl", - srcs = ["AITModelImpl.cpp"], - headers = ["AITModelImpl.h"], - propagated_pp_flags = [ - "-DFBCODE_AIT", - "-Iaitemplate/AITemplate/static/include", - ], - supports_python_dlopen = True, - deps = [ - "//caffe2:ATen-cu", - "//caffe2/c10:c10", - "//caffe2/c10:c10_cuda", - "//folly:map_util", - ], - exported_deps = [ - "//aitemplate/AITemplate/static/include:aitemplate", # @manual - "//caffe2:ATen-cu", - "//caffe2:torch-cpp", - "//folly/container:f14_hash", - ], - exported_external_deps = [ - ("glibc", None, "dl"), - ], -) diff --git a/fx2ait/fx2ait/test/TARGETS b/fx2ait/fx2ait/test/TARGETS deleted file mode 100644 index 465522f7d..000000000 --- a/fx2ait/fx2ait/test/TARGETS +++ /dev/null @@ -1,78 +0,0 @@ -load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") -load("@fbsource//tools/build_defs:glob_defs.bzl", "glob") - -oncall("aitemplate") - -[ - python_unittest( - name = test_file.split("/")[-1][:-3], - srcs = [ - test_file, - ], - env = { - "NUM_BUILDERS": "12", - }, - par_style = "xar", - tags = [ - "re_opts_capabilities={\"platform\": \"gpu-remote-execution\", \"subplatform\": \"A100\"}", - "serialize_test_cases", - "supports_remote_execution", - ], - deps = [ - "fbsource//third-party/pypi/numpy:numpy", - "fbsource//third-party/pypi/parameterized:parameterized", - "//aitemplate/AITemplate/fx2ait/fx2ait:fx2ait", - "//aitemplate/AITemplate/fx2ait/fx2ait/fb/converters:internal_converters", - "//caffe2:test-lib", - "//caffe2:torch", - "//deeplearning/trt/torch_tensorrt/py/torch_tensorrt:acc_tracer", - "//deeplearning/trt/torch_tensorrt/py/torch_tensorrt/fb:internal_passes", - "//glow/fb/fx/acc_tracer:acc_tracer", - ], - ) - for test_file in glob( - [ - "fb/converters/test*.py", - "converters/test*.py", - "converters/*/test*.py", - "test*.py", - ], - exclude = [ - "test_fx2ait.py", - "test_ait_lower.py", - ], - ) -] - -[ - python_unittest( - name = test_file.split("/")[-1][:-3], - srcs = [ - test_file, - ], - env = { - "NUM_BUILDERS": "12", - }, - par_style = "xar", - tags = [ - "re_opts_capabilities={\"platform\": \"gpu-remote-execution\", \"subplatform\": \"A100\"}", - "serialize_test_cases", - "supports_remote_execution", - ], - deps = [ - "fbsource//third-party/pypi/numpy:numpy", - "fbsource//third-party/pypi/parameterized:parameterized", - "fbsource//third-party/pypi/transformers:transformers", - "//aitemplate/AITemplate/fx2ait/fx2ait:fx2ait", - "//aitemplate/AITemplate/fx2ait/fx2ait/fb/converters:internal_converters_aten", - "//caffe2:test-lib", - "//caffe2:torch", - "//caffe2/functorch:functorch", - ], - ) - for test_file in glob( - [ - "converters_aten/test*.py", - ], - ) -] diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py deleted file mode 100644 index c556485f1..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = ADD(GeMM(A, B) + bias, D0) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::plus" -BINARY_OP2 = None -UNARY_OP2 = "cutlass::epilogue::thread::Identity" - - -@registry.reg("cuda.gemm_rcr_bias_add.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_add.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_add.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_add.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py deleted file mode 100644 index bd2988abf..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = RELU(ADD(ADD(GeMM(A, B) + bias, D0), D1)) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::plus" -BINARY_OP2 = "cutlass::plus" -UNARY_OP2 = "cutlass::epilogue::thread::Identity" - - -@registry.reg("cuda.gemm_rcr_bias_add_add.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_add_add.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add_add.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add_add.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_add_add.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_add_add.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py deleted file mode 100644 index 5d262712e..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = RELU(ADD(ADD(GeMM(A, B) + bias, D0), D1)) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::plus" -BINARY_OP2 = "cutlass::plus" -UNARY_OP2 = "cutlass::epilogue::thread::ReLu" - - -@registry.reg("cuda.gemm_rcr_bias_add_add_relu.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_add_add_relu.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add_add_relu.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add_add_relu.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_add_add_relu.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_add_add_relu.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py deleted file mode 100644 index 212b01a74..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = RELU(ADD(GeMM(A, B) + bias, D0)) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::plus" -BINARY_OP2 = None -UNARY_OP2 = "cutlass::epilogue::thread::ReLu" - - -@registry.reg("cuda.gemm_rcr_bias_add_relu.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_add_relu.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add_relu.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_add_relu.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_add_relu.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_add_relu.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py deleted file mode 100644 index 1b2dea303..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = ADD(GeMM(A, B) + bias, D0) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::multiplies" -BINARY_OP2 = None -UNARY_OP2 = "cutlass::epilogue::thread::Identity" - - -@registry.reg("cuda.gemm_rcr_bias_mul.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_mul.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_mul.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_mul.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_mul.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_mul.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py deleted file mode 100644 index 12bce07ae..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = Add(Mul(GeMM(A, B) + bias, D0), D1), -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::multiplies" -BINARY_OP2 = "cutlass::plus" -UNARY_OP2 = "cutlass::epilogue::thread::Identity" - - -@registry.reg("cuda.gemm_rcr_bias_mul_add.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_mul_add.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_mul_add.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_mul_add.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_mul_add.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_mul_add.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py deleted file mode 100644 index c8be43f28..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = TANH(Mul((GeMM(A, B) + bias), D0)) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Identity" -BINARY_OP1 = "cutlass::multiplies" -BINARY_OP2 = None -UNARY_OP2 = "cutlass::epilogue::thread::Tanh" - - -@registry.reg("cuda.gemm_rcr_bias_mul_tanh.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_mul_tanh.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_mul_tanh.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_mul_tanh.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_mul_tanh.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_mul_tanh.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py deleted file mode 100644 index 2828d379d..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = Mul(Sigmoid(GeMM(A, B) + bias), D0) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Sigmoid" -BINARY_OP1 = "cutlass::multiplies" -BINARY_OP2 = None -UNARY_OP2 = "cutlass::epilogue::thread::Identity" - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py deleted file mode 100644 index b3d721d6c..000000000 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -GEMM Specialization for -C = TANH(Mul(Sigmoid(GeMM(A, B) + bias), D0)) -where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] -bias[RowMajor][N], D0[RowMajor][M, N] -""" -from ... import registry -from . import common, common_bias_broadcast -from .layout import RCR - -# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 - -UNARY_OP1 = "cutlass::epilogue::thread::Sigmoid" -BINARY_OP1 = "cutlass::multiplies" -BINARY_OP2 = None -UNARY_OP2 = "cutlass::epilogue::thread::Tanh" - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.config") -def gemm_rcr_config(func_attrs, dtype="float16"): - return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.gen_profiler") -def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): - return common_bias_broadcast.gen_profiler( - func_attrs, - workdir, - profiler_filename, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.gen_function") -def gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, -): - return common_bias_broadcast.gen_function( - func_attrs, - exec_cond_template, - dim_info_dict, - RCR, - UNARY_OP1, - BINARY_OP1, - BINARY_OP2, - UNARY_OP2, - ) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.func_decl") -def gen_function_decl(func_attrs): - return common_bias_broadcast.gen_function_decl(func_attrs) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.func_call") -def gen_function_call(func_attrs, indent=" "): - return common_bias_broadcast.gen_function_call(func_attrs, indent) - - -@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.filter") -def function_filter(cfg, func_attrs, ab_alignment): - """Generates function filter. - - Parameters - ---------- - cfg: str - The filename generated for profiler. - func_attrs : Dict - Stores the operation attributes. - ab_alignment: - Input alignments. - - Returns - ------- - bool - If input cfg should be filtered. - """ - return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/compiler/transform/fuse_permute_bmm.py b/python/aitemplate/compiler/transform/fuse_permute_bmm.py deleted file mode 100644 index 22a3ee036..000000000 --- a/python/aitemplate/compiler/transform/fuse_permute_bmm.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -Perform fusions for permute+bmm operators. -""" -from typing import Callable, List, Optional, Set, Tuple, Type, Union - -from .. import ops -from ..base import IntImm, Operator, Tensor -from ..ops.gemm_universal import ( - bmm_ccr, - bmm_crr, - bmm_rcr, - bmm_rrr, - gemm_rcr, - gemm_rcr_bias, - gemm_rrr, - gemm_rrr_bias, -) -from ..ops.tensor import permute021 -from .fuse_utils import extract_only_one_op -from .transform_utils import ( - copy_src_op_attributes, - copy_tensor_attributes, - remove_dst_op_from_tensor, - remove_tensor_from_sorted_graph, - replace_tensor, - sanitize_sorted_graph, -) - -# pylint: disable=C0103,W0612 - - -def _try_extract_one_mm_op(ops: Set[Union[None, Operator]]) -> Union[None, Operator]: - """ - Helper function that returns the matmul op from src_ops() or dst_ops() call. - Return None if there's no bmm ops - """ - if ops is None: - return None - - for op in ops: - if op._attrs["op"].startswith("bmm") or op._attrs["op"].startswith("gemm"): - return op - - return None - - -def _fuse_permute_bmm_ops( - sorted_graph: List[Tensor], - source: List[Type[Operator]], - targets: List[Union[None, Type[Operator]]], - condition: Optional[Callable], -) -> Tuple[bool, List[Tensor]]: - """ - Function that fuses [permute021 + bmm] into corresponding bmm op. - - Parameters - ---------- - sorted_graph : List[Tensor] - AIT graph to run fusion - source: List[Type[Operator]] - Combination of permute+bmm ops to be fused. - This should be of len-2 - targets: List[Type[Operator]] - To be fused bmm that matches the source. - This should be of len 2, which corresponds to the operator that does - permute A and permute B respectively - condition: Optional[Callable] - If not None, we apply on the gemm op to check whether it requires fusion. - """ - assert len(source) == 2, "Source should have 2 elements, got {} instead".format( - len(source) - ) - - new_sorted_graph = [] - fused = False - to_replace = {} - for tensor in sorted_graph: - if tensor in to_replace: - new_sorted_graph.append(to_replace[tensor]) - replace_tensor(tensor, to_replace[tensor]) - del to_replace[tensor] - continue - new_sorted_graph.append(tensor) - - if fused: - continue - if tensor._attrs["is_output"]: - continue - - permute_op = extract_only_one_op(tensor._attrs["src_ops"]) - bmm_op = _try_extract_one_mm_op(tensor._attrs["dst_ops"]) - if permute_op is None or bmm_op is None: - continue - - if permute_op._attrs["op"] != source[0]()._attrs["op"]: - continue - if bmm_op._attrs["op"] != source[1]()._attrs["op"]: - continue - if condition is not None and not condition(bmm_op): - continue - - assert len(permute_op._attrs["inputs"]) == 1 - assert len(bmm_op._attrs["outputs"]) == 1 - - inputs = list(bmm_op._attrs["inputs"]) - if targets[0] is None and inputs[0] == tensor: - continue - if targets[1] is None and inputs[1] == tensor: - continue - - input_tensor = permute_op._attrs["inputs"][0] - output_tensor = bmm_op._attrs["outputs"][0] - - # TODO: Check whether the input is weight to have better compile time - # optimization on preprocessing of pad etc. - permute_shape = tensor.shape() - prepermute_shape = input_tensor.shape() - - if ( - isinstance(prepermute_shape[-1], IntImm) - and prepermute_shape[-1].value() % 2 == 1 - and isinstance(permute_shape[-1], IntImm) - and permute_shape[-1].value() % 2 == 0 - ): - # We don't run the permute+bmm fusion if the permute op could - # turn an odd alignment into even alignment. - continue - - fused = True - - remove_dst_op_from_tensor(bmm_op._attrs["inputs"], bmm_op) - - target = None - if inputs[0] == tensor: - target = targets[0] - inputs[0] = input_tensor - elif inputs[1] == tensor: - target = targets[1] - inputs[1] = input_tensor - else: - raise RuntimeError( - "bmm inputs are {}, not matching permute's output tensor {}".format( - inputs, tensor - ) - ) - - if not tensor.dst_ops(): - # Remove permute configs if this is the last bmm consuming the tensor - remove_dst_op_from_tensor(input_tensor, permute_op) - remove_tensor_from_sorted_graph(tensor) - - new_tensor = target()(*inputs) - copy_tensor_attributes(new_tensor, output_tensor) - copy_src_op_attributes(new_tensor, output_tensor) - to_replace[output_tensor] = new_tensor - - return (fused, sanitize_sorted_graph(new_sorted_graph)) - - -def fuse_permute_bmm(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: - """Fuse [permute021 + bmm] into corresponding bmm op. - - Parameters - ---------- - sorted_graph : List[Tensor] - Input graph - workdir : str, optional - working dir, by default None - - Returns - ------- - List[Tensor] - Fused graph - """ - - def _need_broadcast_gemm(op: Operator): - if not op._attrs["op"].startswith("gemm"): - return False - inputs = op._attrs["inputs"] - return len(inputs[0].shape()) != 2 or len(inputs[1].shape()) != 2 - - permute_mm_patterns = ( - ([permute021, bmm_ccr], [bmm_rcr, bmm_crr], None), - ([permute021, bmm_crr], [bmm_rrr, bmm_ccr], None), - ([permute021, bmm_rcr], [bmm_ccr, bmm_rrr], None), - ([permute021, bmm_rrr], [bmm_crr, bmm_rcr], None), - ([permute021, gemm_rcr], [bmm_ccr, bmm_rrr], _need_broadcast_gemm), - ([permute021, gemm_rrr], [bmm_crr, bmm_rcr], _need_broadcast_gemm), - ( - [permute021, gemm_rcr_bias], - [ops.gemm_universal.bmm_ccr_add, ops.gemm_universal.bmm_rrr_add], - _need_broadcast_gemm, - ), - ( - [permute021, gemm_rrr_bias], - [ops.gemm_universal.bmm_crr_add, None], - _need_broadcast_gemm, - ), - ) - - graph_transformed = True - while graph_transformed: - graph_transformed = False - for source, targets, condition in permute_mm_patterns: - fused, sorted_graph = _fuse_permute_bmm_ops( - sorted_graph, source, targets, condition - ) - graph_transformed |= fused - - return sorted_graph diff --git a/python/aitemplate/utils/logger.py b/python/aitemplate/utils/logger.py deleted file mode 100644 index 7dfdba771..000000000 --- a/python/aitemplate/utils/logger.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -default logger -""" -import logging -import os - - -def info(name, message): - logger = logging.getLogger(name) - logger.info(message) - - -def debug(name, message): - logger = logging.getLogger(name) - logger.debug(message) - - -def warning(name, message): - logger = logging.getLogger(name) - logger.warning(message) - - -def is_debug(): - logger = logging.getLogger("aitemplate") - return logger.level == logging.DEBUG - - -def setup_logger(name): - root_logger = logging.getLogger(name) - info_handle = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s") - info_handle.setFormatter(formatter) - root_logger.addHandler(info_handle) - root_logger.propagate = False - - DEFAULT_LOGLEVEL = logging.getLogger().level - log_level_str = os.environ.get("LOGLEVEL", None) - LOG_LEVEL = ( - getattr(logging, log_level_str.upper()) - if log_level_str is not None - else DEFAULT_LOGLEVEL - ) - root_logger.setLevel(LOG_LEVEL) - return root_logger