Skip to content

Commit

Permalink
Send predict logic to predict.py, GDS origin, small docstring changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dusandinho committed Oct 26, 2024
1 parent 820b17b commit 8a8f951
Showing 1 changed file with 35 additions and 155 deletions.
190 changes: 35 additions & 155 deletions prefab/device.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
"""Provides the Device class for representing photonic devices."""

import base64
import io
import json
import os
from typing import Optional

import cv2
import gdstk
import matplotlib.pyplot as plt
import numpy as np
import requests
import toml
from matplotlib.axes import Axes
from matplotlib.patches import Rectangle
from PIL import Image
from pydantic import BaseModel, Field, conint, root_validator, validator
from scipy.ndimage import distance_transform_edt
from skimage import measure
from tqdm import tqdm

from . import compare, geometry
from .models import Model
from .predict import predict_array

Image.MAX_IMAGE_PIXELS = None

Expand All @@ -35,7 +29,7 @@ class BufferSpec(BaseModel):
providing extra space for device fabrication processes or for ensuring that the
device is isolated from surrounding structures.
Attributes
Parameters
----------
mode : dict[str, str]
A dictionary that defines the buffer mode for each side of the device
Expand Down Expand Up @@ -210,147 +204,6 @@ def is_binary(self) -> bool:
or np.array_equal(unique_values, [1])
)

def _encode_array(self, array):
image = Image.fromarray(np.uint8(array * 255))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
encoded_png = base64.b64encode(buffered.getvalue()).decode("utf-8")
return encoded_png

def _decode_array(self, encoded_png):
binary_data = base64.b64decode(encoded_png)
image = Image.open(io.BytesIO(binary_data))
return np.array(image) / 255

def _predict_array(
self,
model: Model,
model_type: str,
binarize: bool,
gpu: bool = False,
) -> "Device":
try:
with open(os.path.expanduser("~/.prefab.toml")) as file:
content = file.readlines()
access_token = None
refresh_token = None
for line in content:
if "access_token" in line:
access_token = line.split("=")[1].strip().strip('"')
if "refresh_token" in line:
refresh_token = line.split("=")[1].strip().strip('"')
break
if not access_token or not refresh_token:
raise ValueError("Token not found in the configuration file.")
except FileNotFoundError:
raise FileNotFoundError(
"Could not validate user.\n"
"Please update prefab using: pip install --upgrade prefab.\n"
"Signup/login and generate a new token.\n"
"See https://www.prefabphotonics.com/docs/guides/quickstart."
) from None

headers = {
"Authorization": f"Bearer {access_token}",
"X-Refresh-Token": refresh_token,
}

predict_data = {
"device_array": self._encode_array(self.device_array[:, :, 0]),
"model": model.to_json(),
"model_type": model_type,
"binary": binarize,
}
json_data = json.dumps(predict_data)

endpoint_url = (
"https://prefab-photonics--predict-gpu-v1.modal.run"
if gpu
else "https://prefab-photonics--predict-v1.modal.run"
)

try:
with requests.post(
endpoint_url, data=json_data, headers=headers, stream=True
) as response:
response.raise_for_status()
event_type = None
model_descriptions = {
"p": "Prediction",
"c": "Correction",
"s": "SEMulate",
}
progress_bar = tqdm(
total=100,
desc=f"{model_descriptions[model_type]}",
unit="%",
colour="green",
bar_format="{l_bar}{bar:30}{r_bar}{bar:-10b}",
)

for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8").strip()
if decoded_line.startswith("event:"):
event_type = decoded_line.split(":")[1].strip()
elif decoded_line.startswith("data:"):
try:
data_content = json.loads(
decoded_line.split("data: ")[1]
)
if event_type == "progress":
progress = round(100 * data_content["progress"])
progress_bar.update(progress - progress_bar.n)
elif event_type == "result":
results = []
for key in sorted(data_content.keys()):
if key.startswith("result"):
decoded_image = self._decode_array(
data_content[key]
)
results.append(decoded_image)

if results:
prediction = np.stack(results, axis=-1)
if binarize:
prediction = geometry.binarize_hard(
prediction
)
progress_bar.close()
return prediction
elif event_type == "end":
print("Stream ended.")
progress_bar.close()
break
elif event_type == "auth":
if "new_refresh_token" in data_content["auth"]:
prefab_file_path = os.path.expanduser(
"~/.prefab.toml"
)
with open(
prefab_file_path, "w", encoding="utf-8"
) as toml_file:
toml.dump(
{
"access_token": data_content[
"auth"
]["new_access_token"],
"refresh_token": data_content[
"auth"
]["new_refresh_token"],
},
toml_file,
)
elif event_type == "error":
raise ValueError(f"{data_content['error']}")
except json.JSONDecodeError:
raise ValueError(
"Failed to decode JSON:",
decoded_line.split("data: ")[1],
) from None
except requests.RequestException as e:
raise RuntimeError(f"Request failed: {e}") from e

