Skip to content

Commit

Permalink
Add option for debugging.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaminow committed Mar 1, 2023
1 parent a6d1a89 commit 3b4e872
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions mtenn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,26 @@ def forward(self, input_list):
"""
## Get predictions for all inputs in the list, and combine them in a
## tensor (while keeping track of gradients)
all_reps = torch.stack(
[super(GroupedModel, self).forward(inp) for inp in input_list]
)
all_reps = []
orig_dev = None
for i, inp in enumerate(input_list):
if "MTENN_VERBOSE" in os.environ:
print(f"pose {i}", flush=True)
print(
"size",
", ".join(
[
f"{k}: {v.shape} ({v.dtype})"
for k, v in inp.items()
if type(v) is torch.Tensor
]
),
sum([len(p.flatten()) for p in self.parameters()]),
f"{torch.cuda.memory_allocated():,}",
flush=True,
)
all_reps.append(super(GroupedModel, self).forward(inp))
all_reps = torch.stack(all_reps).flatten()

## Combine each prediction according to `self.combination`
comb_pred = self.combination(all_reps)
Expand Down

0 comments on commit 3b4e872

Please sign in to comment.