Skip to content

Commit

Permalink
Working distortion map
Browse files Browse the repository at this point in the history
  • Loading branch information
AlvaroHG committed Nov 22, 2024
1 parent ef9e76e commit 40c978c
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 64 deletions.
25 changes: 22 additions & 3 deletions test_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import os
from ai2thor.interact import InteractiveControllerPrompt

import numpy as np

def load_scene(scene_name, house_path=None, run_in_editor=False, platform=None, local_build=False, commit_id=None, fov=120, distortion=False, image_dir=None, width=300, height=300):
if image_dir is not None:
Expand Down Expand Up @@ -40,6 +40,7 @@ def load_scene(scene_name, house_path=None, run_in_editor=False, platform=None,
renderDistortionImage=distortion,
renderSemanticSegmentation=True,
renderInstanceSegmentation=True,
enableDistortionMap=distortion,
fieldOfView=120,
**args,
)
Expand Down Expand Up @@ -103,6 +104,24 @@ def load_scene(scene_name, house_path=None, run_in_editor=False, platform=None,
intensityY=0.93
)

evt = controller.step(
action="GetDistortionMaps",
mainCamera=True,
thidPartyCameraIndices=[0]
)
result = evt.metadata['actionReturn']
# keys = [key for (key, val) in result.items()]
maps = []

print(f"---Action {controller.last_action['action']} success: {evt.metadata['lastActionSuccess']} result {result.keys()}")
print(f"[x,y] at (0,0) (bottom left corner) len {result['mainCamera'][0][0]}")
tex_height = len(result['mainCamera'])
tex_width = len(result['mainCamera'][0])
print(f"[x,y] at (height, width) (top right corner) len {result['thirdPartyCameras'][0][tex_height-1][tex_width-1]}")

print(f'Error: {evt.metadata["errorMessage"]}')


