-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_infilling.py
79 lines (72 loc) · 2.13 KB
/
example_infilling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import fire
from llama import Llama
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.0,
top_p: float = 0.9,
max_seq_len: int = 192,
max_gen_len: int = 128,
max_batch_size: int = 4,
):
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
prompts = [
'''def remove_non_ascii(s: str) -> str:
""" <FILL>
return result
''',
"""# Installation instructions:
```bash
<FILL>
```
This downloads the LLaMA inference code and installs the repository as a local pip package.
""",
"""class InterfaceManagerFactory(AbstractManagerFactory):
def __init__(<FILL>
def main():
factory = InterfaceManagerFactory(start=datetime.now())
managers = []
for i in range(10):
managers.append(factory.build(id=i))
""",
"""/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :
π₁ P = 0 ↔ <FILL> = 0 :=
begin
split,
{ intros h f,
rw pi_1_etalisation at h,
simp [h],
refl
},
{ intro h,
have := @quasi_adjoint C D P,
simp [←pi_1_etalisation, this, h],
refl
}
end
""",
]
prefixes = [p.split("<FILL>")[0] for p in prompts]
suffixes = [p.split("<FILL>")[1] for p in prompts]
results = generator.text_infilling(
prefixes=prefixes,
suffixes=suffixes,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
for prompt, result in zip(prompts, results):
print("\n================= Prompt text =================\n")
print(prompt)
print("\n================= Filled text =================\n")
print(result["full_text"])
if __name__ == "__main__":
fire.Fire(main)