Skip to content

Commit

Permalink
Hotfix.
Browse files Browse the repository at this point in the history
  • Loading branch information
xehivs committed Jan 26, 2024
1 parent ba50a28 commit 35dd142
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions torchosr/models/OverlaySoftmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,40 @@ def train(self, dataloader, loss_fn, optimizer):
# Establish all combinations and enhance it to the target number of samples
pairing = np.array(list(itertools.combinations(present, self.n_mixed)))
n_review = len(pairing)
n_cycles = np.ceil(n_target_samples / n_review).astype(int)
pairing = repmat(pairing, n_cycles, 1)[:n_target_samples]

# Shuffle pairings if necessary
if self.shuffle:
np.random.shuffle(pairing)

# Prepare storage
overlayed_X = []

# Iterate pairing
for pid, elements in enumerate(pairing):
# Select sample sublocation
locs = pid % distribution[elements]
# Secure on single-class batches
if n_review > 0:
n_cycles = np.ceil(n_target_samples / n_review).astype(int)
pairing = repmat(pairing, n_cycles, 1)[:n_target_samples]

# Establish sample location
ids = [np.where(y[:,element] == 1)[0][loc] for (element, loc) in zip(elements, locs)]
# Shuffle pairings if necessary
if self.shuffle:
np.random.shuffle(pairing)

# Integrate and store sample
x = X[ids].mean(0)
overlayed_X.append(x)
# Prepare storage
overlayed_X = []

overlayed_X = torch.stack(overlayed_X)
overlayed_y = torch.zeros((overlayed_X.shape[0], self.n_known+1))
overlayed_y[:,-1] = 1

_X = torch.cat((X, overlayed_X))
_y = torch.cat((y, overlayed_y))
# Iterate pairing
for pid, elements in enumerate(pairing):
# Select sample sublocation
locs = pid % distribution[elements]

# Establish sample location
ids = [np.where(y[:,element] == 1)[0][loc] for (element, loc) in zip(elements, locs)]

# Integrate and store sample
x = X[ids].mean(0)
overlayed_X.append(x)

overlayed_X = torch.stack(overlayed_X)
overlayed_y = torch.zeros((overlayed_X.shape[0], self.n_known+1))
overlayed_y[:,-1] = 1

_X = torch.cat((X, overlayed_X))
_y = torch.cat((y, overlayed_y))
else:
_X = X
_y = y

# Compute prediction and loss
pred = self(_X)
Expand Down

0 comments on commit 35dd142

Please sign in to comment.