Skip to content

Commit

Permalink
Compile the first ray shaders with dxc
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 24, 2024
1 parent 592b41a commit fc95cef
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 18 deletions.
2 changes: 2 additions & 0 deletions Sources/backends/d3d12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ static const wchar_t *shader_string(shader_stage stage) {
return L"ps_6_0";
case SHADER_STAGE_COMPUTE:
return L"cs_6_0";
case SHADER_STAGE_RAY_GENERATION:
return L"lib_6_3";
default: {
debug_context context = {0};
error(context, "Unsupported shader stage/version combination");
Expand Down
230 changes: 213 additions & 17 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,16 @@ static void find_referenced_globals(function *f, global_id *globals, size_t *glo
}
}

static void write_types(char *hlsl, size_t *offset, shader_stage stage, type_id input, type_id output, function *main) {
static void write_types(char *hlsl, size_t *offset, shader_stage stage, type_id input, type_id output, function *main, function **rayshaders,
size_t rayshaders_count) {
type_id types[256];
size_t types_size = 0;
find_referenced_types(main, types, &types_size);
if (main != NULL) {
find_referenced_types(main, types, &types_size);
}
for (size_t rayshader_index = 0; rayshader_index < rayshaders_count; ++rayshader_index) {
find_referenced_types(rayshaders[rayshader_index], types, &types_size);
}

for (size_t i = 0; i < types_size; ++i) {
type *t = get_type(types[i]);
Expand Down Expand Up @@ -229,10 +235,15 @@ static void write_types(char *hlsl, size_t *offset, shader_stage stage, type_id

static int global_register_indices[512];

static void write_globals(char *hlsl, size_t *offset, function *main) {
static void write_globals(char *hlsl, size_t *offset, function *main, function **rayshaders, size_t rayshaders_count) {
global_id globals[256];
size_t globals_size = 0;
find_referenced_globals(main, globals, &globals_size);
if (main != NULL) {
find_referenced_globals(main, globals, &globals_size);
}
for (size_t rayshader_index = 0; rayshader_index < rayshaders_count; ++rayshader_index) {
find_referenced_globals(rayshaders[rayshader_index], globals, &globals_size);
}

for (size_t i = 0; i < globals_size; ++i) {
global g = get_global(globals[i]);
Expand Down Expand Up @@ -261,14 +272,59 @@ static void write_globals(char *hlsl, size_t *offset, function *main) {
}
}

static void write_functions(char *hlsl, size_t *offset, shader_stage stage, function *main) {

static function *raygen_shaders[256];
static size_t raygen_shaders_size = 0;

static function *raymiss_shaders[256];
static size_t raymiss_shaders_size = 0;

static function *rayclosesthit_shaders[256];
static size_t rayclosesthit_shaders_size = 0;

static bool is_raygen_shader(function *f) {
for (size_t rayshader_index = 0; rayshader_index < raygen_shaders_size; ++rayshader_index) {
if (f == raygen_shaders[rayshader_index]) {
return true;
}
}
return false;
}

static bool is_raymiss_shader(function *f) {
for (size_t rayshader_index = 0; rayshader_index < raymiss_shaders_size; ++rayshader_index) {
if (f == raymiss_shaders[rayshader_index]) {
return true;
}
}
return false;
}

static bool is_rayclosesthit_shader(function *f) {
for (size_t rayshader_index = 0; rayshader_index < rayclosesthit_shaders_size; ++rayshader_index) {
if (f == rayclosesthit_shaders[rayshader_index]) {
return true;
}
}
return false;
}

static void write_functions(char *hlsl, size_t *offset, shader_stage stage, function *main, function **rayshaders, size_t rayshaders_count) {
function *functions[256];
size_t functions_size = 0;

functions[functions_size] = main;
functions_size += 1;
if (main != NULL) {
functions[functions_size] = main;
functions_size += 1;

find_referenced_functions(main, functions, &functions_size);
find_referenced_functions(main, functions, &functions_size);
}

for (size_t rayshader_index = 0; rayshader_index < rayshaders_count; ++rayshader_index) {
functions[functions_size] = rayshaders[rayshader_index];
functions_size += 1;
find_referenced_functions(rayshaders[rayshader_index], functions, &functions_size);
}

for (size_t i = 0; i < functions_size; ++i) {
function *f = functions[i];
Expand Down Expand Up @@ -370,6 +426,57 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
error(context, "Unsupported shader stage");
}
}
else if (is_raygen_shader(f)) {
*offset += sprintf(&hlsl[*offset], "[shader(\"raygeneration\")]\n");

*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
}
else if (is_raymiss_shader(f)) {
*offset += sprintf(&hlsl[*offset], "[shader(\"miss\")]\n");

*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset += sprintf(&hlsl[*offset], "inout %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
}
else if (is_rayclosesthit_shader(f)) {
debug_context context = {0};
check(f->parameters_size == 2, context, "rayclosesthit shader requires two arguments");
check(f->parameter_types[1].type == float2_id, context, "Second parameter of a rayclosesthit shader needs to be a float2");

*offset += sprintf(&hlsl[*offset], "[shader(\"closesthit\")]\n");

*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset += sprintf(&hlsl[*offset], "inout %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else if (parameter_index == 1) {
*offset += sprintf(&hlsl[*offset], ", BuiltInTriangleIntersectionAttributes _kong_triangle_intersection_attributes");
//type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
*offset += sprintf(&hlsl[*offset], "\t%s _%" PRIu64 " = _kong_triangle_intersection_attributes.barycentrics;\n", type_string(f->parameter_types[1].type), parameter_ids[1]);
}
else {
*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
Expand Down Expand Up @@ -522,11 +629,11 @@ static void hlsl_export_vertex(char *directory, api_kind d3d, function *main) {
check(vertex_input != NO_TYPE, context, "vertex input missing");
check(vertex_output != NO_TYPE, context, "vertex output missing");

write_types(hlsl, &offset, SHADER_STAGE_VERTEX, vertex_input, vertex_output, main);
write_types(hlsl, &offset, SHADER_STAGE_VERTEX, vertex_input, vertex_output, main, NULL, 0);

write_globals(hlsl, &offset, main);
write_globals(hlsl, &offset, main, NULL, 0);

write_functions(hlsl, &offset, SHADER_STAGE_VERTEX, main);
write_functions(hlsl, &offset, SHADER_STAGE_VERTEX, main, NULL, 0);

char *output = NULL;
size_t output_size = 0;
Expand Down Expand Up @@ -567,11 +674,11 @@ static void hlsl_export_fragment(char *directory, api_kind d3d, function *main)
debug_context context = {0};
check(pixel_input != NO_TYPE, context, "fragment input missing");

write_types(hlsl, &offset, SHADER_STAGE_FRAGMENT, pixel_input, NO_TYPE, main);
write_types(hlsl, &offset, SHADER_STAGE_FRAGMENT, pixel_input, NO_TYPE, main, NULL, 0);

write_globals(hlsl, &offset, main);
write_globals(hlsl, &offset, main, NULL, 0);

write_functions(hlsl, &offset, SHADER_STAGE_FRAGMENT, main);
write_functions(hlsl, &offset, SHADER_STAGE_FRAGMENT, main, NULL, 0);

uint8_t *output = NULL;
size_t output_size = 0;
Expand Down Expand Up @@ -606,11 +713,11 @@ static void hlsl_export_compute(char *directory, api_kind d3d, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;

write_types(hlsl, &offset, SHADER_STAGE_COMPUTE, NO_TYPE, NO_TYPE, main);
write_types(hlsl, &offset, SHADER_STAGE_COMPUTE, NO_TYPE, NO_TYPE, main, NULL, 0);

write_globals(hlsl, &offset, main);
write_globals(hlsl, &offset, main, NULL, 0);

write_functions(hlsl, &offset, SHADER_STAGE_COMPUTE, main);
write_functions(hlsl, &offset, SHADER_STAGE_COMPUTE, main, NULL, 0);

debug_context context = {0};

Expand Down Expand Up @@ -643,6 +750,49 @@ static void hlsl_export_compute(char *directory, api_kind d3d, function *main) {
write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

static void hlsl_export_all_ray_shaders(char *directory) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
debug_context context = {0};
check(hlsl != NULL, context, "Could not allocate the hlsl string");
size_t offset = 0;

function *all_rayshaders[256 * 3];
size_t all_rayshaders_size = 0;
for (size_t rayshader_index = 0; rayshader_index < raygen_shaders_size; ++rayshader_index){
all_rayshaders[all_rayshaders_size] = raygen_shaders[rayshader_index];
all_rayshaders_size += 1;
}
for (size_t rayshader_index = 0; rayshader_index < raymiss_shaders_size; ++rayshader_index) {
all_rayshaders[all_rayshaders_size] = raymiss_shaders[rayshader_index];
all_rayshaders_size += 1;
}
for (size_t rayshader_index = 0; rayshader_index < rayclosesthit_shaders_size; ++rayshader_index) {
all_rayshaders[all_rayshaders_size] = rayclosesthit_shaders[rayshader_index];
all_rayshaders_size += 1;
}

write_types(hlsl, &offset, SHADER_STAGE_RAY_GENERATION, NO_TYPE, NO_TYPE, NULL, all_rayshaders, all_rayshaders_size);

write_globals(hlsl, &offset, NULL, all_rayshaders, all_rayshaders_size);

write_functions(hlsl, &offset, SHADER_STAGE_RAY_GENERATION, NULL, all_rayshaders, all_rayshaders_size);

uint8_t *output = NULL;
size_t output_size = 0;
int result = compile_hlsl_to_d3d12(hlsl, &output, &output_size, SHADER_STAGE_RAY_GENERATION, false);
check(result == 0, context, "HLSL compilation failed");

char *name = "ray";

char filename[512];
sprintf(filename, "kong_%s", name);

char var_name[256];
sprintf(var_name, "%s_code", name);

write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

void hlsl_export(char *directory, api_kind d3d) {
int cbuffer_index = 0;
int texture_index = 0;
Expand Down Expand Up @@ -718,6 +868,48 @@ void hlsl_export(char *directory, api_kind d3d) {
}
}

for (type_id i = 0; get_type(i) != NULL; ++i) {
type *t = get_type(i);
if (!t->built_in && has_attribute(&t->attributes, add_name("raypipe"))) {
name_id raygen_shader_name = NO_NAME;
name_id raymiss_shader_name = NO_NAME;
name_id rayclosesthit_shader_name = NO_NAME;

for (size_t j = 0; j < t->members.size; ++j) {
if (t->members.m[j].name == add_name("gen")) {
raygen_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("miss")) {
raymiss_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("closest")) {
rayclosesthit_shader_name = t->members.m[j].value.identifier;
}
}

debug_context context = {0};
check(raygen_shader_name != NO_NAME, context, "Ray generation shader missing");
check(raymiss_shader_name != NO_NAME, context, "Miss shader missing");
check(rayclosesthit_shader_name != NO_NAME, context, "Closest hit shader missing");

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->name == raygen_shader_name) {
raygen_shaders[raygen_shaders_size] = f;
raygen_shaders_size += 1;
}
else if (f->name == raymiss_shader_name) {
raymiss_shaders[raymiss_shaders_size] = f;
raymiss_shaders_size += 1;
}
else if (f->name == rayclosesthit_shader_name) {
rayclosesthit_shaders[rayclosesthit_shaders_size] = f;
rayclosesthit_shaders_size += 1;
}
}
}
}

for (size_t i = 0; i < vertex_shaders_size; ++i) {
hlsl_export_vertex(directory, d3d, vertex_shaders[i]);
}
Expand All @@ -729,4 +921,8 @@ void hlsl_export(char *directory, api_kind d3d) {
for (size_t i = 0; i < compute_shaders_size; ++i) {
hlsl_export_compute(directory, d3d, compute_shaders[i]);
}

if (d3d == API_DIRECT3D12) {
hlsl_export_all_ray_shaders(directory);
}
}
9 changes: 8 additions & 1 deletion Sources/shader_stage.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#pragma once

typedef enum shader_stage { SHADER_STAGE_VERTEX, SHADER_STAGE_FRAGMENT, SHADER_STAGE_COMPUTE } shader_stage;
typedef enum shader_stage {
SHADER_STAGE_VERTEX,
SHADER_STAGE_FRAGMENT,
SHADER_STAGE_COMPUTE,
SHADER_STAGE_RAY_GENERATION,
SHADER_STAGE_RAY_MISS,
SHADER_STAGE_RAY_CLOSEST_HIT
} shader_stage;

0 comments on commit fc95cef

Please sign in to comment.