# xpos = dict(x=0.0, y=0.900992214679718, z=0.0786)
# # sr = controller.step(
# # action="Teleport", position=xpos, rotation=dict(x=0, y=0, z=0), forceAction=True
Expand Down Expand Up @@ -217,6 +236,6 @@ def load_scene(scene_name, house_path=None, run_in_editor=False, platform=None,
fov=float(args.fov),
distortion=args.distortion,
image_dir=args.output,
width=float(args.width),
height=float(args.height)
width=int(args.width),
height=int(args.height)
) # platform="CloudRendering")
184 changes: 151 additions & 33 deletions unity/Assets/Scripts/AgentManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using Newtonsoft.Json.Serialization;
using Thor.Procedural.Data;
using Thor.Rendering;
using UnityEditor.AssetImporters;
using UnityEngine;
using UnityEngine.Experimental.Rendering;
using UnityEngine.Networking;
Expand Down Expand Up @@ -53,6 +54,7 @@ public class AgentManager : MonoBehaviour, ActionInvokable {
private bool renderNormalsImage;
private bool renderFlowImage;
private bool renderDistortionImage;
private bool enableDistortionMap;

private IEnumerable<string> activeCapturePassList;
private Socket sock = null;
Expand Down Expand Up @@ -93,7 +95,8 @@ private enum serverTypes {
"ChangeResolution",
"CoordinateFromRaycastThirdPartyCamera",
"ChangeQuality",
"SetDistortionShaderParams"
"SetDistortionShaderParams",
"GetDistortionMaps"
};
public HashSet<string> errorAllowedActions = new HashSet<string> { "Reset" };

Expand Down Expand Up @@ -341,19 +344,9 @@ public void Initialize(ServerAction action) {
action.renderInstanceSegmentation;
this.renderFlowImage = action.renderFlowImage;
this.renderDistortionImage = action.renderDistortionImage;
this.enableDistortionMap = action.enableDistortionMap;
this.fastActionEmit = action.fastActionEmit;

// TODO Refactor so that we know where these strings come from
activeCapturePassList = new Dictionary<string, bool>() {
{"_img", true},
{"_depth", this.renderDepthImage},
{"_id", this.renderInstanceSegmentation},
{"_class", this.renderSemanticSegmentation},
{"_normals", this.renderNormalsImage},
{"_flow", this.renderFlowImage},
{"_distortion", this.renderDistortionImage}
}.Where(x => x.Value).Select(x => x.Key);

PhysicsSceneManager.SetDefaultSimulationParams(action.defaultPhysicsSimulationParams);
Time.fixedDeltaTime = (
action.defaultPhysicsSimulationParams?.fixedDeltaTime
Expand Down Expand Up @@ -386,7 +379,8 @@ public void Initialize(ServerAction action) {
ResetSceneBounds();
}
// this.updateImageSynthesis(true, activeCapturePassList);
this.updateRenderingManagers(activeCapturePassList, true);
this.UpdateActivePasses();
this.UpdateRenderingManagers(activeCapturePassList, true);
this.agentManagerState = AgentState.ActionComplete;
}

Expand Down Expand Up @@ -578,7 +572,7 @@ private float ClampFieldOfView(
return (fov <= min || fov > max) ? defaultVal : fov;
}

public void updateRenderingManagers(IEnumerable<string> activePassList, bool cameraChange = false) {
public void UpdateRenderingManagers(IEnumerable<string> activePassList, bool cameraChange = false) {
var renderingManagers = agents.Select(agent => agent.m_Camera.GetComponent<RenderingManager>())
.Concat(thirdPartyCameras.Select(cam => cam.GetComponent<RenderingManager>()));
foreach (var renderingManager in renderingManagers) {
Expand All @@ -587,6 +581,21 @@ public void updateRenderingManagers(IEnumerable<string> activePassList, bool cam
}
}

private void UpdateActivePasses() {
// TODO Refactor so that we know where these strings come from
Debug.Log($"--------- enableDistortionMap {this.enableDistortionMap}");
activeCapturePassList = new Dictionary<string, bool>() {
{"_img", true},
{"_depth", this.renderDepthImage},
{"_id", this.renderInstanceSegmentation},
{"_class", this.renderSemanticSegmentation},
{"_normals", this.renderNormalsImage},
{"_flow", this.renderFlowImage},
{"_distortion", this.renderDistortionImage},
{"_distortion_map", this.enableDistortionMap}
}.Where(x => x.Value).Select(x => x.Key);
}

public void updateImageSynthesis(bool status, IEnumerable<string> activePassList, bool cameraChange = false) {
foreach (var agent in this.agents) {
agent.updateImageSynthesis(status, activePassList);
Expand Down Expand Up @@ -744,15 +753,7 @@ private void updateCameraProperties(
var renderingManager = camera.GetComponent<RenderingManager>();
// renderingManager.OnCameraChange();

activeCapturePassList = new Dictionary<string, bool>() {
{"_img", true},
{"_depth", this.renderDepthImage},
{"_id", this.renderInstanceSegmentation},
{"_class", this.renderSemanticSegmentation},
{"_normals", this.renderNormalsImage},
{"_flow", this.renderFlowImage},
{"_distortion", this.renderDistortionImage}
}.Where(x => x.Value).Select(x => x.Key);
this.UpdateActivePasses();
renderingManager.EnablePasses(activeCapturePassList, cameraChange: true);

// this.updateRenderingManagers(activeCapturePassList, true);
Expand Down Expand Up @@ -2096,6 +2097,7 @@ public void SetCriticalErrorState() {
this.agentManagerState = AgentState.Error;
}


public ActionFinished SetDistortionShaderParams(bool mainCamera = true, IEnumerable<int> thidPartyCameraIndices = null, float zoomPercent = 1.0f, float k1 = 0.0f, float k2 = 0.0f, float k3 = 0.0f, float k4 = 0.0f, float strength = 1.0f, float intensityX = 1.0f, float intensityY = 1.0f) {

IEnumerable<RenderingManager> renderingManagers = mainCamera ? new List<RenderingManager>() {this.primaryAgent.m_Camera.GetComponent<RenderingManager>()} : new List<RenderingManager>();
Expand All @@ -2105,23 +2107,138 @@ public ActionFinished SetDistortionShaderParams(bool mainCamera = true, IEnumera
// return new ActionFinished(success: false, errorMessage: "No RenderingManager, make sure you pass 'renderDistortionImage = true' to the agent constructor.");
// }
foreach (var renderingManager in renderingManagers) {
var distortion = renderingManager.GetCapturePass("_distortion");
var distortion = renderingManager.GetCapturePass<RenderToTexture>("_distortion");
if (distortion == null) {
return new ActionFinished(success: false, errorMessage: "No Distortion pass, make sure you pass 'renderDistortionImage = true' to the agent constructor.");
}
var material = distortion.material;
material.SetFloat("_ZoomPercent", zoomPercent);
material.SetFloat("_k1", k1);
material.SetFloat("_k2", k2);
material.SetFloat("_k3", k3);
material.SetFloat("_k4", k4);
material.SetFloat("_DistortionIntensityX", intensityX);
material.SetFloat("_DistortionIntensityY", intensityY);
material.SetFloat("_LensDistortionStrength", strength);
// var material = distortion.material;
var mats = new List<Material>() {distortion.material, renderingManager.distortionMap.material};
foreach (var material in mats) {
material.SetFloat("_ZoomPercent", zoomPercent);
material.SetFloat("_k1", k1);
material.SetFloat("_k2", k2);
material.SetFloat("_k3", k3);
material.SetFloat("_k4", k4);
material.SetFloat("_DistortionIntensityX", intensityX);
material.SetFloat("_DistortionIntensityY", intensityY);
material.SetFloat("_LensDistortionStrength", strength);
}
}
return ActionFinished.Success;

}

private (float[] x, float[] y) decode(byte[] bytes, int width, int height) {

float[] x = new float[width * height];
float[] y = new float[width * height];
Debug.Log($"------------- width {width} height {height} bytes {bytes.Length} ");
int floatingPointPrecision = 8;
for (int i = 0; i < bytes.Length / floatingPointPrecision; i++)
{
int byteIndex = i * floatingPointPrecision;
byte[] localBytesX = new byte[] { bytes[i], bytes[i + 1], bytes[i + 2], bytes[i + 3] }; // converts 4 bytes to a float
x[i] = System.BitConverter.ToSingle(localBytesX, 0);
// byte[] localBytesY = new byte[] { bytes[i + 4], bytes[i + 5], bytes[i + 6], bytes[i + 7] };
// y[i] = System.BitConverter.ToSingle(localBytesY, 0);
if (i == 0 || i == 1) {
// Debug.Log($"--------- i {i} LocalBytesX {string.Join(", ", localBytesX)} decodex {x[i]} localBytesY {string.Join(", ",localBytesY)} decodey {y[i]}");
Debug.Log($"--------- i {i} LocalBytesX {string.Join(", ", localBytesX)} decodex {x[i]}");
}
}
// for (int i = 0; i < bytes.Length / 4; i++)
// {
// int byteIndex = i * 4;
// byte[] localBytesX = new byte[] { bytes[i], bytes[i + 1] }; // converts 4 bytes to a float
// x[i] = System.BitConverter.ToSingle(localBytesX, 0);
// byte[] localBytesY = new byte[] { bytes[i + 2], bytes[i + 3] };
// y[i] = System.BitConverter.ToSingle(localBytesY, 0);
// if (i == 0 || i == 1) {
// Debug.Log($"--------- i {i} LocalBytesX {string.Join(", ", localBytesX)} decodex {x[i]} localBytesY {string.Join(", ",localBytesY)} decodey {y[i]}");
// }
// }
return (x, y);
}

public class DistortionMapReturn {
public float[][,] mainCamera;
public List<float[][,]> thirdPartyCameras;

}

public ActionFinished GetDistortionMaps(bool mainCamera = true, IEnumerable<int> thidPartyCameraIndices = null) {
// TODO multiagent support
var thirdPartyCamResults = new List<Dictionary<string, object>>();
// var result = new Dictionary<string, object>() { {"thidPartyCameras", thirdPartyCamResults}};
IEnumerable<(int index, RenderingManager rm)> renderingManagers = mainCamera ? new List<(int, RenderingManager)>() {(-1, this.primaryAgent.m_Camera.GetComponent<RenderingManager>())} : new List<(int, RenderingManager)>();
renderingManagers = thidPartyCameraIndices != null ? renderingManagers.Concat(this.thirdPartyCameras.Where((cam, i) => thidPartyCameraIndices.Contains(i)).Select((cam, i) => (i, cam.GetComponent<RenderingManager>()))) : renderingManagers;
renderingManagers = renderingManagers.ToList();

var result = new DistortionMapReturn() {
thirdPartyCameras = new List<float[][,]>()
};
foreach (var (index, renderingManager) in renderingManagers) {
var rt = renderingManager.distortionMap.GetRenderTexture();
var floats = decode(renderingManager.getDistortionMapBytes(), rt.width, rt.height);
// var map = new Dictionary<string, object>() {
// {"x", floats.x},
// // {"y", floats.y},
// {"width", rt.width},
// {"height", rt.height}
// };
var bytes = renderingManager.getDistortionMapBytes();
int i = 0;
float x_0 = System.BitConverter.ToSingle(new byte[]{ bytes[i], bytes[ i+ 1], bytes[i + 2], bytes[i + 3] }, 0);
i = 8;
float x_1 = System.BitConverter.ToSingle(new byte[]{ bytes[i], bytes[ i + 1], bytes[i + 2], bytes[i + 3] }, 0);

var maxIndex = bytes.Length / 8;
var x_f = Enumerable.Range(0, maxIndex).Select(i => (
x: System.BitConverter.ToSingle(new byte[]{ bytes[i*8], bytes[ i*8 + 1], bytes[i*8 + 2], bytes[i*8 + 3] }, 0),
y: System.BitConverter.ToSingle(new byte[]{ bytes[i*8 + 4], bytes[ i*8 + 5], bytes[i*8 + 6], bytes[i*8 + 7] }, 0)
)
).ToArray();

var y_f = Enumerable.Range(0, maxIndex).Select(i => System.BitConverter.ToSingle(new byte[]{ bytes[i*8 + 4], bytes[ i*8 + 5], bytes[i*8 + 6], bytes[i*8 + 7] }, 0) );

var map = Enumerable.Range(0, rt.height).Select(i => {
var row = new float[rt.width, 2];
for (int j = 0; j < rt.width; j++) {
var val = x_f[i*rt.width + j];
row[j, 0] = val.x;
row[j, 1] = val.y;
}
return row;
}).ToArray();
// return new ActionFinished(success: true, actionReturn: new DistortionMapReturn() {
// x =bytes.Select(x => (int)x).ToArray(), x_0 = x_0, x_1 = x_1, test = m
// });

// var map = new Dictionary<string, object>() {
// {"x_0", floats.x[0]},
// {"x_1", floats.x[1]},
// {"y_0", floats.y[0]},
// {"y_1", floats.y[1]},
// {"width", rt.width},
// {"height", rt.height}
// };
// var colors = renderingManager.getDistortionMapColors();
// var map = new Dictionary<string, object>() {
// {"map", colors.Select(c => new Dictionary<string, float>() { {"x", c.r / 255.0f}, {"y", c.g / 255.0f}})},
// {"width", rt.width},
// {"height", rt.height}
// };
if (index == -1) {

result.mainCamera = map;
}
else {
result.thirdPartyCameras.Add(map);
}
}
return new ActionFinished(success: true, actionReturn: result);
// return new ActionFinished(success: true, actionReturn: result);
}
}

[Serializable]
Expand Down Expand Up @@ -2875,6 +2992,7 @@ public class ServerAction {
public bool renderNormalsImage;
public bool renderFlowImage;
public bool renderDistortionImage;
public bool enableDistortionMap;
public float cameraY = 0.675f;
public bool placeStationary = true; // when placing/spawning an object, do we spawn it stationary (kinematic true) or spawn and let physics resolve final position

Expand Down
Loading

0 comments on commit 40c978c

Please sign in to comment.