Skip to content

Commit

Permalink
feat: add repeat expression
Browse files Browse the repository at this point in the history
  • Loading branch information
tzole1155 committed Jul 19, 2024
1 parent f323f91 commit b77409b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions hydra_plugins/moai_dsl_plugin/moai_dsl_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
| "view" "(" name "," SIGNED_INT ("," SIGNED_INT)* ")" -> reshape
| "transpose" "(" name "," SIGNED_INT ("," SIGNED_INT)* ")" -> transpose
| "flatten" "(" name "," SIGNED_INT ["," SIGNED_INT] ")" -> flatten
| "repeat_interleave" "(" name "," SIGNED_INT "," SIGNED_INT ")" -> repeat
| "zeros" "(" name ")" -> zeros_like
| "ones" "(" name ")" -> ones_like
| "rand" "(" name ")" -> rand_like
Expand Down
5 changes: 5 additions & 0 deletions moai/core/execution/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,11 @@ def flatten(self, key, *dims):
# self.results.append(f'result{self.index}')
# self.index += 1

def repeat(self, key, *dims):
key = self.extract(key)
dims = list(map(int, dims))
self._transform_operation("repeat_interleave", key, dims)

def unsqueeze(self, key, *dims):
if not isinstance(key, str) or isinstance(key, Token): # NOTE: is lark.Tree
key = self.extract(key)
Expand Down
17 changes: 17 additions & 0 deletions tests/dsl/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,23 @@ def test_flatten(self, parser, highdim_tensors):
x = self._parse_and_run(parser, expression, highdim_tensors)
assert x.sum() == 300.0

def test_repeat_interleave(self, parser, highdim_tensors):
expression = "repeat_interleave(fourdim, 2, 1)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert torch.equal(x, highdim_tensors["fourdim"].repeat_interleave(2, 1))
expression = "repeat_interleave(fourdim, 2, 0)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert torch.equal(x, highdim_tensors["fourdim"].repeat_interleave(2, 0))
expression = "repeat_interleave(single, 2, 1)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert x.sum() == 3600 * 2 # 3600 is the sum of single and 2 is the repeat
expression = "repeat_interleave(fourdim, 2, 0)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert (x[5:] != highdim_tensors["fourdim"]).sum() == 0
expression = "repeat_interleave(fourdim, 2, 0) + ones(10, 3, 2, 6)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert x.sum() == 120 * 2 + 10 * 3 * 2 * 6

def test_trig(self, parser, trig_tensors):
expression = "sin(pi2)"
x = self._parse_and_run(parser, expression, trig_tensors)
Expand Down

0 comments on commit b77409b

Please sign in to comment.