Skip to content

Commit

Permalink
Add various functions and operators and check for payload types
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Sep 16, 2024
1 parent b30d02d commit 81e147c
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 14 deletions.
12 changes: 12 additions & 0 deletions Sources/backends/cstyle.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ void cstyle_write_opcode(char *code, size_t *offset, opcode *o, type_string_func
o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index);
break;
}
case OPCODE_MOD: {
indent(code, offset, *indentation);
*offset += sprintf(&code[*offset], "%s _%" PRIu64 " = _%" PRIu64 " %% _%" PRIu64 ";\n", type_string(o->op_binary.result.type.type),
o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index);
break;
}
case OPCODE_EQUALS: {
indent(code, offset, *indentation);
*offset += sprintf(&code[*offset], "%s _%" PRIu64 " = _%" PRIu64 " == _%" PRIu64 ";\n", type_string(o->op_binary.result.type.type),
Expand Down Expand Up @@ -205,6 +211,12 @@ void cstyle_write_opcode(char *code, size_t *offset, opcode *o, type_string_func
o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index);
break;
}
case OPCODE_XOR: {
indent(code, offset, *indentation);
*offset += sprintf(&code[*offset], "%s _%" PRIu64 " = _%" PRIu64 " ^ _%" PRIu64 ";\n", type_string(o->op_binary.result.type.type),
o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index);
break;
}
case OPCODE_IF: {
indent(code, offset, *indentation);
*offset += sprintf(&code[*offset], "if (_%" PRIu64 ")\n", o->op_if.condition.index);
Expand Down
127 changes: 122 additions & 5 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,18 @@ static void write_root_signature(char *hlsl, size_t *offset) {
*offset += sprintf(&hlsl[*offset], "\")]\n");
}

static type_id payload_types[256];
static size_t payload_types_count = 0;

static bool is_payload_type(type_id t) {
for (size_t payload_index = 0; payload_index < payload_types_count; ++payload_index) {
if (payload_types[payload_index] == t) {
return true;
}
}
return false;
}

static void write_functions(char *hlsl, size_t *offset, shader_stage stage, function *main, function **rayshaders, size_t rayshaders_count) {
function *functions[256];
size_t functions_size = 0;
Expand All @@ -435,6 +447,83 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
find_referenced_functions(rayshaders[rayshader_index], functions, &functions_size);
}

// find payloads
for (size_t i = 0; i < functions_size; ++i) {
function *f = functions[i];

uint8_t *data = f->code.o;
size_t size = f->code.size;

size_t index = 0;
while (index < size) {
opcode *o = (opcode *)&data[index];
switch (o->type) {
case OPCODE_CALL: {
if (o->op_call.func == add_name("trace_ray")) {
debug_context context = {0};
check(o->op_call.parameters_size == 3, context, "trace_ray requires three parameters");

type_id payload_type = o->op_call.parameters[2].type.type;

bool found = false;
for (size_t payload_index = 0; payload_index < payload_types_count; ++payload_index) {
if (payload_types[payload_index] == payload_type) {
found = true;
break;
}
}

if (!found) {
payload_types[payload_types_count] = payload_type;
payload_types_count += 1;
}
}
}
}
index += o->size;
}
}

// function declarations
for (size_t i = 0; i < functions_size; ++i) {
function *f = functions[i];

if (f != main && !is_raygen_shader(f) && !is_raymiss_shader(f) && !is_rayclosesthit_shader(f) && !is_rayintersection_shader(f) &&
!is_rayanyhit_shader(f)) {

uint64_t parameter_ids[256] = {0};
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
for (size_t i = 0; i < f->block->block.vars.size; ++i) {
if (f->parameter_names[parameter_index] == f->block->block.vars.v[i].name) {
parameter_ids[parameter_index] = f->block->block.vars.v[i].variable_id;
break;
}
}
}

*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
char *payload_prefix = "";
if (is_payload_type(f->parameter_types[parameter_index].type)) {
payload_prefix = "inout ";
}

if (parameter_index == 0) {

*offset += sprintf(&hlsl[*offset], "%s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type),
parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type),
parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ");\n");
}
}

*offset += sprintf(&hlsl[*offset], "\n");

for (size_t i = 0; i < functions_size; ++i) {
function *f = functions[i];

Expand Down Expand Up @@ -706,11 +795,18 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
else {
*offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
char *payload_prefix = "";
if (is_payload_type(f->parameter_types[parameter_index].type)) {
payload_prefix = "inout ";
}

if (parameter_index == 0) {
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
*offset += sprintf(&hlsl[*offset], "%s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type),
parameter_ids[parameter_index]);
}
else {
*offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
*offset += sprintf(&hlsl[*offset], ", %s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type),
parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
Expand Down Expand Up @@ -823,19 +919,40 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = _kong_group_index;\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("instance_id")) {
check(o->op_call.parameters_size == 0, context, "instance_id can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = InstanceID();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("world_ray_direction")) {
check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter");
check(o->op_call.parameters_size == 0, context, "world_ray_direction can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = WorldRayDirection();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("world_ray_origin")) {
check(o->op_call.parameters_size == 0, context, "world_ray_origin can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = WorldRayOrigin();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("ray_length")) {
check(o->op_call.parameters_size == 0, context, "ray_length can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = RayTCurrent();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("ray_index")) {
check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter");
check(o->op_call.parameters_size == 0, context, "ray_index can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = DispatchRaysIndex();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("ray_dimensions")) {
check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter");
check(o->op_call.parameters_size == 0, context, "ray_dimensions can not have a parameter");
*offset +=
sprintf(&hlsl[*offset], "%s _%" PRIu64 " = DispatchRaysDimensions();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("object_to_world3x3")) {
check(o->op_call.parameters_size == 0, context, "object_to_world3x3 can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = (float3x3)ObjectToWorld4x3();\n", type_string(o->op_call.var.type.type),
o->op_call.var.index);
}
else if (o->op_call.func == add_name("primitive_index")) {
check(o->op_call.parameters_size == 0, context, "primitive_index can not have a parameter");
*offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = PrimitiveIndex();\n", type_string(o->op_call.var.type.type), o->op_call.var.index);
}
else if (o->op_call.func == add_name("trace_ray")) {
check(o->op_call.parameters_size == 3, context, "trace_ray requires three parameters");
*offset += sprintf(&hlsl[*offset], "TraceRay(_%" PRIu64 ", RAY_FLAG_NONE, 0xFF, 0, 0, 0, _%" PRIu64 ", _%" PRIu64 ");\n",
Expand Down
18 changes: 12 additions & 6 deletions Sources/compiler.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
case OPERATOR_LESS:
case OPERATOR_LESS_EQUAL:
case OPERATOR_AND:
case OPERATOR_OR: {
case OPERATOR_OR:
case OPERATOR_XOR: {
variable right_var = emit_expression(code, parent, right);
variable left_var = emit_expression(code, parent, left);
type_ref t;
Expand Down Expand Up @@ -156,6 +157,9 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
case OPERATOR_OR:
o.type = OPCODE_OR;
break;
case OPERATOR_XOR:
o.type = OPCODE_XOR;
break;
default: {
debug_context context = {0};
error(context, "Unexpected operator");
Expand All @@ -172,7 +176,8 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
case OPERATOR_MINUS:
case OPERATOR_PLUS:
case OPERATOR_DIVIDE:
case OPERATOR_MULTIPLY: {
case OPERATOR_MULTIPLY:
case OPERATOR_MOD: {
variable right_var = emit_expression(code, parent, right);
variable left_var = emit_expression(code, parent, left);
variable result_var = allocate_variable(e->type, VARIABLE_LOCAL);
Expand All @@ -191,6 +196,9 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
case OPERATOR_MULTIPLY:
o.type = OPCODE_MULTIPLY;
break;
case OPERATOR_MOD:
o.type = OPCODE_MOD;
break;
default: {
debug_context context = {0};
error(context, "Unexpected operator");
Expand All @@ -208,10 +216,6 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
debug_context context = {0};
error(context, "! is not a binary operator");
}
case OPERATOR_MOD: {
debug_context context = {0};
error(context, "not implemented");
}
case OPERATOR_ASSIGN:
case OPERATOR_MINUS_ASSIGN:
case OPERATOR_PLUS_ASSIGN:
Expand Down Expand Up @@ -402,6 +406,8 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
}
case OPERATOR_OR:
error(context, "not implemented");
case OPERATOR_XOR:
error(context, "not implemented");
case OPERATOR_AND:
error(context, "not implemented");
case OPERATOR_MOD:
Expand Down
2 changes: 2 additions & 0 deletions Sources/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ typedef struct opcode {
OPCODE_CALL,
OPCODE_MULTIPLY,
OPCODE_DIVIDE,
OPCODE_MOD,
OPCODE_ADD,
OPCODE_SUB,
OPCODE_EQUALS,
Expand All @@ -46,6 +47,7 @@ typedef struct opcode {
OPCODE_LESS_EQUAL,
OPCODE_AND,
OPCODE_OR,
OPCODE_XOR,
OPCODE_IF,
OPCODE_WHILE_START,
OPCODE_WHILE_CONDITION,
Expand Down
55 changes: 55 additions & 0 deletions Sources/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ static void add_func_float3_float_float_float(char *name) {
f->block = NULL;
}

static void add_func_float(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
init_type_ref(&f->return_type, add_name("float"));
f->return_type.type = find_type_by_ref(&f->return_type);
f->parameters_size = 0;
f->block = NULL;
}

static void add_func_float3(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
Expand All @@ -44,6 +53,24 @@ static void add_func_float3(char *name) {
f->block = NULL;
}

static void add_func_float3x3(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
init_type_ref(&f->return_type, add_name("float3x3"));
f->return_type.type = find_type_by_ref(&f->return_type);
f->parameters_size = 0;
f->block = NULL;
}

static void add_func_uint(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
init_type_ref(&f->return_type, add_name("uint"));
f->return_type.type = find_type_by_ref(&f->return_type);
f->parameters_size = 0;
f->block = NULL;
}

static void add_func_uint3(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
Expand Down Expand Up @@ -89,6 +116,24 @@ static void add_func_float3_float3(char *name) {
f->block = NULL;
}

static void add_func_float3_float3_float3(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
init_type_ref(&f->return_type, add_name("float3"));
f->return_type.type = find_type_by_ref(&f->return_type);

f->parameter_names[0] = add_name("a");
init_type_ref(&f->parameter_types[0], add_name("float3"));
f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]);

f->parameter_names[1] = add_name("b");
init_type_ref(&f->parameter_types[1], add_name("float3"));
f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]);

f->parameters_size = 2;
f->block = NULL;
}

static void add_func_void_uint_uint(char *name) {
function_id func = add_function(add_name(name));
function *f = get_function(func);
Expand Down Expand Up @@ -531,16 +576,26 @@ void functions_init(void) {
add_func_uint3("group_thread_id");
add_func_uint3("dispatch_thread_id");
add_func_int("group_index");
add_func_int("instance_id");

add_func_float3_float_float_float("lerp");
add_func_float3("world_ray_origin");
add_func_float3("world_ray_direction");
add_func_float("ray_length");
add_func_float3_float3("normalize");
add_func_float_float("saturate");
add_func_float_float("sin");
add_func_float_float("cos");
add_func_float_float2("length");
add_func_uint3("ray_index");
add_func_float3("ray_dimensions");
add_func_float_float("frac");
add_func_float3x3("object_to_world3x3");
add_func_float3_float3_float3("reflect");
add_func_uint("primitive_index");
add_func_float3_float3("abs");
add_func_float3_float3_float3("dot");
add_func_float3_float3("saturate");

add_func_void_uint_uint("set_mesh_output_counts");

Expand Down
4 changes: 3 additions & 1 deletion Sources/kong.c
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ void resolve_types_in_expression(statement *parent, expression *e) {
case OPERATOR_LESS:
case OPERATOR_LESS_EQUAL:
case OPERATOR_OR:
case OPERATOR_AND: {
case OPERATOR_AND:
case OPERATOR_XOR: {
e->type.type = bool_id;
break;
}
Expand Down Expand Up @@ -371,6 +372,7 @@ void resolve_types_in_expression(statement *parent, expression *e) {
case OPERATOR_DIVIDE:
case OPERATOR_MULTIPLY:
case OPERATOR_OR:
case OPERATOR_XOR:
case OPERATOR_AND:
case OPERATOR_MOD:
case OPERATOR_ASSIGN:
Expand Down
2 changes: 1 addition & 1 deletion Sources/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ static expression *parse_logical(state_t *state) {
while (!done) {
if (current(state).kind == TOKEN_OPERATOR) {
operatorr op = current(state).op;
if (op == OPERATOR_OR || op == OPERATOR_AND) {
if (op == OPERATOR_OR || op == OPERATOR_AND || op == OPERATOR_XOR) {
advance_state(state);
expression *right = parse_equality(state);
expression *expression = expression_allocate();
Expand Down
Loading

0 comments on commit 81e147c

Please sign in to comment.