Skip to content

Commit

Permalink
Make mesh shaders usable
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 27, 2024
1 parent f190f2b commit 13e4ad0
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 15 deletions.
58 changes: 55 additions & 3 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ static void write_types(char *hlsl, size_t *offset, shader_stage stage, type_id
}
}
}
else if (stage == SHADER_STAGE_MESH && types[i] == output) {
for (size_t j = 0; j < t->members.size; ++j) {
if (j == 0) {
*offset += sprintf(&hlsl[*offset], "\t%s %s : SV_POSITION;\n", type_string(t->members.m[j].type.type), get_name(t->members.m[j].name));
}
else {
*offset +=
sprintf(&hlsl[*offset], "\t%s %s : TEXCOORD%zu;\n", type_string(t->members.m[j].type.type), get_name(t->members.m[j].name), j - 1);
}
}
}
else if (stage == SHADER_STAGE_FRAGMENT && types[i] == input) {
for (size_t j = 0; j < t->members.size; ++j) {
if (j == 0) {
Expand Down Expand Up @@ -416,6 +427,21 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
error(context, "Mesh function requires a threads attribute with three parameters");
}

attribute *tris_attribute = find_attribute(&f->attributes, add_name("tris"));
if (tris_attribute == NULL || tris_attribute->paramters_count != 1) {
debug_context context = {0};
error(context, "Mesh function requires a tris attribute with one parameter");
}

attribute *vertices_attribute = find_attribute(&f->attributes, add_name("vertices"));
if (vertices_attribute == NULL || vertices_attribute->paramters_count != 2) {
debug_context context = {0};
error(context, "Mesh function requires a vertices attribute with two parameters");
}

type_id vertex_type = (type_id)vertices_attribute->parameters[1];
char *vertex_name = get_name(get_type(vertex_type)->name);

*offset += sprintf(&hlsl[*offset], "[outputtopology(\"triangle\")][numthreads(%i, %i, %i)] %s main(", (int)threads_attribute->parameters[0],
(int)threads_attribute->parameters[1], (int)threads_attribute->parameters[2], type_string(f->return_type.type));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
Expand All @@ -431,8 +457,12 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
if (f->parameters_size > 0) {
*offset += sprintf(&hlsl[*offset], ", ");
}
*offset += sprintf(&hlsl[*offset], "in uint3 _kong_group_id : SV_GroupID, in uint3 _kong_group_thread_id : SV_GroupThreadID, in uint3 "
"_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n");
*offset +=
sprintf(&hlsl[*offset],
"out indices uint3 _kong_mesh_tris[%i], out vertices %s _kong_mesh_vertices[%i], in uint3 _kong_group_id : SV_GroupID, in uint3 "
"_kong_group_thread_id : SV_GroupThreadID, in uint3 "
"_kong_dispatch_thread_id : SV_DispatchThreadID, in uint _kong_group_index : SV_GroupIndex) {\n",
(int)tris_attribute->parameters[0], vertex_name, (int)vertices_attribute->parameters[0]);
}
else {
debug_context context = {0};
Expand Down Expand Up @@ -636,6 +666,21 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
sprintf(&hlsl[*offset], "DispatchMesh(_%" PRIu64 ", _%" PRIu64 ", _%" PRIu64 ", _%" PRIu64 ");\n", o->op_call.parameters[0].index,
o->op_call.parameters[1].index, o->op_call.parameters[2].index, o->op_call.parameters[3].index);
}
else if (o->op_call.func == add_name("set_mesh_output_counts")) {
check(o->op_call.parameters_size == 2, context, "set_mesh_output_counts requires two parameters");
*offset += sprintf(&hlsl[*offset], "SetMeshOutputCounts(_%" PRIu64 ", _%" PRIu64 ");\n", o->op_call.parameters[0].index,
o->op_call.parameters[1].index);
}
else if (o->op_call.func == add_name("set_mesh_triangle")) {
check(o->op_call.parameters_size == 2, context, "set_mesh_triangle requires two parameters");
*offset += sprintf(&hlsl[*offset], "_kong_mesh_tris[_%" PRIu64 "] = _%" PRIu64 ";\n", o->op_call.parameters[0].index,
o->op_call.parameters[1].index);
}
else if (o->op_call.func == add_name("set_mesh_vertex")) {
check(o->op_call.parameters_size == 2, context, "set_mesh_vertex requires two parameters");
*offset += sprintf(&hlsl[*offset], "_kong_mesh_vertices[_%" PRIu64 "] = _%" PRIu64 ";\n", o->op_call.parameters[0].index,
o->op_call.parameters[1].index);
}
else {
if (o->op_call.var.type.type == void_id) {
*offset += sprintf(&hlsl[*offset], "%s(", function_string(o->op_call.func));
Expand Down Expand Up @@ -745,7 +790,14 @@ static void hlsl_export_mesh(char *directory, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;

write_types(hlsl, &offset, SHADER_STAGE_MESH, NO_TYPE, NO_TYPE, main, NULL, 0);
attribute *vertices_attribute = find_attribute(&main->attributes, add_name("vertices"));
if (vertices_attribute == NULL || vertices_attribute->paramters_count != 2) {
debug_context context = {0};
error(context, "Mesh function requires a vertices attribute with two parameters");
}
type_id vertex_output = (type_id)vertices_attribute->parameters[1];

write_types(hlsl, &offset, SHADER_STAGE_MESH, NO_TYPE, vertex_output, main, NULL, 0);

write_globals(hlsl, &offset, main, NULL, 0);

Expand Down
236 changes: 233 additions & 3 deletions Sources/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ static void add_func_float3_float3(char *name) {
f->block = NULL;
}

static void add_func_void_uint_uint(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);

init_type_ref(&f->return_type, add_name("void"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("a");
f->parameter_names[1] = add_name("b");
for (int i = 0; i < 2; ++i) {
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);
}
f->parameters_size = 2;

f->block = NULL;
}

void functions_init(void) {
function *new_functions = realloc(functions, functions_size * sizeof(function));
debug_context context = {0};
Expand Down Expand Up @@ -121,7 +139,13 @@ void functions_init(void) {
f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("float"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);
f->parameters_size = 1;

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("float"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameters_size = 2;

f->block = NULL;
}

Expand All @@ -130,10 +154,21 @@ void functions_init(void) {
function *f = get_function(float3_constructor_id);
init_type_ref(&f->return_type, add_name("float3"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("float"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);
f->parameters_size = 1;

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("float"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("float"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);

f->parameters_size = 3;

f->block = NULL;
}

Expand All @@ -142,10 +177,163 @@ void functions_init(void) {
function *f = get_function(float4_constructor_id);
init_type_ref(&f->return_type, add_name("float4"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("float"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);
f->parameters_size = 1;

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("float"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("float"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);

f->parameter_names[3] = add_name("w");
init_type_ref(&f->parameter_types[3], add_name("float"));
f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]);

f->parameters_size = 4;

f->block = NULL;
}

{
float2_constructor_id = add_function(add_name("int2"));
function *f = get_function(float2_constructor_id);
init_type_ref(&f->return_type, add_name("int2"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("int"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("int"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameters_size = 2;

f->block = NULL;
}

{
float3_constructor_id = add_function(add_name("int3"));
function *f = get_function(float3_constructor_id);
init_type_ref(&f->return_type, add_name("int3"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("int"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("int"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("int"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);

f->parameters_size = 3;

f->block = NULL;
}

{
float4_constructor_id = add_function(add_name("int4"));
function *f = get_function(float4_constructor_id);
init_type_ref(&f->return_type, add_name("int4"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("int"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("int"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("int"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);

f->parameter_names[3] = add_name("w");
init_type_ref(&f->parameter_types[3], add_name("int"));
f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]);

f->parameters_size = 4;

f->block = NULL;
}

{
float2_constructor_id = add_function(add_name("uint2"));
function *f = get_function(float2_constructor_id);
init_type_ref(&f->return_type, add_name("uint2"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("uint"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameters_size = 2;

f->block = NULL;
}

{
float3_constructor_id = add_function(add_name("uint3"));
function *f = get_function(float3_constructor_id);
init_type_ref(&f->return_type, add_name("uint3"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("uint"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("uint"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);

f->parameters_size = 3;

f->block = NULL;
}

{
float4_constructor_id = add_function(add_name("uint4"));
function *f = get_function(float4_constructor_id);
init_type_ref(&f->return_type, add_name("uint4"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("uint"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameter_names[2] = add_name("z");
init_type_ref(&f->parameter_types[2], add_name("uint"));
f->parameter_types[2].type = find_type_by_ref(&f->parameter_types[2]);

f->parameter_names[3] = add_name("w");
init_type_ref(&f->parameter_types[3], add_name("uint"));
f->parameter_types[3].type = find_type_by_ref(&f->parameter_types[3]);

f->parameters_size = 4;

f->block = NULL;
}

Expand Down Expand Up @@ -215,6 +403,48 @@ void functions_init(void) {
add_func_float_float("saturate");
add_func_uint3("ray_index");
add_func_float3("ray_dimensions");

add_func_void_uint_uint("set_mesh_output_counts");

{
function_id func = add_function(add_name("set_mesh_triangle"));
function *f = get_function(func);
init_type_ref(&f->return_type, add_name("void"));
f->return_type.type = find_type_by_ref(&f->return_type);
f->return_type.array_size = 1;

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("uint3"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameters_size = 2;

f->block = NULL;
}

{
function_id func = add_function(add_name("set_mesh_vertex"));
function *f = get_function(func);
init_type_ref(&f->return_type, add_name("void"));
f->return_type.type = find_type_by_ref(&f->return_type);
f->return_type.array_size = 1;

f->parameter_names[0] = add_name("x");
init_type_ref(&f->parameter_types[0], add_name("uint"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("y");
init_type_ref(&f->parameter_types[1], add_name("void"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameters_size = 2;

f->block = NULL;
}
}

static void grow_if_needed(uint64_t size) {
Expand Down
8 changes: 7 additions & 1 deletion Sources/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ static double attribute_parameter_to_number(name_id attribute_name, name_id para
if (attribute_name == add_name("topology") && parameter_name == add_name("triangle")) {
return 0;
}

type_id type = find_type_by_name(parameter_name);
if (type != NO_TYPE) {
return (double)type;
}

debug_context context = {0};
error(context, "Unknown attribute parameter %s", get_name(parameter_name));
return 0;
Expand Down Expand Up @@ -875,7 +881,7 @@ static expression *parse_call(state_t *state, name_id func_name) {

advance_state(state);

bool dynamic = square && current(state).kind == TOKEN_NUMBER;
bool dynamic = square && current(state).kind != TOKEN_NUMBER;

expression *right = parse_member(state, square);

Expand Down
Loading

0 comments on commit 13e4ad0

Please sign in to comment.