def predict(
self,
model: Model,
Expand Down Expand Up @@ -393,7 +246,8 @@ def predict(
If the prediction service returns an error or if the response from the
service cannot be processed correctly.
"""
prediction_array = self._predict_array(
prediction_array = predict_array(
device_array=self.device_array,
model=model,
model_type="p",
binarize=binarize,
Expand Down Expand Up @@ -445,7 +299,8 @@ def correct(
If the correction service returns an error or if the response from the
service cannot be processed correctly.
"""
correction_array = self._predict_array(
correction_array = predict_array(
device_array=self.device_array,
model=model,
model_type="c",
binarize=binarize,
Expand Down Expand Up @@ -487,7 +342,8 @@ def semulate(
A new instance of the Device class with its geometry transformed to simulate
an SEM image style.
"""
semulated_array = self._predict_array(
semulated_array = predict_array(
device_array=self.device_array,
model=model,
model_type="s",
binarize=False,
Expand Down Expand Up @@ -550,6 +406,7 @@ def to_gds(
cell_name: str = "prefab_device",
gds_layer: tuple[int, int] = (1, 0),
contour_approx_mode: int = 2,
origin: tuple[float, float] = (0.0, 0.0),
):
"""
Exports the device geometry as a GDSII file.
Expand All @@ -572,11 +429,15 @@ def to_gds(
The mode of contour approximation used during the conversion. Defaults to 2,
which corresponds to `cv2.CHAIN_APPROX_SIMPLE`, a method that compresses
horizontal, vertical, and diagonal segments and leaves only their endpoints.
origin : tuple[float, float], optional
The x and y coordinates of the origin for the GDSII export. Defaults to
(0.0, 0.0).
"""
gdstk_cell = self.flatten()._device_to_gdstk(
cell_name=cell_name,
gds_layer=gds_layer,
contour_approx_mode=contour_approx_mode,
origin=origin,
)
print(f"Saving GDS to '{gds_path}'...")
gdstk_library = gdstk.Library()
Expand All @@ -588,6 +449,7 @@ def to_gdstk(
cell_name: str = "prefab_device",
gds_layer: tuple[int, int] = (1, 0),
contour_approx_mode: int = 2,
origin: tuple[float, float] = (0.0, 0.0),
):
"""
Converts the device geometry to a GDSTK cell object.
Expand All @@ -607,6 +469,9 @@ def to_gdstk(
The mode of contour approximation used during the conversion. Defaults to 2,
which corresponds to `cv2.CHAIN_APPROX_SIMPLE`, a method that compresses
horizontal, vertical, and diagonal segments and leaves only their endpoints.
origin : tuple[float, float], optional
The x and y coordinates of the origin for the GDSTK cell. Defaults to
(0.0, 0.0).
Returns
-------
Expand All @@ -618,6 +483,7 @@ def to_gdstk(
cell_name=cell_name,
gds_layer=gds_layer,
contour_approx_mode=contour_approx_mode,
origin=origin,
)
return gdstk_cell

Expand All @@ -626,6 +492,7 @@ def _device_to_gdstk(
cell_name: str,
gds_layer: tuple[int, int],
contour_approx_mode: int,
origin: tuple[float, float],
) -> gdstk.Cell:
approx_mode_mapping = {
1: cv2.CHAIN_APPROX_NONE,
Expand Down Expand Up @@ -662,8 +529,21 @@ def _device_to_gdstk(
polygons_to_process = hierarchy_polygons[level]

if polygons_to_process:
center_x_nm = self.device_array.shape[1] / 2
center_y_nm = self.device_array.shape[0] / 2

center_x_um = center_x_nm / 1000
center_y_um = center_y_nm / 1000

adjusted_polygons = [
[
(x - center_x_um + origin[0], y - center_y_um + origin[1])
for x, y in polygon
]
for polygon in polygons_to_process
]
processed_polygons = gdstk.boolean(
polygons_to_process,
adjusted_polygons,
processed_polygons,
operation,
layer=gds_layer[0],
Expand Down Expand Up @@ -1409,8 +1289,8 @@ def flatten(self) -> "Device":
Returns
-------
np.ndarray
The flattened array with values scaled between 0 and 1.
Device
A new instance of the Device with the flattened geometry.
"""
flattened_device_array = geometry.flatten(device_array=self.device_array)
return self.model_copy(update={"device_array": flattened_device_array})
Expand Down

0 comments on commit 8a8f951

Please sign in to comment.