diff --git a/.github/workflows/building.yml b/.github/workflows/building.yml deleted file mode 100644 index 24aff3a..0000000 --- a/.github/workflows/building.yml +++ /dev/null @@ -1,112 +0,0 @@ -name: CI - -on: - push: - branches: - - master - - ci -jobs: - wheel: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: true - matrix: - os: - - ubuntu-18.04 -# - macos-10.15 -# - windows-2019 -# python-version: ['3.7', '3.8', '3.9', '3.10'] - python-version: ['3.9'] - torch-version: [1.11.0] -# cuda-version: ['cpu', 'cu102', 'cu113', 'cu115'] -# cuda-version: ['cu113'] - cuda-version: ['cpu'] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - default: true - - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - run: | - bash .github/workflows/cuda/${{ matrix.cuda-version }}-${{ runner.os }}.sh - - - name: Install PyTorch ${{ matrix.torch-version }}+${{ matrix.cuda-version }} - run: | - pip install numpy typing-extensions dataclasses - pip install --no-index --no-cache-dir torch==${{ matrix.torch-version}} -f https://download.pytorch.org/whl/${{ matrix.cuda-version }}/torch_stable.html - python -c "import torch; print('PyTorch:', torch.__version__)" - python -c "import torch; print('CUDA:', torch.version.cuda)" - export PY_SITE_DIR=$(python -c "import site; print(site.getsitepackages()[0])") - echo "LIBTORCH=$PY_SITE_DIR/torch" >> $GITHUB_ENV - echo "LD_LIBRARY_PATH=$PY_SITE_DIR/torch/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV - - - uses: messense/maturin-action@v1 - with: - manylinux: auto - command: build - args: --release -o dist - env: - LIBTORCH_CXX11_ABI: "0" - - - name: Upload wheels - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist - -# windows: -# runs-on: windows-latest -# steps: -# - uses: actions/checkout@v2 -# - uses: messense/maturin-action@v1 -# with: -# command: build -# args: --release --no-sdist -o dist -# - name: Upload wheels -# uses: actions/upload-artifact@v2 -# with: -# name: wheels -# path: dist -# -# macos: -# runs-on: macos-latest -# steps: -# - uses: actions/checkout@v2 -# - uses: messense/maturin-action@v1 -# with: -# command: build -# args: --release --no-sdist -o dist --universal2 -# - name: Upload wheels -# uses: actions/upload-artifact@v2 -# with: -# name: wheels -# path: dist - - release: - name: Release - runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" - needs: [ wheel ] - steps: - - uses: actions/download-artifact@v2 - with: - name: wheels - - name: Publish to PyPI - uses: messense/maturin-action@v1 - env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --skip-existing * - -# act-CI-wheel \ No newline at end of file diff --git a/.github/workflows/cuda/cpu-env.sh b/.github/workflows/cuda/cpu-env.sh new file mode 100644 index 0000000..98bdafd --- /dev/null +++ b/.github/workflows/cuda/cpu-env.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export FORCE_CUDA=0 +export TORCH_CUDA_ARCH_LIST=0 diff --git a/.github/workflows/cuda/cu101-Linux-env.sh b/.github/workflows/cuda/cu101-Linux-env.sh deleted file mode 100644 index 2559816..0000000 --- a/.github/workflows/cuda/cu101-Linux-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/usr/local/cuda-10.1 -LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} -PATH=${CUDA_HOME}/bin:${PATH} - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" diff --git a/.github/workflows/cuda/cu101-Linux.sh b/.github/workflows/cuda/cu101-Linux.sh deleted file mode 100755 index ffb1dca..0000000 --- a/.github/workflows/cuda/cu101-Linux.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -OS=ubuntu1804 - -wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin -sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 -wget -nv https://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda-repo-${OS}-10-1-local-10.1.243-418.87.00_1.0-1_amd64.deb -sudo dpkg -i cuda-repo-${OS}-10-1-local-10.1.243-418.87.00_1.0-1_amd64.deb -sudo apt-key add /var/cuda-repo-10-1-local-10.1.243-418.87.00/7fa2af80.pub - -sudo apt-get -qq update -sudo apt install cuda-nvcc-10-1 cuda-libraries-dev-10-1 -sudo apt clean - -rm -f https://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda-repo-${OS}-10-1-local-10.1.243-418.87.00_1.0-1_amd64.deb diff --git a/.github/workflows/cuda/cu101-Windows-env.sh b/.github/workflows/cuda/cu101-Windows-env.sh deleted file mode 100644 index 24ace97..0000000 --- a/.github/workflows/cuda/cu101-Windows-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v10.1 -PATH=${CUDA_HOME}/bin:$PATH -PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" diff --git a/.github/workflows/cuda/cu101-Windows.sh b/.github/workflows/cuda/cu101-Windows.sh deleted file mode 100755 index 362cd2b..0000000 --- a/.github/workflows/cuda/cu101-Windows.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# Install NVIDIA drivers, see: -# https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 -curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" -7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" - -export CUDA_SHORT=10.1 -export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers/ -export CUDA_FILE=cuda_${CUDA_SHORT}.243_426.00_win10.exe - -# Install CUDA: -curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" -echo "" -echo "Installing from ${CUDA_FILE}..." -PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" -echo "Done!" -rm -f "${CUDA_FILE}" diff --git a/.github/workflows/cuda/cu101-env.sh b/.github/workflows/cuda/cu101-env.sh new file mode 100644 index 0000000..c1fe508 --- /dev/null +++ b/.github/workflows/cuda/cu101-env.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" diff --git a/.github/workflows/cuda/cu102-Linux-env.sh b/.github/workflows/cuda/cu102-Linux-env.sh deleted file mode 100644 index a8f60a8..0000000 --- a/.github/workflows/cuda/cu102-Linux-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/usr/local/cuda-10.2 -LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} -PATH=${CUDA_HOME}/bin:${PATH} - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" diff --git a/.github/workflows/cuda/cu102-Linux.sh b/.github/workflows/cuda/cu102-Linux.sh deleted file mode 100755 index 85e1ed2..0000000 --- a/.github/workflows/cuda/cu102-Linux.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -OS=ubuntu1804 - -wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin -sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 -wget -nv https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda-repo-${OS}-10-2-local-10.2.89-440.33.01_1.0-1_amd64.deb -sudo dpkg -i cuda-repo-${OS}-10-2-local-10.2.89-440.33.01_1.0-1_amd64.deb -sudo apt-key add /var/cuda-repo-10-2-local-10.2.89-440.33.01/7fa2af80.pub - -sudo apt-get -qq update -sudo apt install cuda-nvcc-10-2 cuda-libraries-dev-10-2 -sudo apt clean - -rm -f https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda-repo-${OS}-10-2-local-10.2.89-440.33.01_1.0-1_amd64.deb diff --git a/.github/workflows/cuda/cu102-Windows-env.sh b/.github/workflows/cuda/cu102-Windows-env.sh deleted file mode 100644 index 1888e2c..0000000 --- a/.github/workflows/cuda/cu102-Windows-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v10.2 -PATH=${CUDA_HOME}/bin:$PATH -PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" diff --git a/.github/workflows/cuda/cu102-Windows.sh b/.github/workflows/cuda/cu102-Windows.sh deleted file mode 100755 index 368420b..0000000 --- a/.github/workflows/cuda/cu102-Windows.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# Install NVIDIA drivers, see: -# https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 -curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" -7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" - -export CUDA_SHORT=10.2 -export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers -export CUDA_FILE=cuda_${CUDA_SHORT}.89_441.22_win10.exe - -# Install CUDA: -curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" -echo "" -echo "Installing from ${CUDA_FILE}..." -PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" -echo "Done!" -rm -f "${CUDA_FILE}" diff --git a/.github/workflows/cuda/cu102-env.sh b/.github/workflows/cuda/cu102-env.sh new file mode 100644 index 0000000..c1fe508 --- /dev/null +++ b/.github/workflows/cuda/cu102-env.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" diff --git a/.github/workflows/cuda/cu111-Linux-env.sh b/.github/workflows/cuda/cu111-Linux-env.sh deleted file mode 100644 index bd30537..0000000 --- a/.github/workflows/cuda/cu111-Linux-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/usr/local/cuda-11.1 -LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} -PATH=${CUDA_HOME}/bin:${PATH} - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" diff --git a/.github/workflows/cuda/cu111-Linux.sh b/.github/workflows/cuda/cu111-Linux.sh deleted file mode 100755 index 31621e3..0000000 --- a/.github/workflows/cuda/cu111-Linux.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -OS=ubuntu1804 - -wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin -sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 -wget -nv https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda-repo-${OS}-11-1-local_11.1.1-455.32.00-1_amd64.deb -sudo dpkg -i cuda-repo-${OS}-11-1-local_11.1.1-455.32.00-1_amd64.deb -sudo apt-key add /var/cuda-repo-${OS}-11-1-local/7fa2af80.pub - -sudo apt-get -qq update -sudo apt install cuda-nvcc-11-1 cuda-libraries-dev-11-1 -sudo apt clean - -rm -f https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda-repo-${OS}-11-1-local_11.1.1-455.32.00-1_amd64.deb diff --git a/.github/workflows/cuda/cu111-Windows-env.sh b/.github/workflows/cuda/cu111-Windows-env.sh deleted file mode 100644 index 0e672a3..0000000 --- a/.github/workflows/cuda/cu111-Windows-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.1 -PATH=${CUDA_HOME}/bin:$PATH -PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="6.0+PTX" diff --git a/.github/workflows/cuda/cu111-Windows.sh b/.github/workflows/cuda/cu111-Windows.sh deleted file mode 100755 index 4cd9fe3..0000000 --- a/.github/workflows/cuda/cu111-Windows.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# Install NVIDIA drivers, see: -# https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 -curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" -7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" - -export CUDA_SHORT=11.1 -export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers -export CUDA_FILE=cuda_${CUDA_SHORT}.1_456.81_win10.exe - -# Install CUDA: -curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" -echo "" -echo "Installing from ${CUDA_FILE}..." -PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" -echo "Done!" -rm -f "${CUDA_FILE}" diff --git a/.github/workflows/cuda/cu111-env.sh b/.github/workflows/cuda/cu111-env.sh new file mode 100644 index 0000000..c38498b --- /dev/null +++ b/.github/workflows/cuda/cu111-env.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" diff --git a/.github/workflows/cuda/cu113-Linux-env.sh b/.github/workflows/cuda/cu113-Linux-env.sh deleted file mode 100644 index a3befec..0000000 --- a/.github/workflows/cuda/cu113-Linux-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/usr/local/cuda-11.3 -LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} -PATH=${CUDA_HOME}/bin:${PATH} - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" diff --git a/.github/workflows/cuda/cu113-Linux.sh b/.github/workflows/cuda/cu113-Linux.sh deleted file mode 100755 index 1cdd9d9..0000000 --- a/.github/workflows/cuda/cu113-Linux.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -OS=ubuntu1804 - -wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin -sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 -wget -nv https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb -sudo dpkg -i cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb -sudo apt-key add /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub - -sudo apt-get -qq update -sudo apt install cuda-nvcc-11-3 cuda-libraries-dev-11-3 -sudo apt clean - -rm -f https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb diff --git a/.github/workflows/cuda/cu113-Windows-env.sh b/.github/workflows/cuda/cu113-Windows-env.sh deleted file mode 100644 index 3a662fb..0000000 --- a/.github/workflows/cuda/cu113-Windows-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 -PATH=${CUDA_HOME}/bin:$PATH -PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="6.0+PTX" diff --git a/.github/workflows/cuda/cu113-Windows.sh b/.github/workflows/cuda/cu113-Windows.sh deleted file mode 100755 index 3cd7133..0000000 --- a/.github/workflows/cuda/cu113-Windows.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# Install NVIDIA drivers, see: -# https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 -curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" -7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" - -export CUDA_SHORT=11.3 -export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers -export CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe - -# Install CUDA: -curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" -echo "" -echo "Installing from ${CUDA_FILE}..." -PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" -echo "Done!" -rm -f "${CUDA_FILE}" diff --git a/.github/workflows/cuda/cu113-env.sh b/.github/workflows/cuda/cu113-env.sh new file mode 100644 index 0000000..c38498b --- /dev/null +++ b/.github/workflows/cuda/cu113-env.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" diff --git a/.github/workflows/cuda/cu115-Linux-env.sh b/.github/workflows/cuda/cu115-Linux-env.sh deleted file mode 100644 index 1c148a2..0000000 --- a/.github/workflows/cuda/cu115-Linux-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/usr/local/cuda-11.5 -LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} -PATH=${CUDA_HOME}/bin:${PATH} - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" diff --git a/.github/workflows/cuda/cu115-Linux.sh b/.github/workflows/cuda/cu115-Linux.sh deleted file mode 100755 index 02bcb4d..0000000 --- a/.github/workflows/cuda/cu115-Linux.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -OS=ubuntu1804 - -wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin -sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 -wget -nv https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda-repo-${OS}-11-5-local_11.5.2-495.29.05-1_amd64.deb -sudo dpkg -i cuda-repo-${OS}-11-5-local_11.5.2-495.29.05-1_amd64.deb -sudo apt-key add /var/cuda-repo-${OS}-11-5-local/7fa2af80.pub - -sudo apt-get -qq update -sudo apt install cuda-nvcc-11-5 cuda-libraries-dev-11-5 -sudo apt clean - -rm -f https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda-repo-${OS}-11-5-local_11.5.2-495.29.05-1_amd64.deb diff --git a/.github/workflows/cuda/cu115-Windows-env.sh b/.github/workflows/cuda/cu115-Windows-env.sh deleted file mode 100644 index 3a662fb..0000000 --- a/.github/workflows/cuda/cu115-Windows-env.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 -PATH=${CUDA_HOME}/bin:$PATH -PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH - -export FORCE_CUDA=1 -export TORCH_CUDA_ARCH_LIST="6.0+PTX" diff --git a/.github/workflows/cuda/cu115-Windows.sh b/.github/workflows/cuda/cu115-Windows.sh deleted file mode 100755 index db2559c..0000000 --- a/.github/workflows/cuda/cu115-Windows.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# TODO We currently use CUDA 11.3 to build CUDA 11.5 Windows wheels - -# Install NVIDIA drivers, see: -# https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 -curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" -7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" - -export CUDA_SHORT=11.3 -export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers -export CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe - -# Install CUDA: -curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" -echo "" -echo "Installing from ${CUDA_FILE}..." -PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" -echo "Done!" -rm -f "${CUDA_FILE}" diff --git a/.github/workflows/cuda/cu115-env.sh b/.github/workflows/cuda/cu115-env.sh new file mode 100644 index 0000000..c38498b --- /dev/null +++ b/.github/workflows/cuda/cu115-env.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..d9e12ee --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,96 @@ +name: Release + +on: + push: + branches: + - ci + workflow_dispatch: + + +jobs: + pypi: + strategy: + fail-fast: true + matrix: + os: + - ubuntu-18.04 + python-version: ['3.7', '3.9', '3.10'] + torch-version: [ 1.11.0 ] + cuda-version: [ 'cpu' ] + + runs-on: ${{ matrix.os }} + container: + image: pytorch/manylinux-${{ matrix.cuda-version }} + + steps: + - uses: actions/checkout@v2 + - name: Setup Env ${{ matrix.torch-version }}+${{ matrix.cuda-version }} + run: bash .github/workflows/scripts/setup_pypi.sh + env: + PYTHON_VERSION: ${{ matrix.python-version }} + TORCH_VERSION: ${{ matrix.torch-version }} + CUDA_VERSION: ${{ matrix.cuda-version }} + + - name: Build wheel + run: | + pip install -r requirements-dev.txt + python setup.py bdist_wheel --dist-dir=dist + python .github/workflows/scripts/auditwheel repair dist/*.whl --plat=manylinux_2_17_x86_64 + env: + LIBTORCH_CXX11_ABI: "0" + + - name: Publish to PyPI Test + if: matrix.cuda-version == 'cpu' + run: maturin upload --repository-url https://test.pypi.org/legacy/ --skip-existing wheelhouse/* + env: + MATURIN_PYPI_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} + + - name: Publish to PyPI + if: matrix.cuda-version == 'cpu' && startsWith(github.ref, 'refs/tags') + run: maturin upload --skip-existing wheelhouse/* + env: + MATURIN_PYPI_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} + + conda: + strategy: + fail-fast: true + matrix: + os: + - ubuntu-18.04 + python-version: ['3.7', '3.9', '3.10'] + torch-version: [ 1.11.0 ] + cuda-version: ['cpu', 'cu102', 'cu113', 'cu115'] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Install Conda environment with Micromamba + uses: mamba-org/provision-with-micromamba@main + with: + channels: conda-forge,defaults + environment-name: ${{ matrix.torch-version }}+${{ matrix.cuda-version }} + environment-file: false + micromamba-version: 0.23.3 + extra-specs: | + python=${{ matrix.python-version }} + boa=0.11.0 + anaconda-client + + - name: Build package ${{ matrix.torch-version }}+${{ matrix.cuda-version }} + run: | + source .github/workflows/cuda/$CUDA_VERSION-env.sh + ./conda/tch_geometric/build_conda.sh ${{ matrix.python-version }} ${{ matrix.torch-version }} ${{ matrix.cuda-version }} + env: + PYTHON_VERSION: ${{ matrix.python-version }} + TORCH_VERSION: ${{ matrix.torch-version }} + CUDA_VERSION: ${{ matrix.cuda-version }} + shell: + bash -l {0} + + - name: Publish Conda package on personal channel + run: | + anaconda upload --force --label main $HOME/conda-bld/*/*.tar.bz2 + env: + ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }} + shell: + bash -l {0} diff --git a/.github/workflows/scripts/auditwheel b/.github/workflows/scripts/auditwheel new file mode 100644 index 0000000..aab4cbb --- /dev/null +++ b/.github/workflows/scripts/auditwheel @@ -0,0 +1,22 @@ +#!/bin/python +# -*- coding: utf-8 -*- + +# Monkey patch to not ship libjvm.so in pypi wheels +import sys +import re + +from auditwheel.main import main +from auditwheel.policy import _POLICIES as POLICIES + +# libjvm is loaded dynamically; do not include it +for p in POLICIES: + p['lib_whitelist'].append('libtorch_cuda_cu.so') + p['lib_whitelist'].append('libtorch_cuda_cpp.so') + p['lib_whitelist'].append('libtorch_cpu.so') + p['lib_whitelist'].append('libtorch_python.so') + p['lib_whitelist'].append('libtorch.so') + p['lib_whitelist'].append('libtorch_cuda.so') + +if __name__ == "__main__": + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) \ No newline at end of file diff --git a/.github/workflows/scripts/setup_pypi.sh b/.github/workflows/scripts/setup_pypi.sh new file mode 100644 index 0000000..9741dd8 --- /dev/null +++ b/.github/workflows/scripts/setup_pypi.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +PY_MAPPING='{ "3.7": "cp37-cp37m", "3.8": "cp38-cp38", "3.9": "cp39-cp39", "3.10": "cp310-cp310" }' +export PYTHON_ALIAS="$(python -c "print(${PY_MAPPING}['${PYTHON_VERSION}'])")" + +echo Setup Python $PYTHON_VERSION - $PYTHON_ALIAS +export PATH="/opt/python/$PYTHON_ALIAS/bin:$PATH" +export PYTHON_INCLUDE_DIRS=$(python -c "from sysconfig import get_paths as gp; print(gp()['include'])") +export PYTHON_SITE_DIR=$(python -c "import site; print(site.getsitepackages()[0])") + +echo Install Rust +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +export PATH=$HOME/.cargo/bin:$PATH + +echo Install PyTorch $TORCH_VERSION+$CUDA_VERSION +export TORCH_CUDA_VERSION=$(echo $CUDA_VERSION | sed "s/cuda/cu/g") +pip install numpy typing-extensions dataclasses toml auditwheel +pip install --no-index --no-cache-dir torch==$TORCH_VERSION -f https://download.pytorch.org/whl/$CUDA_VERSION/torch_stable.html +python -c "import torch; print('PyTorch:', torch.__version__)" +python -c "import torch; print('CUDA:', torch.version.cuda)" + +echo Updating auditwheel +cat .github/workflows/auditwheel > $(which auditwheel) + +echo Setting Vars +echo "PATH=$PATH" >> $GITHUB_ENV +echo "PYTHON_INCLUDE_DIRS=$PYTHON_INCLUDE_DIRS" >> $GITHUB_ENV +echo "TORCH_CUDA_VERSION=$TORCH_CUDA_VERSION" >> $GITHUB_ENV +echo "PYO3_PYTHON=/opt/python/$PYTHON_ALIAS/bin/python" >> $GITHUB_ENV +echo "LIBTORCH=$PYTHON_SITE_DIR/torch" >> $GITHUB_ENV +echo "LD_LIBRARY_PATH=$PYTHON_SITE_DIR/torch/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV diff --git a/.gitignore b/.gitignore index d39d824..b56ab1c 100644 --- a/.gitignore +++ b/.gitignore @@ -70,4 +70,6 @@ docs/_build/ # Pyenv .python-version -.cargo/config.toml \ No newline at end of file +.cargo/config.toml +dist +/wheelhouse \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 60f3e83..f698c09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "anyhow" -version = "1.0.55" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "159bb86af3a200e19a068f4224eae4c8bb2d0fa054c7e5d1cacd5cef95e684cd" +checksum = "08f9b8508dccb7687a1d6c4ce66b2b0ecef467c94667de27d8d7fe1f8d2a9cdc" [[package]] name = "atty" @@ -161,9 +161,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.2" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e54ea8bc3fb1ee042f5aace6e3c6e025d3874866da222930f70ce62aceba0bfa" +checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53" dependencies = [ "cfg-if", "crossbeam-utils", @@ -182,10 +182,11 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c00d6d2ea26e8b151d99093005cb442fb9a37aeaca582a03ec70946f49ab5ed9" +checksum = "1145cf131a2c6ba0615079ab6a638f7e1973ac9c2634fcbeaaad6114246efe8c" dependencies = [ + "autocfg", "cfg-if", "crossbeam-utils", "lazy_static", @@ -195,9 +196,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e5bed1f1c269533fa816a0a5492b3545209a205ca1a54842be180eb63a16a6" +checksum = "0bf124c720b7686e3c2663cf54062ab0f68a88af2fb6a030e87e30bf721fcb38" dependencies = [ "cfg-if", "lazy_static", @@ -233,9 +234,9 @@ checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" [[package]] name = "flate2" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f" +checksum = "b39522e96686d38f4bc984b9198e3a0613264abaebaff2c5c918bfa6b6da09af" dependencies = [ "cfg-if", "crc32fast", @@ -245,9 +246,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" +checksum = "9be70c98951c83b8d2f8f60d7065fa6d5146873094452a1008da8c2f1e4205ad" dependencies = [ "cfg-if", "libc", @@ -318,15 +319,15 @@ checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" [[package]] name = "itoa" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" +checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" [[package]] name = "js-sys" -version = "0.3.56" +version = "0.3.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04" +checksum = "671a26f820db17c2a2750743f1dd03bafd15b98c9f30c7c2628c024c05d73397" dependencies = [ "wasm-bindgen", ] @@ -339,24 +340,25 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.119" +version = "0.2.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bf2e165bb3457c8e098ea76f3e3bc9db55f87aa90d52d0e6be741470916aaa4" +checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" [[package]] name = "lock_api" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" +checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" dependencies = [ + "autocfg", "scopeguard", ] [[package]] name = "log" -version = "0.4.14" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" +checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ "cfg-if", ] @@ -372,9 +374,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "memoffset" @@ -387,12 +389,11 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.4.4" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +checksum = "d2b29bd4bc3f33391105ebee3589c19197c4271e3e5a9ec9bfe8127eeff8f082" dependencies = [ "adler", - "autocfg", ] [[package]] @@ -410,18 +411,18 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +checksum = "97fbc387afefefd5e9e39493299f3069e14a140dd34dc19b4c1c1a8fddb6a790" dependencies = [ "num-traits", ] [[package]] name = "num-integer" -version = "0.1.44" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" dependencies = [ "autocfg", "num-traits", @@ -429,9 +430,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", ] @@ -448,9 +449,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.9.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da32515d9f6e6e489d7bc9d84c71b060db7247dc035bbe44eac88cf87486d8d5" +checksum = "7b10983b38c53aebdf33f542c6275b0f58a238129d00c4ae0e6fb59738d783ca" [[package]] name = "oorandom" @@ -504,9 +505,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe" +checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" [[package]] name = "plotters" @@ -550,18 +551,18 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.36" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029" +checksum = "c54b25569025b7fc9651de43004ae593a75ad88543b17178aa5e1b9c4f15f56f" dependencies = [ - "unicode-xid", + "unicode-ident", ] [[package]] name = "pyo3" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cf01dbf1c05af0a14c7779ed6f3aa9deac9c3419606ac9de537a2d649005720" +checksum = "d41d50a7271e08c7c8a54cd24af5d62f73ee3a6f6a314215281ebdec421d5752" dependencies = [ "cfg-if", "indoc", @@ -575,18 +576,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf9e4d128bfbddc898ad3409900080d8d5095c379632fbbfbb9c8cfb1fb852b" +checksum = "779239fc40b8e18bc8416d3a37d280ca9b9fb04bda54b98037bb6748595c2410" dependencies = [ "once_cell", ] [[package]] name = "pyo3-macros" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67701eb32b1f9a9722b4bc54b548ff9d7ebfded011c12daece7b9063be1fd755" +checksum = "00b247e8c664be87998d8628e86f282c25066165f1f8dda66100c48202fdb93a" dependencies = [ "pyo3-macros-backend", "quote", @@ -595,9 +596,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f44f09e825ee49a105f2c7b23ebee50886a9aee0746f4dd5a704138a64b0218a" +checksum = "5a8c2812c412e00e641d99eeb79dd478317d981d938aa60325dfa7157b607095" dependencies = [ "proc-macro2", "pyo3-build-config", @@ -607,9 +608,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.15" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "864d3e96a899863136fc6e99f3d7cae289dafe43bf2c5ac19b70df7210c0a145" +checksum = "a1feb54ed693b93a84e14094943b84b7c4eae204c512b7ccb95ab0c66d278ad1" dependencies = [ "proc-macro2", ] @@ -652,9 +653,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.5.1" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" +checksum = "bd99e5772ead8baa5215278c9b15bf92087709e9c1b2d1f97cdb5a183c933a7d" dependencies = [ "autocfg", "crossbeam-deque", @@ -664,31 +665,30 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.9.1" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" +checksum = "258bcdb5ac6dad48491bb2992db6b7cf74878b0384908af124823d118c99683f" dependencies = [ "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "lazy_static", "num_cpus", ] [[package]] name = "redox_syscall" -version = "0.2.10" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff" +checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" dependencies = [ "bitflags", ] [[package]] name = "regex" -version = "1.5.4" +version = "1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" +checksum = "d83f127d94bdbcda4c8cc2e50f6f84f4b611f69c902699ca385a39c3a75f9ff1" dependencies = [ "regex-syntax", ] @@ -701,9 +701,9 @@ checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" [[package]] name = "regex-syntax" -version = "0.6.25" +version = "0.6.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +checksum = "49b3de9ec5dc0a3417da371aab17d729997c15010e7fd24ff707773a33bddb64" [[package]] name = "rustc_version" @@ -716,9 +716,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" +checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" [[package]] name = "same-file" @@ -737,15 +737,15 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "semver" -version = "1.0.6" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a3381e03edd24287172047536f20cabde766e2cd3e65e6b00fb3af51c4f38d" +checksum = "8cb243bdfdb5936c8dc3c45762a19d12ab4550cdc753bc247637d4ec35a040fd" [[package]] name = "serde" -version = "1.0.136" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789" +checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" [[package]] name = "serde_cbor" @@ -759,9 +759,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.136" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9" +checksum = "1f26faba0c3959972377d3b2d306ee9f71faee9714294e41bb777f83f88578be" dependencies = [ "proc-macro2", "quote", @@ -770,11 +770,11 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.79" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95" +checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c" dependencies = [ - "itoa 1.0.1", + "itoa 1.0.2", "ryu", "serde", ] @@ -787,19 +787,19 @@ checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" [[package]] name = "syn" -version = "1.0.86" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a65b3f4ffa0092e9887669db0eae07941f023991ab58ea44da8fe8e2d511c6b" +checksum = "fbaf6116ab8924f39d52792136fb74fd60a80194cf1b1c6ffa6453eef1c3f942" dependencies = [ "proc-macro2", "quote", - "unicode-xid", + "unicode-ident", ] [[package]] name = "tch" -version = "0.7.0" -source = "git+https://github.com/EgorDm/tch-rs.git?branch=main#31b26f86e79c5495c56713e1ae829836156546e0" +version = "0.7.2" +source = "git+https://github.com/EgorDm/tch-rs.git?branch=main#40ad7d0c5c1022a647f39162037818a25e58df42" dependencies = [ "half", "lazy_static", @@ -838,18 +838,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" +checksum = "bd829fe32373d27f76265620b5309d0340cb8550f523c1dda251d6298069069a" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" +checksum = "0396bc89e626244658bef819e22d0cc459e795a5ebe878e6ec336d1674a8d79a" dependencies = [ "proc-macro2", "quote", @@ -878,8 +878,8 @@ dependencies = [ [[package]] name = "torch-sys" -version = "0.7.0" -source = "git+https://github.com/EgorDm/tch-rs.git?branch=main#31b26f86e79c5495c56713e1ae829836156546e0" +version = "0.7.2" +source = "git+https://github.com/EgorDm/tch-rs.git?branch=main#40ad7d0c5c1022a647f39162037818a25e58df42" dependencies = [ "anyhow", "cc", @@ -888,22 +888,22 @@ dependencies = [ ] [[package]] -name = "unicode-width" -version = "0.1.9" +name = "unicode-ident" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" +checksum = "d22af068fba1eb5edcb4aea19d382b2a3deb4c8f9d475c589b6ada9e0fd493ee" [[package]] -name = "unicode-xid" -version = "0.2.2" +name = "unicode-width" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" [[package]] name = "unindent" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514672a55d7380da379785a4d70ca8386c8883ff7eaae877be4d2081cebe73d8" +checksum = "52fee519a3e570f7df377a06a1a7775cdbfb7aa460be7e08de2b1f0e69973a44" [[package]] name = "walkdir" @@ -924,9 +924,9 @@ checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" [[package]] name = "wasm-bindgen" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06" +checksum = "27370197c907c55e3f1a9fbe26f44e937fe6451368324e009cba39e139dc08ad" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -934,9 +934,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca" +checksum = "53e04185bfa3a779273da532f5025e33398409573f348985af9a1cbf3774d3f4" dependencies = [ "bumpalo", "lazy_static", @@ -949,9 +949,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01" +checksum = "17cae7ff784d7e83a2fe7611cfe766ecf034111b49deb850a3dc7699c08251f5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -959,9 +959,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc" +checksum = "99ec0dc7a4756fffc231aab1b9f2f578d23cd391390ab27f952ae0c9b3ece20b" dependencies = [ "proc-macro2", "quote", @@ -972,15 +972,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2" +checksum = "d554b7f530dee5964d9a9468d95c1f8b8acae4f282807e7d27d4b03099a46744" [[package]] name = "web-sys" -version = "0.3.56" +version = "0.3.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb" +checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 42e4675..3ddb37a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,3 @@ criterion = "0.3" [features] extension-module = ["pyo3/extension-module", "tch/torch_python", "pyo3"] default = ["extension-module"] - -[[bench]] -name = "internal" -harness = false diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..148f624 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include Cargo.toml +recursive-include src * +recursive-include tch_geometric * \ No newline at end of file diff --git a/Makefile b/Makefile index 4b162d7..e3fd27d 100644 --- a/Makefile +++ b/Makefile @@ -7,4 +7,7 @@ develop: @$(ACTIVE_ENV); $(VENV_ACTIVATE); maturin develop release: - @$(ACTIVE_ENV); $(VENV_ACTIVATE); maturin develop --release \ No newline at end of file + @$(ACTIVE_ENV); $(VENV_ACTIVATE); maturin develop --release + +build-release: + @$(ACTIVE_ENV); $(VENV_ACTIVATE); maturin build --release -o dist \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..dfe6426 --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ + +
+
+

