Skip to content

Commit

Permalink
Updates GEMM with proper GFLOPs setup
Browse files Browse the repository at this point in the history
  • Loading branch information
odgaard committed Jan 24, 2024
1 parent 1670336 commit dcfca38
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 298 deletions.
23 changes: 17 additions & 6 deletions batbench/backends/kernelbackend/kerneltuner_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ class KernelBackend:
DEFAULT_OBJECTIVE = TIME

def __init__(self, spec, config_space, args: Arguments,
cuda_backend="Cupy", metrics=None, objective=DEFAULT_OBJECTIVE):
cuda_backend="Cupy", metrics=None):
self.spec = spec
self.config_space = config_space
self.kernel_spec = self.spec["KernelSpecification"]
self.objective = self.spec['General'].get('Objective', self.DEFAULT_OBJECTIVE)
self.minimize = self.spec['General'].get('Minimize', True)
self.metrics = metrics
self.objective = objective
self.args = args
self.function_args = self.args.get_function_args()

Expand Down Expand Up @@ -148,11 +149,17 @@ def evaluate_gridsize(self, gridsizes, dimension):
def extract_param_names(self, gridsize):
return [node.id for node in ast.walk(ast.parse(gridsize)) if isinstance(node, ast.Name)]


def wrap_variables_in_gridsize(self, gridsize, paramnames):
for paramname in paramnames:
# prevents multiple occurrences and avoids matching substrings
if not re.search(f"\b{paramname}\b", gridsize):
gridsize = gridsize.replace(paramname, f"p['{paramname}']")
# Using a regular expression to ensure that whole words are matched
pattern = r'\b' + re.escape(paramname) + r'\b'
replacement = f"p['{paramname}']"

# Check if the parameter name is already wrapped
wrapped_pattern = f"p\\['{paramname}'\\]"
if not re.search(wrapped_pattern, gridsize):
gridsize = re.sub(pattern, replacement, gridsize)
return gridsize

def validate_problemsize_length(self, problemsizes, gridsizes):
Expand All @@ -164,6 +171,7 @@ def update_invalid_result(self, result, msg, error=None):
result.validity = msg
result.correctness = 0
result.runtimes = [0]
result.objective = 10000 if self.minimize else 0
if error:
result.error = error
return result
Expand All @@ -172,7 +180,9 @@ def update_invalid_result(self, result, msg, error=None):
def update_result(self, result, kt_result):
result.runtimes = [t/1000 for t in kt_result["times"]]
result.runtime = sum(result.runtimes)
result.objective = kt_result[self.objective]/1000
result.objective = kt_result[self.objective]
if self.objective == self.TIME:
result.objective /= 1000
result.compile_time = kt_result["compile_time"]/1000
#result.time = kt_result["verification_time"]
#result.time = kt_result["benchmark_time"]
Expand All @@ -189,6 +199,7 @@ def run_reference(self, tuning_config):
self.opts["compiler_options"], None, self.opts["block_size_names"],
self.opts["quiet"], None)
answer_list = [None] * len(res)

for key in self.args.output_args:
idx = self.args.args[key]["index"]
self.args.add_reference_value(key, res[idx])
Expand Down
46 changes: 24 additions & 22 deletions batbench/benchmarks/GEMM/GEMM-CAFF.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{
"General": {
"BenchmarkName": "GEMM",
"OutputFormat": "JSON"
"OutputFormat": "JSON",
"Objective": "GFLOPs",
"Minimize": false
},
"ConfigurationSpace": {
"TuningParameters": [
Expand Down Expand Up @@ -68,25 +70,25 @@
{
"Name": "STRM",
"Type": "int",
"Values": "[0, 1]",
"Values": "[0]",
"Default": 0
},
{
"Name": "STRN",
"Type": "int",
"Values": "[0, 1]",
"Values": "[0]",
"Default": 0
},
{
"Name": "SA",
"Type": "int",
"Values": "[0]",
"Values": "[0, 1]",
"Default": 0
},
{
"Name": "SB",
"Type": "int",
"Values": "[0]",
"Values": "[0, 1]",
"Default": 0
},
{
Expand All @@ -102,20 +104,20 @@
"Parameters": ["KWG", "KWI"]
},
{
"Expression": "(MWG % (MDIMC * VWM)) == 0",
"Parameters": ["MWG", "MDIMC", "VWM"]
"Expression": "(MWG % MDIMC) == 0",
"Parameters": ["MWG", "MDIMC"]
},
{
"Expression": "(NWG % (NDIMC * VWN)) == 0",
"Parameters": ["NWG", "NDIMC", "VWN"]
"Expression": "(NWG % NDIMC) == 0",
"Parameters": ["NWG", "NDIMC"]
},
{
"Expression": "(MWG % (MDIMA * VWM)) == 0",
"Parameters": ["MWG", "MDIMA", "VWM"]
"Expression": "(MWG % MDIMA) == 0",
"Parameters": ["MWG", "MDIMA"]
},
{
"Expression": "(NWG % (NDIMB * VWN)) == 0",
"Parameters": ["NWG", "NDIMB", "VWN"]
"Expression": "(NWG % NDIMB) == 0",
"Parameters": ["NWG", "NDIMB"]
},
{
"Expression": "(KWG % ((MDIMC * NDIMC) // MDIMA)) == 0",
Expand Down Expand Up @@ -145,40 +147,40 @@
"Z": "1"
},
"GlobalSize": {
"X": "(16384 * MDIMC) // MWG",
"Y": "(16384 * NDIMC) // NWG",
"X": "4096 // MWG",
"Y": "4096 // NWG",
"Z": "1"
},
"SharedMemory": 16384,
"SharedMemory": 49152,
"Stream": null,
"Arguments": [
{
"Name": "kSizeM",
"Type": "int32",
"MemoryType": "Scalar",
"AccessType": "ReadOnly",
"FillValue": 16384
"FillValue": 4096
},
{
"Name": "kSizeN",
"Type": "int32",
"MemoryType": "Scalar",
"AccessType": "ReadOnly",
"FillValue": 16384
"FillValue": 4096
},
{
"Name": "kSizeK",
"Type": "int32",
"MemoryType": "Scalar",
"AccessType": "ReadOnly",
"FillValue": 16384
"FillValue": 4096
},
{
"Name": "agm",
"Type": "float",
"MemoryType": "Vector",
"AccessType": "ReadOnly",
"Size": 16384,
"Size": 16777216,
"FillType": "Random",
"FillValue": 1.0
},
Expand All @@ -187,7 +189,7 @@
"Type": "float",
"MemoryType": "Vector",
"AccessType": "ReadOnly",
"Size": 16384,
"Size": 16777216,
"FillType": "Random",
"FillValue": 1.0
},
Expand All @@ -196,7 +198,7 @@
"Type": "float",
"MemoryType": "Vector",
"AccessType": "WriteOnly",
"Size": 16384,
"Size": 16777216,
"FillType": "Constant",
"FillValue": 0.0,
"Output": 1
Expand Down
Loading

0 comments on commit dcfca38

Please sign in to comment.