From 0e27f302a01b35cb775dce3c39f15c724167419e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Oct 2024 09:39:26 -0700 Subject: [PATCH 1/5] Do full inference test against test vectors for test_* models --- tests/test_models.py | 53 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index fd09ceb2e6..b5b538d9c4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,7 +26,7 @@ has_fx_feature_extraction = False import timm -from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value +from timm import list_models, list_pretrained, create_model, set_scriptable, get_pretrained_cfg_value from timm.layers import Format, get_spatial_dim, get_channel_dim from timm.models import get_notrace_modules, get_notrace_functions @@ -39,7 +39,8 @@ torch_device = os.environ.get('TORCH_DEVICE', 'cpu') timeout = os.environ.get('TIMEOUT') timeout120 = int(timeout) if timeout else 120 -timeout300 = int(timeout) if timeout else 300 +timeout240 = int(timeout) if timeout else 240 +timeout360 = int(timeout) if timeout else 360 if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -118,6 +119,50 @@ def _get_input_size(model=None, model_name='', target=None): return input_size +@pytest.mark.base +@pytest.mark.timeout(timeout240) +@pytest.mark.parametrize('model_name', list_pretrained('test_*')) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_inference(model_name, batch_size): + """Run a single forward pass with each model""" + from PIL import Image + from huggingface_hub import snapshot_download + import tempfile + import safetensors + + model = create_model(model_name, pretrained=True) + model.eval() + pp = timm.data.create_transform(**timm.data.resolve_data_config(model=model)) + + with tempfile.TemporaryDirectory() as temp_dir: + snapshot_download( + repo_id='timm/' + model_name, repo_type='model', local_dir=temp_dir, allow_patterns='test/*' + ) + rand_tensors = safetensors.torch.load_file(os.path.join(temp_dir, 'test', 'rand_tensors.safetensors')) + owl_tensors = safetensors.torch.load_file(os.path.join(temp_dir, 'test', 'owl_tensors.safetensors')) + test_owl = Image.open(os.path.join(temp_dir, 'test', 'test_owl.jpg')) + + with torch.no_grad(): + rand_output = model(rand_tensors['input']) + rand_features = model.forward_features(rand_tensors['input']) + rand_pre_logits = model.forward_head(rand_features, pre_logits=True) + assert torch.allclose(rand_output, rand_tensors['output']) + assert torch.allclose(rand_features, rand_tensors['features']) + assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits']) + + def _test_owl(owl_input): + owl_output = model(owl_input) + owl_features = model.forward_features(owl_input) + owl_pre_logits = model.forward_head(owl_features.clone(), pre_logits=True) + assert owl_output.softmax(1).argmax(1) == 24 # owl + assert torch.allclose(owl_output, owl_tensors['output']) + assert torch.allclose(owl_features, owl_tensors['features']) + assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits']) + + _test_owl(owl_tensors['input']) # test with original pp owl tensor + _test_owl(pp(test_owl).unsqueeze(0)) # re-process from original jpg + + @pytest.mark.base @pytest.mark.timeout(timeout120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @@ -182,7 +227,7 @@ def test_model_backward(model_name, batch_size): ) @pytest.mark.cfg -@pytest.mark.timeout(timeout300) +@pytest.mark.timeout(timeout360) @pytest.mark.parametrize('model_name', list_models( exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True)) @pytest.mark.parametrize('batch_size', [1]) @@ -260,7 +305,7 @@ def test_model_default_cfgs(model_name, batch_size): @pytest.mark.cfg -@pytest.mark.timeout(timeout300) +@pytest.mark.timeout(timeout360) @pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs_non_std(model_name, batch_size): From fde671940382d3668200f0c8b0cab3be2ae6e49a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Oct 2024 09:53:17 -0700 Subject: [PATCH 2/5] relax tolerance on inference test --- tests/test_models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index b5b538d9c4..02cf7707df 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -146,18 +146,18 @@ def test_model_inference(model_name, batch_size): rand_output = model(rand_tensors['input']) rand_features = model.forward_features(rand_tensors['input']) rand_pre_logits = model.forward_head(rand_features, pre_logits=True) - assert torch.allclose(rand_output, rand_tensors['output']) - assert torch.allclose(rand_features, rand_tensors['features']) - assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits']) + assert torch.allclose(rand_output, rand_tensors['output'], rtol=1e-3, atol=1e-5) + assert torch.allclose(rand_features, rand_tensors['features'], rtol=1e-3, atol=1e-5) + assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits'], rtol=1e-3, atol=1e-5) def _test_owl(owl_input): owl_output = model(owl_input) owl_features = model.forward_features(owl_input) owl_pre_logits = model.forward_head(owl_features.clone(), pre_logits=True) assert owl_output.softmax(1).argmax(1) == 24 # owl - assert torch.allclose(owl_output, owl_tensors['output']) - assert torch.allclose(owl_features, owl_tensors['features']) - assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits']) + assert torch.allclose(owl_output, owl_tensors['output'], rtol=1e-3, atol=1e-5) + assert torch.allclose(owl_features, owl_tensors['features'], rtol=1e-3, atol=1e-5) + assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits'], rtol=1e-3, atol=1e-5) _test_owl(owl_tensors['input']) # test with original pp owl tensor _test_owl(pp(test_owl).unsqueeze(0)) # re-process from original jpg From 95907e69c28770c5ea3323018b76027ccb20acba Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Oct 2024 10:05:27 -0700 Subject: [PATCH 3/5] Further reduce atol for model comparison, move python 3.11 + torch 2.2 -> python 3.12 + torch 2.4.1 --- .github/workflows/tests.yml | 6 +++--- tests/test_models.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1279b219d9..9cc2243c95 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,11 +16,11 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: ['3.10', '3.11'] - torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.1.0', vision: '0.16.0'}] + python: ['3.10', '3.12'] + torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.4.1', vision: '0.19.1'}] testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] exclude: - - python: '3.11' + - python: '3.12' torch: {base: '1.13.0', vision: '0.14.0'} runs-on: ${{ matrix.os }} diff --git a/tests/test_models.py b/tests/test_models.py index 02cf7707df..f2a1d7e408 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -146,18 +146,18 @@ def test_model_inference(model_name, batch_size): rand_output = model(rand_tensors['input']) rand_features = model.forward_features(rand_tensors['input']) rand_pre_logits = model.forward_head(rand_features, pre_logits=True) - assert torch.allclose(rand_output, rand_tensors['output'], rtol=1e-3, atol=1e-5) - assert torch.allclose(rand_features, rand_tensors['features'], rtol=1e-3, atol=1e-5) - assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits'], rtol=1e-3, atol=1e-5) + assert torch.allclose(rand_output, rand_tensors['output'], rtol=1e-3, atol=1e-4) + assert torch.allclose(rand_features, rand_tensors['features'], rtol=1e-3, atol=1e-4) + assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits'], rtol=1e-3, atol=1e-4) def _test_owl(owl_input): owl_output = model(owl_input) owl_features = model.forward_features(owl_input) owl_pre_logits = model.forward_head(owl_features.clone(), pre_logits=True) assert owl_output.softmax(1).argmax(1) == 24 # owl - assert torch.allclose(owl_output, owl_tensors['output'], rtol=1e-3, atol=1e-5) - assert torch.allclose(owl_features, owl_tensors['features'], rtol=1e-3, atol=1e-5) - assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits'], rtol=1e-3, atol=1e-5) + assert torch.allclose(owl_output, owl_tensors['output'], rtol=1e-3, atol=1e-4) + assert torch.allclose(owl_features, owl_tensors['features'], rtol=1e-3, atol=1e-4) + assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits'], rtol=1e-3, atol=1e-4) _test_owl(owl_tensors['input']) # test with original pp owl tensor _test_owl(pp(test_owl).unsqueeze(0)) # re-process from original jpg From 036b1f0cc90b4edd5280d6528cdaa14df147c8a1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Oct 2024 10:18:34 -0700 Subject: [PATCH 4/5] no 2.4.1 for cpu --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9cc2243c95..1ca879cd4e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.12'] - torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.4.1', vision: '0.19.1'}] + torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.4.0', vision: '0.19.0'}] testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] exclude: - python: '3.12' From 1a2d8bb5a082e63bd08ea7ed45670f7f46b190f6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Oct 2024 10:26:22 -0700 Subject: [PATCH 5/5] Update pip install to use whl/cpu --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1ca879cd4e..0dd62aaf66 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.12'] - torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.4.0', vision: '0.19.0'}] + torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.4.1', vision: '0.19.1'}] testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] exclude: - python: '3.12' @@ -46,7 +46,7 @@ jobs: sudo sed -i 's/azure\.//' /etc/apt/sources.list sudo apt update sudo apt install -y google-perftools - pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu --index-url https://download.pytorch.org/whl/cpu - name: Install requirements run: | pip install -r requirements.txt