tch-geometric

+

+ License + CI +

+

+ Pytorch Geometric extension library +

+
+

+ Features • + Installation • + Examples +

+ + +## Features +Pytorch Geometric extension library with additional graph sampling algorithms. + +Supports: + +* Node2Vec (`random_walk`) +* Temporal Random Walk (`temporal_random_walk`) +* Biased Temporal Random Walk (CTDNE) (`biased_tempo_random_walk`) +* Negative Sampling (`negative_sample_neighbors_homogenous` and `negative_sample_neighbors_heterogenous`) +* GraphSAGE budget sampling (`budget_sampling`) +* Temporal Heterogenous Graph Transformer (HGT) sampling (`hgt_sampling`) +* GraphSAGE (`neighbor_sampling_heterogenous` and `neighbor_sampling_homogenous`) + +## TODO: +* Cite appropriately +* Add usage guide + + +## Examples +Check examples folder \ No newline at end of file diff --git a/benches/internal.rs b/benches/internal.rs deleted file mode 100644 index fec3026..0000000 --- a/benches/internal.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::convert::TryFrom; -use criterion::{criterion_group, criterion_main, Criterion}; -use rand::SeedableRng; -use tch::Tensor; -use tch_geometric::data::{CsrGraphStorage, CsrGraph, load_karate_graph}; -use tch_geometric::algo::negative_sampling::{negative_sample_neighbors_homogenous}; - -pub fn internal_benchmark(c: &mut Criterion) { - let (x, _, coo_graph) = load_karate_graph(); - - let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); - - let node_count = x.size()[0]; - let graph_data = CsrGraphStorage::try_from(&coo_graph).unwrap(); - let graph = CsrGraph::::try_from(&graph_data).unwrap(); - let inputs = Tensor::of_slice(&[0_i64, 1, 2, 3, 4, 5, 6, 7, 8, 9]); - - c.bench_function("negative_sample_neighbors", |b| b.iter(|| { - negative_sample_neighbors_homogenous( - &mut rng, - &graph, - (node_count, node_count), - &inputs, - 10, - 5, - ).unwrap() - })); -} - -criterion_group!(benches, internal_benchmark); -criterion_main!(benches); diff --git a/conda/tch_geometric/build_conda.sh b/conda/tch_geometric/build_conda.sh new file mode 100755 index 0000000..3124035 --- /dev/null +++ b/conda/tch_geometric/build_conda.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +export PYTHON_VERSION=$1 +export TORCH_VERSION=$2 +export CUDA_VERSION=$3 + +export CONDA_PYTORCH_CONSTRAINT="pytorch==${TORCH_VERSION%.*}.*" + +if [ "${CUDA_VERSION}" = "cpu" ]; then + export CONDA_CUDATOOLKIT_CONSTRAINT="cpuonly # [not osx]" +else + case $CUDA_VERSION in + cu115) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.5.*" + ;; + cu113) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.3.*" + ;; + cu111) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.1.*" + ;; + cu102) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==10.2.*" + ;; + cu101) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==10.1.*" + ;; + *) + echo "Unrecognized CUDA_VERSION=$CUDA_VERSION" + exit 1 + ;; + esac +fi + +echo "PyTorch $TORCH_VERSION+$CUDA_VERSION" +echo "- $CONDA_PYTORCH_CONSTRAINT" +echo "- $CONDA_CUDATOOLKIT_CONSTRAINT" + +conda mambabuild . -c pytorch -c default -c nvidia -c conda-forge --output-folder "$HOME/conda-bld" \ No newline at end of file diff --git a/conda/tch_geometric/meta.yaml b/conda/tch_geometric/meta.yaml new file mode 100644 index 0000000..036875a --- /dev/null +++ b/conda/tch_geometric/meta.yaml @@ -0,0 +1,50 @@ +package: + name: tch_geometric + version: 0.1.0 + +source: + path: ../.. + +requirements: + build: + - {{ compiler('c') }} # [win] + + host: + - pip + - python {{ environ.get('PYTHON_VERSION') }} + - {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} + - {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }} + - conda-forge::toml==0.10.2 + - conda-forge::setuptools-rust==1.3.0 + - conda-forge::rust==1.61.0 + + run: + - python {{ environ.get('PYTHON_VERSION') }} + - {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} + - {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }} + +build: + string: py{{ environ.get('PYTHON_VERSION').replace('.', '') }}_torch_{{ environ['TORCH_VERSION'] }}_{{ environ['CUDA_VERSION'] }} + script: | + export PYTHON_INCLUDE_DIRS=$(python -c "from sysconfig import get_paths as gp; print(gp()['include'])") + export PYTHON_SITE_DIR=$(python -c "import site; print(site.getsitepackages()[0])") + export LIBTORCH="$PYTHON_SITE_DIR/torch" + export LD_LIBRARY_PATH="$PYTHON_SITE_DIR/torch/lib:$LD_LIBRARY_PATH" + export LIBTORCH_CXX11_ABI="0" + pip install . + script_env: + - FORCE_CUDA + - TORCH_CUDA_ARCH_LIST + - TORCH_CUDA_VERSION + +test: + commands: + - | + export PYTHON_SITE_DIR=$(python -c "import site; print(site.getsitepackages()[0])") + export LD_LIBRARY_PATH="$PYTHON_SITE_DIR/torch/lib:$LD_LIBRARY_PATH" + python -c "import tch_geometric" + +about: + home: https://github.com/EgorDm/tch-geometric + license: MIT + summary: Pytorch Geometric extension library with additional graph sampling algorithms. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d2d2687..e5d914f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,19 @@ [build-system] -requires = ["maturin>=0.12,<0.13"] -build-backend = "maturin" +requires = ["setuptools>=61.0.0", "wheel", "setuptools_rust>=1.0.0"] +#requires = ["maturin>=0.12,<0.13"] +#build-backend = "maturin" [project] name = "tch-geometric" requires-python = ">=3.6" +dynamic = ["version"] classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] +[tool.setuptools] +packages = ["tch_geometric"] +include-package-data = true +zip-safe = false \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..5263d9a --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,7 @@ +setuptools>=61.0.0 +setuptools_rust~=1.0.0 +pip>=21.3 +wheel +toml +auditwheel +maturin \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7cc8241 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +import os + +from setuptools import setup +from setuptools_rust import RustExtension +import toml + +version = toml.load('Cargo.toml')['package']['version'] +if os.getenv('VERSION_SUFFIX', False): + version = f'{version}+{os.getenv("VERSION_SUFFIX")}' + +setup( + rust_extensions=[ + RustExtension( + "tch_geometric.tch_geometric", + debug=os.environ.get("BUILD_DEBUG") == "1", + ) + ], + version=version +) diff --git a/tch_geometric/__init__.py b/tch_geometric/__init__.py index c6bc1ff..d0a2cf0 100644 --- a/tch_geometric/__init__.py +++ b/tch_geometric/__init__.py @@ -1,6 +1,2 @@ -import tch_geometric.tch_geometric as native -import tch_geometric.loader as loader - - - - +import torch +from tch_geometric.tch_geometric import * diff --git a/tch_geometric/data/__init__.py b/tch_geometric/data/__init__.py deleted file mode 100644 index 0eeed7c..0000000 --- a/tch_geometric/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .conversion import * \ No newline at end of file diff --git a/tch_geometric/data/conversion.py b/tch_geometric/data/conversion.py deleted file mode 100644 index 096dc46..0000000 --- a/tch_geometric/data/conversion.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Union, Tuple, Dict - -from torch import Tensor -from torch_geometric.data import HeteroData, Data -from torch_geometric.data.storage import EdgeStorage -from torch_geometric.typing import EdgeType - -import tch_geometric.tch_geometric as native - -RelType = str -Size = Tuple[int, int] - - -def edge_type_to_str(edge_type: Union[EdgeType, str]) -> RelType: - # It is faster to have keys consisting of single string - return edge_type if isinstance(edge_type, str) else '__'.join(edge_type) - - -def to_sparse(data: Union[Data, EdgeStorage], sparse_fn) -> Tuple[Tensor, Tensor, Tensor, Size]: - if not hasattr(data, 'edge_index'): - raise AttributeError("Data object does not contain attribute 'edge_index'") - - size = data.size() - ptrs, indices, perm = sparse_fn(data.edge_index, size) - return ptrs, indices, perm, size - - -def to_csr(data: Union[Data, EdgeStorage]) -> Tuple[Tensor, Tensor, Tensor, Size]: - return to_sparse(data, native.to_csr) - - -def to_csc(data: Union[Data, EdgeStorage]) -> Tuple[Tensor, Tensor, Tensor, Size]: - return to_sparse(data, native.to_csc) - - -def to_hetero_sparse(data: HeteroData, sparse_fn) \ - -> Tuple[Dict[RelType, Tensor], Dict[RelType, Tensor], Dict[RelType, Tensor], Dict[RelType, Size]]: - ptrs_dict, indices_dict, perm_dict, size_dict = {}, {}, {}, {} - - for store in data.edge_stores: - key = edge_type_to_str(store._key) - ptrs_dict[key], indices_dict[key], perm_dict[key], size_dict[key] = sparse_fn(store) - - return ptrs_dict, indices_dict, perm_dict, size_dict - - -def to_hetero_csc(data: HeteroData) \ - -> Tuple[Dict[RelType, Tensor], Dict[RelType, Tensor], Dict[RelType, Tensor], Dict[RelType, Size]]: - return to_hetero_sparse(data, to_csc) - - -def to_hetero_csr(data: HeteroData) \ - -> Tuple[Dict[RelType, Tensor], Dict[RelType, Tensor], Dict[RelType, Tensor], Dict[RelType, Size]]: - return to_hetero_sparse(data, to_csr) - - -def to_hetero_sparse_attr(data: HeteroData, attr: str, perm_dict: Dict[RelType, Tensor]) \ - -> Dict[RelType, Tensor]: - attr_dict = {} - for store in data.edge_stores: - key = edge_type_to_str(store._key) - if hasattr(store, attr): - attr_data = getattr(store, attr)[perm_dict[key]] - attr_dict[key] = attr_data - - return attr_dict diff --git a/tch_geometric/data/subgraph.py b/tch_geometric/data/subgraph.py deleted file mode 100644 index 6d406ea..0000000 --- a/tch_geometric/data/subgraph.py +++ /dev/null @@ -1,81 +0,0 @@ -from collections import defaultdict -from typing import Dict, Any -from copy import deepcopy - -import torch -from torch import Tensor -from torch_geometric.data import HeteroData -from torch_geometric.typing import EdgeType, NodeType - -from tch_geometric.utils import zip_dict - - -def subgraph_from_edgelist( - edge_index_dict: Dict[EdgeType, Tensor], - node_attrs: Dict[str, Dict[NodeType, Any]] = None, - edge_attrs: Dict[str, Dict[EdgeType, Any]] = None, -): - if node_attrs is None: - node_attrs = {} - if edge_attrs is None: - edge_attrs = {} - - nodes: Dict[NodeType, Any] = defaultdict(list) - rows = dict() - cols = dict() - node_counts = defaultdict(lambda: 0) - sample_count = defaultdict(lambda: 0) - - for edge_type, edge_index in edge_index_dict.items(): - (src, _, _) = edge_type - start = node_counts[src] - edge_count = edge_index.shape[1] - nodes[src].append(edge_index[0, :]) - rows[edge_type] = torch.arange(start, start + edge_count, dtype=torch.long) - - node_counts[src] += edge_count - - sample_count = deepcopy(node_counts) - - for edge_type, edge_index in edge_index_dict.items(): - (_, _, dst) = edge_type - start = node_counts[dst] - edge_count = edge_index.shape[1] - nodes[dst].append(edge_index[1, :]) - cols[edge_type] = torch.arange(start, start + edge_count, dtype=torch.long) - - node_counts[dst] += edge_count - - for node_type, vals in nodes.items(): - nodes[node_type] = torch.cat(vals, dim=0) - - subgraph = create_subgraph( - nodes, rows, cols, - node_attrs=dict(sample_count=sample_count, **node_attrs), - edge_attrs=edge_attrs - ) - - return subgraph - - -def create_subgraph(nodes_dict, rows_dict, cols_dict, edge_attrs=None, node_attrs=None): - if edge_attrs is None: - edge_attrs = {} - if node_attrs is None: - node_attrs = {} - - # Build the subgraph - subgraph = HeteroData() - for node_type, node in nodes_dict.items(): - subgraph[node_type].x = node - for node_attr, node_attr_vals in node_attrs.items(): - setattr(subgraph[node_type], node_attr, node_attr_vals[node_type]) - - for rel_type, (row, col) in zip_dict(rows_dict, cols_dict): - edge_type = tuple(rel_type.split('__')) if isinstance(rel_type, str) else rel_type - subgraph[edge_type].edge_index = torch.stack([row, col], dim=0) - - for edge_attr, edge_attr_vals in edge_attrs.items(): - setattr(subgraph[edge_type], edge_attr, edge_attr_vals[edge_type]) - - return subgraph \ No newline at end of file diff --git a/tch_geometric/loader/__init__.py b/tch_geometric/loader/__init__.py deleted file mode 100644 index 6f998ea..0000000 --- a/tch_geometric/loader/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .custom_loader import * \ No newline at end of file diff --git a/tch_geometric/loader/budget_loader.py b/tch_geometric/loader/budget_loader.py deleted file mode 100644 index 163c482..0000000 --- a/tch_geometric/loader/budget_loader.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Callable, Union, List, Dict, Optional, Tuple - -from torch.utils.data import Dataset -from torch_geometric.data import HeteroData -from torch_geometric.typing import NodeType - -from tch_geometric.transforms.budget_sampling import BudgetSamplerTransform -from tch_geometric.transforms.hgt_sampling import HGTSamplerTransform -from tch_geometric.loader.custom_loader import CustomLoader - - -class BudgetLoader(CustomLoader): - def __init__( - self, - dataset: Dataset, - neighbor_sampler: BudgetSamplerTransform = None, - data: HeteroData = None, - num_neighbors: Union[List[int], Dict[NodeType, List[int]]] = None, - window: Optional[Tuple[int, int]] = None, - forward: bool = False, - relative: bool = True, - batch_size_tmp: int = None, - temporal: bool = True, - **kwargs - ): - super().__init__(dataset, batch_size_tmp=batch_size_tmp, **kwargs) - - self.neighbor_sampler = neighbor_sampler - if not self.neighbor_sampler: - self.neighbor_sampler = BudgetSamplerTransform(data, num_neighbors, window, forward, relative) - - self.temporal = temporal - - def sample(self, inputs: HeteroData): - inputs_dict = inputs.x_dict - inputs_timestamps_dict = inputs.timestamp_dict if self.neighbor_sampler.window else None - - return self.neighbor_sampler(inputs_dict, inputs_timestamps_dict) diff --git a/tch_geometric/loader/custom_loader.py b/tch_geometric/loader/custom_loader.py deleted file mode 100644 index cb239b3..0000000 --- a/tch_geometric/loader/custom_loader.py +++ /dev/null @@ -1,67 +0,0 @@ -from abc import abstractmethod -from typing import Callable - -from torch.utils.data import Dataset, RandomSampler, SequentialSampler, BatchSampler, DataLoader -from torch_geometric.loader.base import BaseDataLoader -from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.utilities import data - -def _is_dataloader_shuffled(dataloader: DataLoader): - return ( - hasattr(dataloader, "sampler") - and not ( # Added this condition - isinstance(dataloader.sampler, BatchSampler) and - isinstance(dataloader.sampler.sampler, SequentialSampler), - ) - ) - -# I am going insane from this warning. It's a false positive. Therefore we monkey patch it. -_check_eval_shuffling_og = DataConnector._check_eval_shuffling -def _check_eval_shuffling(cls, dataloader, mode): - if not _is_dataloader_shuffled(dataloader): - return - - _check_eval_shuffling_og(dataloader, mode) -DataConnector._check_eval_shuffling = _check_eval_shuffling - - -class CustomLoader(BaseDataLoader): - def __init__( - self, - dataset: Dataset, - transform: Callable = None, - shuffle: bool = False, - generator=None, - batch_size: int = 1, - drop_last: bool = False, - batch_size_tmp: int = None, - **kwargs, - ): - kwargs.pop('collate_fn', None) - kwargs.pop('sampler', None) - batch_size = batch_size or batch_size_tmp - self.transform = transform - self.shuffle = shuffle - self.batch_size = batch_size - self.batch_size_tmp = batch_size - - # Default sampler to set autocollate to False - if shuffle: - sampler = RandomSampler(dataset, generator=generator) - else: - sampler = SequentialSampler(dataset) - batch_sampler = BatchSampler(sampler, batch_size, drop_last) - - super().__init__( - dataset, - collate_fn=self.sample, - sampler=batch_sampler, - batch_size=None, - **kwargs, - ) - - def sample(self, inputs): - return inputs - - def transform_fn(self, out): - return out if self.transform is None else self.transform(out) diff --git a/tch_geometric/loader/hgt_loader.py b/tch_geometric/loader/hgt_loader.py deleted file mode 100644 index d900765..0000000 --- a/tch_geometric/loader/hgt_loader.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Callable, Union, List, Dict, Optional - -from torch.utils.data import Dataset -from torch_geometric.data import HeteroData -from torch_geometric.typing import NodeType - -from tch_geometric.transforms.hgt_sampling import HGTSamplerTransform -from tch_geometric.loader.custom_loader import CustomLoader - - -class HGTLoader(CustomLoader): - def __init__( - self, - dataset: Dataset, - neighbor_sampler: HGTSamplerTransform = None, - data: HeteroData = None, - num_samples: Union[List[int], Dict[NodeType, List[int]]] = None, - temporal: bool = False, - batch_size_tmp: int = None, - **kwargs - ): - super().__init__(dataset, batch_size_tmp=batch_size_tmp, **kwargs) - - self.neighbor_sampler = neighbor_sampler - if not self.neighbor_sampler: - self.neighbor_sampler = HGTSamplerTransform(data, num_samples, temporal) - - def sample(self, inputs: HeteroData): - inputs_dict = inputs.x_dict - inputs_timestamps_dict = inputs.timestamp_dict if self.neighbor_sampler.temporal else None - timerange = None - - return self.neighbor_sampler(inputs_dict, inputs_timestamps_dict, timerange) diff --git a/tch_geometric/tch_geometric.pyi b/tch_geometric/tch_geometric.pyi index dffba8a..e9c4cdf 100644 --- a/tch_geometric/tch_geometric.pyi +++ b/tch_geometric/tch_geometric.pyi @@ -3,7 +3,7 @@ from typing import Union, Tuple, List, Optional, Dict from torch import Tensor from torch_geometric.typing import NodeType, EdgeType -from tch_geometric.transforms import EdgeSampler, EdgeFilter +from tch_geometric.utils import EdgeSampler, EdgeFilter LayerOffset = (int, int, int) RelType = str diff --git a/tch_geometric/transforms/__init__.py b/tch_geometric/transforms/__init__.py deleted file mode 100644 index 02a1b75..0000000 --- a/tch_geometric/transforms/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .negative_sampling import * -from .neighbor_sampling import * diff --git a/tch_geometric/transforms/budget_sampling.py b/tch_geometric/transforms/budget_sampling.py deleted file mode 100644 index 583b1e5..0000000 --- a/tch_geometric/transforms/budget_sampling.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Union, List, Dict, Optional, Tuple - -import torch -from torch_geometric.data import Data, HeteroData -from torch_geometric.loader.utils import filter_hetero_data -from torch_geometric.typing import NodeType - -import tch_geometric.tch_geometric as native -from tch_geometric.data import to_hetero_csc, to_hetero_sparse_attr -from tch_geometric.types import MixedData, validate_mixeddata, HeteroTensor, Timerange - -NAN_TIMESTAMP = -1 -NAN_TIMEDELTA = -99999 - - -class BudgetSamplerTransform: - def __init__( - self, - data: Union[Data, HeteroData], - num_neighbors: Union[List[int], Dict[NodeType, List[int]]], - window: Optional[Tuple[int, int]] = None, - forward: bool = False, - relative: bool = True, - ) -> None: - super().__init__() - assert isinstance(data, HeteroData) - - self.data = data - self.window = window - self.forward = forward - self.relative = relative - self.temporal = window is not None - - self.col_ptrs_dict, self.row_indices_dict, self.perm_dict, self.size_dict = to_hetero_csc(data) - self.node_types, self.edge_types = data.metadata() - - if self.temporal: - assert 'timestamp' in data.keys - self.row_timestamps_dict = to_hetero_sparse_attr(data, 'timestamp', self.perm_dict) - - if isinstance(num_neighbors, (list, tuple)): - num_neighbors = {key: num_neighbors for key in self.node_types} - assert isinstance(num_neighbors, dict) - self.num_neighbors = num_neighbors - - self.num_hops = max([len(v) for v in self.num_neighbors.values()]) - - def __call__( - self, - inputs_dict: HeteroTensor, - inputs_timestamps_dict: Optional[HeteroTensor] = None, - temporal: bool = True, - ) -> Union[Data, HeteroData]: - validate_mixeddata(inputs_dict, hetero=True, dtype=torch.int64) - temporal = self.temporal - - # Sample the data - sample_fn = native.budget_sampling - nodes, nodes_timestamps, rows, cols, edges = sample_fn( - self.node_types, - self.edge_types, - self.col_ptrs_dict, - self.row_indices_dict, - self.row_timestamps_dict if temporal else None, - inputs_dict, - inputs_timestamps_dict, - self.num_neighbors, - self.num_hops, - self.window if temporal else None, - self.forward, - self.relative, - ) - batch_size = {key: value.numel() for key, value in inputs_dict.items()} - - # Transform data to HeteroData - data = filter_hetero_data( - self.data, nodes, rows, cols, edges, self.perm_dict - ) - for node_type, batch_size in batch_size.items(): - data[node_type].batch_size = batch_size - - if temporal: - for store in data.node_stores: - node_type = store._key - if node_type in nodes_timestamps: - store.timestamp = nodes_timestamps[node_type] - - return data diff --git a/tch_geometric/transforms/constrastive_merge.py b/tch_geometric/transforms/constrastive_merge.py deleted file mode 100644 index 34dc491..0000000 --- a/tch_geometric/transforms/constrastive_merge.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections import defaultdict - -import torch -from torch_geometric.data import HeteroData, Data - -from tch_geometric.utils import zip_dict - - -class ContrastiveMergeTransform: - def __init__(self) -> None: - super().__init__() - - def __call__(self, pos: HeteroData, neg: HeteroData) -> HeteroData: - result = HeteroData() - - nodes_start_neg = {} - for (node_type, (pos_store, neg_store)) in zip_dict(pos._node_store_dict, neg._node_store_dict): - nodes_start_neg[node_type] = pos_store.num_nodes - result[node_type].x = torch.cat([pos_store.x, neg_store.x], dim=0) - - for (edge_type, (pos_store, neg_store)) in zip_dict(pos._edge_store_dict, neg._edge_store_dict): - (src, rel, dst) = edge_type - - pos_edge_store = result[(src, f'{rel}_pos', dst)] - pos_edge_store.edge_index = pos_store.edge_index - pos_edge_store.type = 'pos' - - neg_edge_store = result[(src, f'{rel}_neg', dst)] - neg_edge_store.edge_index = neg_store.edge_index - neg_edge_store.edge_index[1, :] += nodes_start_neg[dst] - neg_edge_store.type = 'neg' - - return result - - -class EdgeTypeAggregateTransform: - def __init__(self) -> None: - super().__init__() - - def __call__( - self, - data: HeteroData - ) -> HeteroData: - result = HeteroData() - node_type_default = 'n' - - # Merge all nodes into a single tensor while preserving the type specific offsets - offset = 0 - node_offsets = {} - xs = [] - for store in data.node_stores: - node_type = store._key - xs.append(store.x) - node_offsets[node_type] = offset - offset += store.num_nodes - result[node_type_default].x = torch.cat(xs, dim=0) - - # Merge all edges into a single tensor and correct edges with correct node offsets - edge_indexes_dict = defaultdict(list) - for store in data.edge_stores: - (src, _, dst) = store._key - type = store.type - - edge_index = store.edge_index - edge_index[0, :] += node_offsets[src] - edge_index[1, :] += node_offsets[dst] - - edge_indexes_dict[(node_type_default, type, node_type_default)].append(edge_index) - - for edge_type, edge_indexes in edge_indexes_dict.items(): - result[edge_type].edge_index = torch.cat(edge_indexes, dim=1) - - return result \ No newline at end of file diff --git a/tch_geometric/transforms/hgt_sampling.py b/tch_geometric/transforms/hgt_sampling.py deleted file mode 100644 index 2b071cb..0000000 --- a/tch_geometric/transforms/hgt_sampling.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Union, List, Dict, Optional - -import torch -from torch_geometric.data import Data, HeteroData -from torch_geometric.loader.utils import filter_hetero_data -from torch_geometric.typing import NodeType - -import tch_geometric.tch_geometric as native -from tch_geometric.data import to_hetero_csc, to_hetero_sparse_attr -from tch_geometric.types import MixedData, validate_mixeddata, HeteroTensor, Timerange - -NAN_TIMESTAMP = -1 -NAN_TIMEDELTA = -99999 - -class HGTSamplerTransform: - def __init__( - self, - data: Union[Data, HeteroData], - num_samples: Union[List[int], Dict[NodeType, List[int]]], - temporal: bool = False, - ) -> None: - super().__init__() - assert isinstance(data, HeteroData) - - self.data = data - self.num_samples = num_samples - self.temporal = temporal - - self.col_ptrs_dict, self.row_indices_dict, self.perm_dict, self.size_dict = to_hetero_csc(data) - self.node_types, self.edge_types = data.metadata() - - if temporal: - assert 'timestamp' in data.keys - self.row_timestamps_dict = to_hetero_sparse_attr(data, 'timestamp', self.perm_dict) - - if isinstance(num_samples, (list, tuple)): - num_samples = {key: num_samples for key in self.node_types} - assert isinstance(num_samples, dict) - self.num_samples = num_samples - - self.num_hops = max([len(v) for v in self.num_samples.values()]) - - def __call__( - self, - inputs_dict: HeteroTensor, - inputs_timestamps_dict: Optional[HeteroTensor] = None, - timerange: Optional[Timerange] = None, - ) -> Union[Data, HeteroData]: - validate_mixeddata(inputs_dict, hetero=True, dtype=torch.int64) - - # Correct amount of samples by the batch size - num_inputs = sum([len(v) for v in inputs_dict.values() ]) - num_samples = { - k: [n * num_inputs for n in v] - for k, v in self.num_samples.items() - } - - # Sample the data - sample_fn = native.hgt_sampling - nodes, nodes_timestamps, rows, cols, edges = sample_fn( - self.node_types, - self.edge_types, - self.col_ptrs_dict, - self.row_indices_dict, - self.row_timestamps_dict if self.temporal else None, - inputs_dict, - inputs_timestamps_dict if self.temporal else None, - num_samples, - self.num_hops, - timerange if self.temporal else None, - ) - batch_size = {key: value.numel() for key, value in inputs_dict.items()} - - # Transform data to HeteroData - data = filter_hetero_data( - self.data, nodes, rows, cols, edges, self.perm_dict - ) - for node_type, batch_size in batch_size.items(): - data[node_type].batch_size = batch_size - - if self.temporal: - for store in data.node_stores: - node_type = store._key - if node_type in nodes_timestamps: - store.timestamp = nodes_timestamps[node_type] - - for store in data.edge_stores: - (src, _, dst) = store._key - timestamp_src = data[src].timestamp[store.edge_index[0, :]] - timestamp_dst = data[dst].timestamp[store.edge_index[1, :]] - store.timedelta = timestamp_dst - timestamp_src - store.timedelta[torch.logical_or(timestamp_src == NAN_TIMESTAMP, timestamp_dst == NAN_TIMESTAMP)] = NAN_TIMEDELTA - - return data diff --git a/tch_geometric/transforms/negative_sampling.py b/tch_geometric/transforms/negative_sampling.py deleted file mode 100644 index 7b52a27..0000000 --- a/tch_geometric/transforms/negative_sampling.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Union, Tuple - -import torch -from torch_geometric.data import Data, HeteroData - -import tch_geometric.tch_geometric as native -from tch_geometric.data import to_csc, to_csr, to_hetero_csc, to_hetero_csr -from tch_geometric.data.subgraph import create_subgraph -from tch_geometric.types import MixedData, validate_mixeddata - - -class NegativeSamplerTransform: - def __init__( - self, - data: Union[Data, HeteroData], - num_neg: int, - try_count: int, - inbound: bool = True, - ) -> None: - super().__init__() - self.data = data - self.num_neg = num_neg - self.try_count = try_count - self.inbound = inbound - - # Convert the graph data into a suitable format for sampling. - if isinstance(data, Data): - convert_fn = to_csc if inbound else to_csr - self.ptrs, self.indices, self.perm, self.size = convert_fn(data) - - elif isinstance(data, HeteroData): - convert_fn = to_hetero_csc if inbound else to_hetero_csr - self.ptrs_dict, self.indices_dict, self.perm_dict, self.size_dict = convert_fn(data) - self.node_types, self.edge_types = data.metadata() - - else: - raise TypeError(f'Invalid graph type: {type(data)}') - - def __call__(self, inputs: MixedData) -> HeteroData: - if isinstance(self.data, Data): - validate_mixeddata(inputs, hetero=False, dtype=torch.int64) - - # Sample negative edges. - sample_fn = native.negative_sample_neighbors_homogenous - node, row, col, sample_count = sample_fn( - self.ptrs, - self.indices, - self.size, - inputs, - self.num_neg, - self.try_count, - ) - - # Build the subgraph - subgraph = Data() - subgraph.x = node - subgraph.edge_index = torch.stack([row, col], dim=0) - subgraph.sample_count = sample_count - - return subgraph - - elif isinstance(self.data, HeteroData): - validate_mixeddata(inputs, hetero=True, dtype=torch.int64) - - # Sample negative edges. - sample_fn = native.negative_sample_neighbors_heterogenous - nodes, rows, cols, sample_counts = sample_fn( - self.node_types, - self.edge_types, - self.ptrs_dict, - self.indices_dict, - self.size_dict, - inputs, - self.num_neg, - self.try_count, - self.inbound, - ) - - subgraph = create_subgraph(nodes, rows, cols, node_attrs=dict(sample_count=sample_counts)) - - return subgraph diff --git a/tch_geometric/transforms/neighbor_sampling.py b/tch_geometric/transforms/neighbor_sampling.py deleted file mode 100644 index bd24479..0000000 --- a/tch_geometric/transforms/neighbor_sampling.py +++ /dev/null @@ -1,143 +0,0 @@ -from dataclasses import dataclass -from typing import Union, List, Dict, Optional, Tuple - -import torch -from torch_geometric.data import Data, HeteroData -from torch_geometric.loader.utils import filter_data, filter_hetero_data -from torch_geometric.typing import EdgeType - -import tch_geometric.tch_geometric as native -from tch_geometric.data import to_csc, to_hetero_csc, edge_type_to_str -from tch_geometric.types import MixedData, validate_mixeddata - -NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]] - - -@dataclass -class EdgeSampler: - def validate(self, hetero: bool = False) -> None: - raise NotImplementedError - - -@dataclass -class UniformEdgeSampler(EdgeSampler): - with_replacement: bool = False - - def validate(self, hetero: bool = False) -> None: - pass - - -@dataclass -class WeightedEdgeSampler(EdgeSampler): - weights: MixedData - - def validate(self, hetero: bool = False) -> None: - validate_mixeddata(self.weights, hetero=hetero, dtype=torch.float64) - - -TEMPORAL_SAMPLE_STATIC: int = 0 -TEMPORAL_SAMPLE_RELATIVE: int = 1 -TEMPORAL_SAMPLE_DYNAMIC: int = 2 - - -@dataclass -class EdgeFilter: - def validate(self, hetero: bool = False) -> None: - raise NotImplementedError - - -@dataclass -class TemporalEdgeFilter: - window: Tuple[int, int] - timestamps: MixedData - forward: bool = False - mode: int = TEMPORAL_SAMPLE_STATIC - - def validate(self, hetero: bool = False) -> None: - validate_mixeddata(self.timestamps, hetero=hetero, dtype=torch.int64) - - -class NeighborSamplerTransform: - def __init__( - self, - data: Union[Data, HeteroData], - num_neighbors: NumNeighbors, - edge_sampler: Optional[EdgeSampler] = None, - edge_filter: Optional[EdgeFilter] = None, - ) -> None: - super().__init__() - self.data = data - self.num_neighbors = num_neighbors - self.edge_sampler = edge_sampler - self.edge_filter = edge_filter - - # Convert the graph data into a suitable format for sampling. - if isinstance(data, Data): - self.col_ptrs, self.row_indices, self.perm, self.size = to_csc(data) - assert isinstance(num_neighbors, (list, tuple)) - - elif isinstance(data, HeteroData): - self.col_ptrs_dict, self.row_indices_dict, self.perm_dict, self.size_dict = to_hetero_csc(data) - self.node_types, self.edge_types = data.metadata() - - if isinstance(num_neighbors, (list, tuple)): - num_neighbors = {key: num_neighbors for key in self.edge_types} - assert isinstance(num_neighbors, dict) - self.num_neighbors = { - edge_type_to_str(key): value - for key, value in num_neighbors.items() - } - - self.num_hops = max([len(v) for v in self.num_neighbors.values()]) - - else: - raise TypeError(f'NeighborLoader found invalid type: {type(data)}') - - def __call__(self, inputs: MixedData, input_states: Optional[MixedData] = None) -> Union[Data, HeteroData]: - if isinstance(self.data, Data): - validate_mixeddata(inputs, hetero=False, dtype=torch.int64) - - sample_fn = native.neighbor_sampling_homogenous - node, row, col, edge, layer_offsets = sample_fn( - self.col_ptrs, - self.row_indices, - inputs, - self.num_neighbors, - self.edge_sampler, - self.edge_filter, - ) - batch_size = inputs.numel() - - data = filter_data(self.data, node, row, col, edge, self.perm) - data.batch_size = batch_size - - return data - - elif isinstance(self.data, HeteroData): - validate_mixeddata(inputs, hetero=True, dtype=torch.int64) - - if input_states: - validate_mixeddata(input_states, hetero=True, dtype=torch.int64) - edge_filter = (self.edge_filter, input_states) if self.edge_filter else None - - sample_fn = native.neighbor_sampling_heterogenous - nodes, rows, cols, edges, layer_offsets = sample_fn( - self.node_types, - self.edge_types, - self.col_ptrs_dict, - self.row_indices_dict, - inputs, - self.num_neighbors, - self.num_hops, - self.edge_sampler, - edge_filter, - ) - batch_size = {key: value.numel() for key, value in inputs.items()} - - data = filter_hetero_data( - self.data, nodes, rows, cols, edges, self.perm_dict - ) - for node_type, batch_size in batch_size.items(): - data[node_type].batch_size = batch_size - - return data diff --git a/tch_geometric/types.py b/tch_geometric/types.py deleted file mode 100644 index 0c091e9..0000000 --- a/tch_geometric/types.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Union, Dict, Tuple - -from torch import Tensor - -MixedData = Union[Tensor, Dict[str, Tensor]] -HeteroTensor = Dict[str, Tensor] - -Timerange = Tuple[int, int] - -def validate_mixeddata(data: MixedData, hetero: bool = False, dtype=None): - if hetero: - assert isinstance(data, dict) - for v in data.values(): - assert v.dtype == dtype - else: - assert data.dtype == dtype diff --git a/tch_geometric/utils.py b/tch_geometric/utils.py new file mode 100644 index 0000000..3fe02d1 --- /dev/null +++ b/tch_geometric/utils.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import List +from typing import Union, Dict, Tuple + +import torch +from torch import Tensor +from torch_geometric.typing import EdgeType + +NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]] + +MixedData = Union[Tensor, Dict[str, Tensor]] +HeteroTensor = Dict[str, Tensor] + +Timerange = Tuple[int, int] + + +def validate_mixeddata(data: MixedData, hetero: bool = False, dtype=None): + if hetero: + assert isinstance(data, dict) + for v in data.values(): + assert v.dtype == dtype + else: + assert data.dtype == dtype + + +@dataclass +class EdgeSampler: + def validate(self, hetero: bool = False) -> None: + raise NotImplementedError + + +@dataclass +class UniformEdgeSampler(EdgeSampler): + with_replacement: bool = False + + def validate(self, hetero: bool = False) -> None: + pass + + +@dataclass +class WeightedEdgeSampler(EdgeSampler): + weights: MixedData + + def validate(self, hetero: bool = False) -> None: + validate_mixeddata(self.weights, hetero=hetero, dtype=torch.float64) + + +TEMPORAL_SAMPLE_STATIC: int = 0 +TEMPORAL_SAMPLE_RELATIVE: int = 1 +TEMPORAL_SAMPLE_DYNAMIC: int = 2 + + +@dataclass +class EdgeFilter: + def validate(self, hetero: bool = False) -> None: + raise NotImplementedError + + +@dataclass +class TemporalEdgeFilter: + window: Tuple[int, int] + timestamps: MixedData + forward: bool = False + mode: int = TEMPORAL_SAMPLE_STATIC + + def validate(self, hetero: bool = False) -> None: + validate_mixeddata(self.timestamps, hetero=hetero, dtype=torch.int64) diff --git a/tch_geometric/utils/__init__.py b/tch_geometric/utils/__init__.py deleted file mode 100644 index c8abad0..0000000 --- a/tch_geometric/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .iter import * diff --git a/tch_geometric/utils/iter.py b/tch_geometric/utils/iter.py deleted file mode 100644 index c00b46b..0000000 --- a/tch_geometric/utils/iter.py +++ /dev/null @@ -1,5 +0,0 @@ -def zip_dict(d1, d2): - """ - Zip two dictionaries together. - """ - yield from ((k, (d1[k], d2[k])) for k in d1.keys() & d2.keys())