From e803655a1794caac390d9220107a93baafa9ef23 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 13:51:15 +0000 Subject: [PATCH] chore(deps): update dependency flax to v0.8.5 --- poetry.lock | 46 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/poetry.lock b/poetry.lock index 59535a35e..fa10dbb6d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1868,21 +1868,21 @@ dotenv = ["python-dotenv"] [[package]] name = "flax" -version = "0.8.1" +version = "0.8.5" description = "Flax: A neural network library for JAX designed for flexibility" optional = false python-versions = ">=3.9" files = [ - {file = "flax-0.8.1-py3-none-any.whl", hash = "sha256:8cf9ef11859eef252470377556a8cc48db287fc6647407ab34f1fc01461925dd"}, - {file = "flax-0.8.1.tar.gz", hash = "sha256:ce3d99e9b4c0d2e4d9fc28bc56cced8ba953adfd695aabd24f096b4c8a7e2f92"}, + {file = "flax-0.8.5-py3-none-any.whl", hash = "sha256:c96e46d1c48a300d010ebf5c4846f163bdd7acc6efff5ff2bfb1cb5b08aa65d8"}, + {file = "flax-0.8.5.tar.gz", hash = "sha256:4a9cb7950ece54b0addaa73d77eba24e46138dbe783d01987be79d20ccb2b09b"}, ] [package.dependencies] -jax = ">=0.4.19" +jax = ">=0.4.27" msgpack = "*" numpy = [ - {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] optax = "*" orbax-checkpoint = "*" @@ -1893,7 +1893,7 @@ typing-extensions = ">=4.2" [package.extras] all = ["matplotlib"] -testing = ["black[jupyter] (==23.7.0)", "clu", "clu (<=0.0.9)", "einops", "gymnasium[accept-rom-license,atari]", "jaxlib", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist", "pytype", "sentencepiece", "tensorflow", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] +testing = ["black[jupyter] (==23.7.0)", "clu", "clu (<=0.0.9)", "einops", "gymnasium[accept-rom-license,atari]", "jaxlib", "jaxtyping", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "penzai (>=0.1.2)", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist", "pytype", "sentencepiece", "tensorflow (>=2.12.0)", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] [[package]] name = "flyteidl" @@ -3248,40 +3248,38 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", [[package]] name = "jax" -version = "0.4.23" +version = "0.4.28" description = "Differentiate, compile, and transform Numpy code." optional = false python-versions = ">=3.9" files = [ - {file = "jax-0.4.23-py3-none-any.whl", hash = "sha256:a7a07ccd1577111e3b82378c79a8ed0f9d6613b1e98fb6bf3c0b459198f73eaa"}, - {file = "jax-0.4.23.tar.gz", hash = "sha256:2a229a5a758d1b803891b2eaed329723f6b15b4258b14dc0ccb1498c84963685"}, + {file = "jax-0.4.28-py3-none-any.whl", hash = "sha256:6a181e6b5a5b1140e19cdd2d5c4aa779e4cb4ec627757b918be322d8e81035ba"}, + {file = "jax-0.4.28.tar.gz", hash = "sha256:dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9"}, ] [package.dependencies] ml-dtypes = ">=0.2.0" numpy = [ - {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] opt-einsum = "*" scipy = [ - {version = ">=1.9", markers = "python_version < \"3.12\""}, {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.9", markers = "python_version < \"3.12\""}, ] [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.22)"] -cpu = ["jaxlib (==0.4.23)"] -cuda = ["jaxlib (==0.4.23+cuda11.cudnn86)"] -cuda11-cudnn86 = ["jaxlib (==0.4.23+cuda11.cudnn86)"] -cuda11-local = ["jaxlib (==0.4.23+cuda11.cudnn86)"] -cuda11-pip = ["jaxlib (==0.4.23+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)", "nvidia-nccl-cu11 (>=2.18.3)"] -cuda12 = ["jax-cuda12-plugin (==0.4.23)", "jaxlib (==0.4.23)", "nvidia-cublas-cu12 (>=12.2.5.6)", "nvidia-cuda-cupti-cu12 (>=12.2.142)", "nvidia-cuda-nvcc-cu12 (>=12.2.140)", "nvidia-cuda-runtime-cu12 (>=12.2.140)", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12 (>=11.0.8.103)", "nvidia-cusolver-cu12 (>=11.5.2)", "nvidia-cusparse-cu12 (>=12.1.2.141)", "nvidia-nccl-cu12 (>=2.18.3)", "nvidia-nvjitlink-cu12 (>=12.2)"] -cuda12-local = ["jaxlib (==0.4.23+cuda12.cudnn89)"] -cuda12-pip = ["jaxlib (==0.4.23+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.2.5.6)", "nvidia-cuda-cupti-cu12 (>=12.2.142)", "nvidia-cuda-nvcc-cu12 (>=12.2.140)", "nvidia-cuda-runtime-cu12 (>=12.2.140)", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12 (>=11.0.8.103)", "nvidia-cusolver-cu12 (>=11.5.2)", "nvidia-cusparse-cu12 (>=12.1.2.141)", "nvidia-nccl-cu12 (>=2.18.3)", "nvidia-nvjitlink-cu12 (>=12.2)"] -minimum-jaxlib = ["jaxlib (==0.4.19)"] -tpu = ["jaxlib (==0.4.23)", "libtpu-nightly (==0.1.dev20231213)", "requests"] +ci = ["jaxlib (==0.4.27)"] +cpu = ["jaxlib (==0.4.28)"] +cuda = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12 = ["jax-cuda12-plugin (==0.4.28)", "jaxlib (==0.4.28)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +cuda12-cudnn89 = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12-local = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.28+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +minimum-jaxlib = ["jaxlib (==0.4.27)"] +tpu = ["jaxlib (==0.4.28)", "libtpu-nightly (==0.1.dev20240508)", "requests"] [[package]] name = "jaxlib" @@ -3316,8 +3314,8 @@ files = [ ml-dtypes = ">=0.2.0" numpy = ">=1.22" scipy = [ - {version = ">=1.9", markers = "python_version < \"3.12\""}, {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.9", markers = "python_version < \"3.12\""}, ] [package.extras] @@ -5623,8 +5621,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1"