Skip to content

Commit

Permalink
add github action test
Browse files Browse the repository at this point in the history
  • Loading branch information
remigenet committed Jul 24, 2024
1 parent a553e43 commit 2a60cf8
Show file tree
Hide file tree
Showing 9 changed files with 547 additions and 16 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/tkat_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Run tests on multiple backends

on:
push:
branches: [ main, beta ]
pull_request:
branches: [ main ]


jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install poetry
poetry install
- name: Run tests
run: |
poetry run python run_tests.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![TKAT representation](images/model_representation.jpeg)

This folder includes the original code implemented for the [paper](https://arxiv.org/abs/2406.02486) of the same name.
This folder includes the original code implemented for the [paper](https://arxiv.org/abs/2406.02486) of the same name. The model is made in keras3 and is supporting all backend (jax, tensorflow, pytorch).

It is inspired on the Temporal Fusion Transformer by [google-research](https://github.com/google-research/google-research/tree/master/tft) and the [Temporal Kolmogorov Arnold Network](https://github.com/remigenet/TKAN).

Expand Down
27 changes: 22 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,31 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "tkat"
version = "0.1.1"
version = "0.2.0"
description = "Temporal KAN Transformer"
authors = [ "Rémi Genet", "Hugo Inzirillo"]
readme = "README.md"
packages = [{include = "tkat"}]

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
numpy = ">=1.2,<2"
tensorflow = ">=2.8,<3"
tkan = ">=0.3.0,<0.4.0"
python = ">=3.9,<3.12"
keras = ">=3.0.0,<4.0"
keras_efficient_kan = "^0.1.4"
tkan = "^0.4.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-xdist = "^3.3.0"
tensorflow = "^2.15.0"
torch = "^2.0.0"
jax = "^0.4.13"
jaxlib = "^0.4.13"

[tool.pytest.ini_options]
addopts = "-v"
testpaths = ["tests"]
filterwarnings = [
"ignore:Can't initialize NVML:UserWarning",
"ignore:jax.xla_computation is deprecated:DeprecationWarning",
"ignore::DeprecationWarning:jax._src.dtypes"
]
26 changes: 26 additions & 0 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import subprocess

def run_test(backend):
env = os.environ.copy()
env['KERAS_BACKEND'] = backend
result = subprocess.run(['pytest', f'tests/test_{backend}.py'], env=env, capture_output=True, text=True)
print(f"\n--- {backend.upper()} Backend Test Results ---")
print(result.stdout)
if result.stderr:
print("Errors:")
print(result.stderr)
return result.returncode

if __name__ == "__main__":
backends = ['tensorflow', 'torch', 'jax']
exit_codes = []

for backend in backends:
exit_codes.append(run_test(backend))

if any(exit_codes):
exit(1)
else:
print("\nAll tests passed successfully!")
exit(0)
Empty file added tests/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
BACKEND = 'jax'
os.environ['KERAS_BACKEND'] = BACKEND

import pytest
import keras
from keras import ops
from keras import backend
from keras import random
from tkat import TKAT # Assuming you've defined TKAT in a separate file

def generate_random_tensor(shape):
return random.normal(shape=shape, dtype=backend.floatx())

def test_tkat_basic():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 32, 10, 5
num_unknow_features, num_know_features = 3, 2
num_embedding, num_hidden, num_heads = 8, 16, 4

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)
output = tkat_model(input_data)

expected_output_shape = (batch_size, n_ahead)
assert output.shape == expected_output_shape, f"Expected shape {expected_output_shape}, but got {output.shape}"

def test_tkat_variable_selection():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 16, 8, 3
num_unknow_features, num_know_features = 4, 3
num_embedding, num_hidden, num_heads = 4, 8, 2

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)

# Get the embedding layer output
embedding_layer = tkat_model.get_layer('embedding_layer') # Assuming you've named your EmbeddingLayer
embedded_input = embedding_layer(input_data)

# Access the variable selection networks
vsn_past = tkat_model.get_layer('vsn_past_features')
vsn_future = tkat_model.get_layer('vsn_future_features')

# Test VSN outputs
past_features = embedded_input[:, :sequence_length, :, :]
future_features = embedded_input[:, sequence_length:, :, -num_know_features:]

past_output = vsn_past(past_features)
future_output = vsn_future(future_features)

assert past_output.shape == (batch_size, sequence_length, num_hidden)
assert future_output.shape == (batch_size, n_ahead, num_hidden)



def test_tkat_attention():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 8, 6, 2
num_unknow_features, num_know_features = 4, 3
num_embedding, num_hidden, num_heads = 4, 8, 2

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)

# Get the attention layer
attention_layer = next(layer for layer in tkat_model.layers if isinstance(layer, keras.layers.MultiHeadAttention))

# Test attention output
output = tkat_model(input_data)
assert output.shape == (batch_size, n_ahead)

def test_tkat_training():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 64, 12, 4
num_unknow_features, num_know_features = 4, 3
num_embedding, num_hidden, num_heads = 8, 16, 4

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)
target_data = generate_random_tensor((batch_size, n_ahead))

tkat_model.compile(optimizer='adam', loss='mse')
history = tkat_model.fit(input_data, target_data, epochs=2, batch_size=16, verbose=0)

assert len(history.history['loss']) == 2
assert history.history['loss'][1] < history.history['loss'][0]

def test_tkat_prediction():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 32, 10, 5
num_unknow_features, num_know_features = 3, 2
num_embedding, num_hidden, num_heads = 8, 16, 4

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)

predictions = tkat_model.predict(input_data)
assert predictions.shape == (batch_size, n_ahead)

if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit 2a60cf8

Please sign in to comment.