Skip to content

Commit

Permalink
add tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Sep 7, 2024
1 parent 298520a commit beb7d47
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 59 deletions.
2 changes: 2 additions & 0 deletions REAL-Video-Enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def startRender(self):
videoWidth=self.videoWidth,
videoHeight=self.videoHeight,
videoFps=self.videoFps,
tilingEnabled=self.tilingCheckBox.isChecked(),
tilesize=self.tileSizeComboBox.currentText(),
videoFrameCount=self.videoFrameCount,
method=method,
backend=self.backendComboBox.currentText(),
Expand Down
4 changes: 2 additions & 2 deletions backend/src/SceneDetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def getPySceneDetectTransitions(self) -> Queue:

for frame_num in tqdm(range(self.totalInputFrames - 1)):

frame = self.getFrameTo100x100img(self.readQueue.get())
frame = bytesTo100x100img(self.readQueue.get(), width=self.width, height=self.height )

detectedFrameList = adaptiveDetector.process_frame(
frame_num=frame_num, frame_img=frame
Expand All @@ -96,7 +96,7 @@ def getMeanTransitions(self):
sceneChangeQueue = Queue()
detector = NPMeanSequential()
for frame_num in tqdm(range(self.totalInputFrames - 1)):
frame = bytesTo100x100img(self.readQueue.get())
frame = bytesTo100x100img(self.readQueue.get(), width=self.width, height=self.height )
if detector.sceneDetect(frame):
sceneChangeQueue.put(frame_num-1)
return sceneChangeQueue
Expand Down
112 changes: 60 additions & 52 deletions backend/src/UpscaleTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import cv2
import torch as torch
import torch.nn.functional as F


from src.Util import (
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(
self,
modelPath: str,
device="default",
tile_pad: int = 10,
tile_pad: int = 0,
precision: str = "auto",
width: int = 1920,
height: int = 1080,
Expand All @@ -81,9 +82,23 @@ def __init__(
self.device = device
model = self.loadModel(modelPath=modelPath, device=device, dtype=self.dtype)

self.width = width
self.height = height
self.videoWidth = width
self.videoHeight = height
self.tilesize = tilesize
self.tile = [self.tilesize, self.tilesize]
match self.scale:
case 1:
modulo = 4
case 2:
modulo = 2
case _:
modulo = 1
if all(t > 0 for t in self.tile):
self.pad_w = math.ceil(min(self.tile[0] + 2 * tile_pad, width) / modulo) * modulo
self.pad_h = math.ceil(min(self.tile[1] + 2 * tile_pad, height) / modulo) * modulo
else:
self.pad_w = width
self.pad_h = height

if backend == "tensorrt":
import tensorrt as trt
Expand All @@ -93,7 +108,7 @@ def __init__(
os.path.realpath(trt_cache_dir),
(
f"{os.path.basename(modelPath)}"
+ f"_{width}x{height}"
+ f"_{self.pad_w}x{self.pad_h}"
+ f"_{'fp16' if self.dtype == torch.float16 else 'fp32'}"
+ f"_{torch.cuda.get_device_name(device)}"
+ f"_trt-{trt.__version__}"
Expand All @@ -109,7 +124,7 @@ def __init__(
if not os.path.isfile(trt_engine_path):
inputs = [
torch.zeros(
(1, 3, self.height, self.width),
(1, 3, self.pad_h, self.pad_w),
dtype=self.dtype,
device=device,
)
Expand Down Expand Up @@ -170,7 +185,7 @@ def loadModel(
def bytesToFrame(self, frame):
return (
torch.frombuffer(frame, dtype=torch.uint8)
.reshape(self.height, self.width, 3)
.reshape(self.videoHeight, self.videoWidth, 3)
.to(self.device, dtype=self.dtype)
.permute(2, 0, 1)
.unsqueeze(0)
Expand Down Expand Up @@ -212,76 +227,69 @@ def getScale(self):
@torch.inference_mode()
def renderTiledImage(
self,
image: torch.Tensor,
img: torch.Tensor,
) -> torch.Tensor:
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
scale = self.scale
tile = self.tile
tile_pad = self.tile_pad

Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = image.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
batch, channel, height, width = img.shape
output_shape = (batch, channel, height * scale, width * scale)

# start with black image
output = image.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tilesize)
tiles_y = math.ceil(height / self.tilesize)
output = img.new_zeros(output_shape)

tiles_x = math.ceil(width / tile[0])
tiles_y = math.ceil(height / tile[1])

# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tilesize
ofs_y = y * self.tilesize
ofs_x = x * tile[0]
ofs_y = y * tile[1]

# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tilesize, width)
input_end_x = min(ofs_x + tile[0], width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tilesize, height)
input_end_y = min(ofs_y + tile[1], height)

# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
input_start_x_pad = max(input_start_x - tile_pad, 0)
input_end_x_pad = min(input_end_x + tile_pad, width)
input_start_y_pad = max(input_start_y - tile_pad, 0)
input_end_y_pad = min(input_end_y + tile_pad, height)

# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = image[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]

# upscale tile
with torch.no_grad():
output_tile = self.renderImage(input_tile)
input_tile = img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]

h, w = input_tile.shape[2:]
input_tile = F.pad(input_tile, (0, self.pad_w - w, 0, self.pad_h - h), "replicate")

# process tile
output_tile = self.model(input_tile)

output_tile = output_tile[:, :, : h * scale, : w * scale]

# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
output_start_x = input_start_x * scale
output_end_x = input_end_x * scale
output_start_y = input_start_y * scale
output_end_y = input_end_y * scale

# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
output_start_x_tile = (input_start_x - input_start_x_pad) * scale
output_end_x_tile = output_start_x_tile + input_tile_width * scale
output_start_y_tile = (input_start_y - input_start_y_pad) * scale
output_end_y_tile = output_start_y_tile + input_tile_height * scale

# put tile into output image
output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[
:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile
]
return output

return output
4 changes: 2 additions & 2 deletions backend/src/Util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def log(message: str):
with open(os.path.join(cwd, "backend_log.txt"), "a") as f:
f.write(message + "\n")

def bytesTo100x100img(self, image: bytes) -> np.ndarray:
frame = np.frombuffer(image,dtype=np.uint8).reshape(self.height, self.width, 3)
def bytesTo100x100img(image: bytes, width, height) -> np.ndarray:
frame = np.frombuffer(image,dtype=np.uint8).reshape(height, width, 3)
frame = cv2.resize(
frame, dsize=(100, 100)
)
Expand Down
20 changes: 19 additions & 1 deletion src/ui/ProcessTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,21 @@ def QConnect(self):
combobox.currentIndexChanged.connect(
self.switchInterpolationAndUpscale
)

# set tile size visible to false by default
self.parent.tileSizeContainer.setVisible(False)
# connect up tilesize container visiable
self.parent.tilingCheckBox.stateChanged.connect(
lambda: self.parent.tileSizeContainer.setVisible(self.parent.tilingCheckBox.isChecked())
)

self.parent.inputFileText.textChanged.connect(self.parent.updateVideoGUIDetails)
self.parent.interpolationMultiplierSpinBox.valueChanged.connect(
self.parent.updateVideoGUIDetails
)
self.parent.modelComboBox.currentIndexChanged.connect(self.parent.updateVideoGUIDetails)



def killRenderProcess(self):
try: # kills render process if necessary
self.renderProcess.terminate()
Expand All @@ -114,8 +121,10 @@ def switchInterpolationAndUpscale(self):

if method.lower() == "interpolate":
self.parent.interpolationContainer.setVisible(True)
self.parent.upscaleContainer.setVisible(False)
else:
self.parent.interpolationContainer.setVisible(False)
self.parent.upscaleContainer.setVisible(True)

self.parent.updateVideoGUIDetails()

Expand All @@ -127,6 +136,8 @@ def run(
videoHeight: int,
videoFps: float,
videoFrameCount: int,
tilesize: int,
tilingEnabled: bool,
method: str,
backend: str,
interpolationTimes: int,
Expand All @@ -138,6 +149,8 @@ def run(
self.videoWidth = videoWidth
self.videoHeight = videoHeight
self.videoFps = videoFps
self.tilingEnabled = tilingEnabled
self.tilesize = tilesize
self.videoFrameCount = videoFrameCount
models = self.getTotalModels(method=method, backend=backend)

Expand Down Expand Up @@ -230,6 +243,11 @@ def renderToPipeThread(self, method: str, backend: str, interpolateTimes: int):
"--interpolateFactor",
"1",
]
if self.tilingEnabled:
command += [
"--tilesize",
f"{self.tilesize}",
]
if method == "Interpolate":
command += [
"--interpolateModel",
Expand Down
4 changes: 2 additions & 2 deletions testRVEInterface.ui
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ li.checked::marker { content: "\2612"; }
<widget class="QWidget" name="tilingContainer" native="true">
<layout class="QHBoxLayout" name="horizontalLayout_16">
<property name="spacing">
<number>0</number>
<number>6</number>
</property>
<property name="leftMargin">
<number>0</number>
Expand Down Expand Up @@ -747,7 +747,7 @@ li.checked::marker { content: &quot;\2612&quot;; }
</size>
</property>
<property name="toolTip">
<string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p&gt;Perform processing without outputing new video. This tests the raw performance of the inference.&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
<string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p&gt;Split up processing upscaled frames into chunks.&lt;/p&gt;&lt;p&gt;Lowers VRAM usage, but also slows down render. &lt;/p&gt;&lt;p&gt;Only use when render failes due to VRAM limits.&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
</property>
<property name="text">
<string/>
Expand Down

0 comments on commit beb7d47

Please sign in to comment.