Skip to content

Commit

Permalink
Add standalone predict.py module for ndarrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Dusandinho committed Oct 26, 2024
1 parent b0736c1 commit 820b17b
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 1 deletion.
3 changes: 2 additions & 1 deletion prefab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

__version__ = "1.1.2"

from . import compare, geometry, read, shapes
from . import compare, geometry, predict, read, shapes
from .device import BufferSpec, Device
from .models import models

__all__ = [
"Device",
"BufferSpec",
"geometry",
"predict",
"read",
"shapes",
"compare",
Expand Down
260 changes: 260 additions & 0 deletions prefab/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import base64
import io
import json
import os

import numpy as np
import requests
import toml
from PIL import Image
from tqdm import tqdm

from .geometry import binarize_hard
from .models import Model

BASE_URL = "https://prefab-photonics--predict"


def predict_array(
device_array: np.ndarray,
model: Model,
model_type: str,
binarize: bool,
gpu: bool = False,
) -> np.ndarray:
"""
Predicts the output array for a given device array using a specified model.
This function sends the device array to a prediction service, which uses a machine
learning model to predict the outcome of the nanofabrication process. The prediction
can be performed on a GPU if specified.
Parameters
----------
device_array : np.ndarray
The input device array to be predicted.
model : Model
The model to use for prediction.
model_type : str
The type of model to use (e.g., 'p', 'c', 's').
binarize : bool
Whether to binarize the output.
gpu : bool, optional
Whether to use GPU for prediction. Defaults to False.
Returns
-------
np.ndarray
The predicted output array.
Raises
------
RuntimeError
If the request to the prediction service fails.
"""
headers = _prepare_headers()
predict_data = _prepare_predict_data(device_array, model, model_type, binarize)
endpoint_url = f"{BASE_URL}-gpu-v1.modal.run" if gpu else f"{BASE_URL}-v1.modal.run"

try:
with requests.post(
endpoint_url,
data=json.dumps(predict_data),
headers=headers,
stream=True,
) as response:
response.raise_for_status()
return _process_response(response, model_type, binarize)
except requests.RequestException as e:
raise RuntimeError(f"Request failed: {e}") from e


def predict_array_with_grad(
device_array: np.ndarray, model: Model, model_type: str
) -> tuple[np.ndarray, np.ndarray]:
"""
Predicts the output array and its gradient for a given device array using a
specified model.
This function sends the device array to a prediction service, which uses a machine
learning model to predict both the outcome and the gradient of the nanofabrication
process.
Parameters
----------
device_array : np.ndarray
The input device array to be predicted.
model : Model
The model to use for prediction.
model_type : str
The type of model to use (e.g., 'p', 'c', 's').
Returns
-------
tuple[np.ndarray, np.ndarray]
A tuple containing the predicted output array and its gradient.
Raises
------
RuntimeError
If the request to the prediction service fails.
"""
headers = _prepare_headers()
predict_data = _prepare_predict_data(device_array, model, model_type, False)
endpoint_url = f"{BASE_URL}-with-grad-v1.modal.run"

response = requests.post(
endpoint_url, data=json.dumps(predict_data), headers=headers
)
prediction_array = _decode_array(response.json()["prediction_array"])
gradient_array = _decode_array(response.json()["gradient_array"])
gradient_min = response.json()["gradient_min"]
gradient_max = response.json()["gradient_max"]
gradient_range = gradient_max - gradient_min
gradient_array = gradient_array * gradient_range + gradient_min

return (prediction_array, gradient_array)


def _encode_array(array):
"""Encode a numpy array as a PNG image and return the base64 encoded string."""
image = Image.fromarray(np.uint8(array * 255))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
encoded_png = base64.b64encode(buffered.getvalue()).decode("utf-8")
return encoded_png


def _decode_array(encoded_png):
"""Decode a base64 encoded PNG image and return a numpy array."""
binary_data = base64.b64decode(encoded_png)
image = Image.open(io.BytesIO(binary_data))
return np.array(image) / 255


def _read_tokens():
"""Read access and refresh tokens from the configuration file."""
token_file_path = os.path.expanduser("~/.prefab.toml")
try:
with open(token_file_path) as file:
tokens = toml.load(file)
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
if not access_token or not refresh_token:
raise ValueError("Tokens not found in the configuration file.")
return access_token, refresh_token
except FileNotFoundError:
raise FileNotFoundError(
"Could not validate user.\n"
"Please update prefab using: pip install --upgrade prefab.\n"
"Signup/login and generate a new token.\n"
"See https://www.prefabphotonics.com/docs/guides/quickstart."
) from None


def _prepare_headers():
"""Prepare HTTP headers for the request."""
access_token, refresh_token = _read_tokens()
return {
"Authorization": f"Bearer {access_token}",
"X-Refresh-Token": refresh_token,
}


def _prepare_predict_data(device_array, model, model_type, binarize):
"""Prepare the data payload for the prediction request."""
return {
"device_array": _encode_array(np.squeeze(device_array)),
"model": model.to_json(),
"model_type": model_type,
"binary": binarize,
}


def _process_response(response, model_type, binarize):
"""Process the streaming response from the prediction request."""
event_type = None
model_descriptions = {
"p": "Prediction",
"c": "Correction",
"s": "SEMulate",
}
progress_bar = tqdm(
total=100,
desc=model_descriptions.get(model_type, "Processing"),
unit="%",
colour="green",
bar_format="{l_bar}{bar:30}{r_bar}{bar:-10b}",
)

for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8").strip()
if decoded_line.startswith("event:"):
event_type = decoded_line.split(":", 1)[1].strip()
elif decoded_line.startswith("data:"):
data_content = _parse_data_line(decoded_line)
result = _handle_event(event_type, data_content, progress_bar, binarize)
if result is not None:
progress_bar.close()
return result
progress_bar.close()


def _parse_data_line(decoded_line):
"""Parse a data line from the response stream."""
data_line = decoded_line.split(":", 1)[1].strip()
try:
return json.loads(data_line)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON: {data_line}") from None


def _handle_event(event_type, data_content, progress_bar, binarize):
"""Handle different types of events received from the server."""
if event_type == "progress":
_update_progress(progress_bar, data_content)
elif event_type == "result":
return _process_result(data_content, binarize)
elif event_type == "end":
print("Stream ended.")
elif event_type == "auth":
_update_tokens(data_content.get("auth", {}))
elif event_type == "error":
raise ValueError(f"{data_content['error']}")


def _update_progress(progress_bar, data_content):
"""Update the progress bar based on the progress event."""
progress = round(100 * data_content.get("progress", 0))
progress_bar.update(progress - progress_bar.n)


def _process_result(data_content, binarize):
"""Process the result event and return the prediction."""
results = [
_decode_array(data_content[key])
for key in sorted(data_content.keys())
if key.startswith("result")
]
if results:
prediction = np.stack(results, axis=-1)
if binarize:
prediction = binarize_hard(prediction)
return prediction


def _update_tokens(auth_data):
"""Update tokens if new tokens are provided in the auth event."""
new_access_token = auth_data.get("new_access_token")
new_refresh_token = auth_data.get("new_refresh_token")
if new_access_token and new_refresh_token:
prefab_file_path = os.path.expanduser("~/.prefab.toml")
with open(prefab_file_path, "w", encoding="utf-8") as toml_file:
toml.dump(
{
"access_token": new_access_token,
"refresh_token": new_refresh_token,
},
toml_file,
)

0 comments on commit 820b17b

Please sign in to comment.