Skip to content

Commit

Permalink
feat: chex-ify testsuite (#12)
Browse files Browse the repository at this point in the history
* drop `.pre-commit-config.yaml`.
* fix `README` instructions for running the cli.
* add `.vscode/settings.json` with a simple vscode config.
* delete redundant `jflux/__main__.py`.
* enforce ruff style rules `E`, `F` and `W`.
* `chex`-ify testsuite.
  • Loading branch information
SauravMaheshkar authored Oct 9, 2024
1 parent 0c9fb04 commit bf2d84c
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 240 deletions.
2 changes: 1 addition & 1 deletion .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ $ uv sync
## Running

```shell
$ uv jflux
$ uv run jflux
```

## References
Expand Down
20 changes: 0 additions & 20 deletions .pre-commit-config.yaml

This file was deleted.

15 changes: 15 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff",
},
}
4 changes: 0 additions & 4 deletions jflux/__main__.py

This file was deleted.

13 changes: 9 additions & 4 deletions jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import jax
import jax.numpy as jnp
import numpy as np
from einops import rearrange
from fire import Fire
from flax import nnx
from jax.typing import DTypeLike
from PIL import Image

from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
Expand Down Expand Up @@ -124,7 +123,8 @@ def main(
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
num_steps: number of sampling steps
(default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
Expand Down Expand Up @@ -216,7 +216,12 @@ def main(
x = x.clip(-1, 1)
x = rearrange(x[0], "c h w -> h w c")

img = Image.fromarray((127.5 * (x + 1.0)))
x = 127.5 * (x + 1.0)
x_numpy = np.array(x.astype(jnp.uint8))
img = Image.fromarray(x_numpy)

img.save(fn, quality=95, subsampling=0)
idx += 1

if loop:
print("-" * 80)
Expand Down
2 changes: 1 addition & 1 deletion jflux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, params: FluxParams):
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" # noqa: E501
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
Expand Down
3 changes: 3 additions & 0 deletions jflux/port.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from einops import rearrange


##############################################################################################
# AUTOENCODER MODEL PORTING
##############################################################################################
Expand Down Expand Up @@ -481,3 +482,5 @@ def port_flux(flux, tensors):
tensors=tensors,
prefix="final_layer",
)

return flux
2 changes: 0 additions & 2 deletions jflux/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from dataclasses import dataclass

import jax
import torch # need for t5 and clip
from flax import nnx
from huggingface_hub import hf_hub_download
from jax import numpy as jnp
from jax.typing import DTypeLike
from safetensors import safe_open

from jflux.model import Flux, FluxParams
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies = [
"einops>=0.8.0",
"fire>=0.6.0",
"flax>=0.9.0",
"jflux",
# FIXME: Allow for local installation without GPUs as well `jax[cuda12]`
"jax>=0.4.31",
"mypy>=1.11.2",
Expand All @@ -22,6 +21,7 @@ dependencies = [
jflux = "jflux.cli:app"

[tool.uv]
package = true
dev-dependencies = [
"flux",
"pytest>=8.3.3",
Expand All @@ -32,7 +32,10 @@ jflux = { workspace = true }
flux = { git = "https://github.com/black-forest-labs/flux.git" }

[tool.ruff.lint]
select = ["I001"]
select = ["E", "F", "I001", "W"]

[tool.ruff.lint.isort]
lines-after-imports = 2

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit bf2d84c

Please sign in to comment.