-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
108 lines (82 loc) · 3.56 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""dimension annotation
b: batch
t: token position
d: gpt d_model
v: gpt vocab size
l: SAE n latent
k: topk
n: training step
Difference to paper training spec:
- total training token is 8 epoch of 1.31b, paper is 8 epoch of 6.4b
- We project away gradient information parallel to the decoder vectors, to account for interaction between Adam and decoder normalization.
- weight EMA
- ghost grads
"""
import argparse
from pathlib import Path
import torch
import numpy as np
from geom_median.numpy import compute_geometric_median
import transformer_lens.utils as utils
from sparse_autoencoder.model import Autoencoder, TopK
from tqdm import tqdm
import wandb
wandb.require("core")
K = 32 # top k
seq_len = 64 # default value of all experiments per paper
d_model = 768 # gpt2 small
n_latents = 2**17
data_dir = Path("data")
data_dir.mkdir(parents=True, exist_ok=True)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
wandb.init(project="topk_sae", name="sae 128k (update trn data)")
parser = argparse.ArgumentParser()
parser.add_argument("--target_layer", type=int, default=8)
parser.add_argument("--n_step", type=int, default=10_000)
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--n_epoch", type=int, default=8)
args = parser.parse_args()
device = utils.get_device()
trn_data_path = (
data_dir
/ f"act_nbd_layer_{args.target_layer}_n_{args.n_step}_bs_{args.batch_size}.bin"
)
act_nbd = np.memmap(
str(trn_data_path),
dtype=np.float32,
mode="r+",
shape=(args.n_step, args.batch_size, d_model),
)
print(f"trn data loaded with shape {act_nbd.shape}")
# initializaiton
sae = Autoencoder(n_latents, d_model, activation=TopK(K), normalize=True)
sample_act = np.array(act_nbd[-100:]).reshape(-1, d_model) # (100*2048, 768) last 100 step as sample act
mse_scale = 1 / ((sample_act - sample_act.mean(0))**2).mean()
mse_scale = torch.tensor(mse_scale, dtype=torch.float32, device=device)
sae.encoder.weight.data = sae.decoder.weight.data.T.clone() # tied init encoder to the transpose of the decoder
sae.decoder.weight.data /= sae.decoder.weight.data.norm(dim=0) # init decoder column to be unit-norm
geometric_median_d = compute_geometric_median(sample_act).median
geometric_median_d = torch.tensor(geometric_median_d, dtype=torch.float32, device=device)
sae.pre_bias.data = geometric_median_d # initialize the bias bpre to be the geometric median of a sample set of data points
sae = sae.to(device)
optimizer = torch.optim.Adam(sae.parameters(), lr=4e-4)
for epoch in range(args.n_epoch):
print(f"... on epoch {epoch+1}/{args.n_epoch}")
with tqdm(range(args.n_step), unit="step") as pbar:
for step in pbar:
act_bd = act_nbd[step]
act_bd = torch.from_numpy(act_bd).to(device)
_, _, recon_bd = sae(act_bd)
loss = ((recon_bd - act_bd) ** 2).mean() * mse_scale
loss.backward()
sae.decoder.weight.data /= sae.decoder.weight.data.norm(dim=0) # renormalize decoder column to be unit-norm
optimizer.step()
optimizer.zero_grad(set_to_none=True)
pbar.set_postfix({"loss": f"{loss.item():.3f}"})
wandb.log(dict(loss=loss))
model_filename = f"sae_128k.pt"
model_path = data_dir / 'sae' / model_filename
torch.save(sae.state_dict(), model_path)
print(f"Model saved to {model_path}")
wandb.finish()