Commit
·
0f3b775
1
Parent(s):
9302ce1
Update README.md
Browse files
README.md
CHANGED
|
@@ -155,6 +155,52 @@ Users (both direct and downstream) should be made aware of the risks, biases and
|
|
| 155 |
Use the code below to get started with the model.
|
| 156 |
|
| 157 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
```
|
| 159 |
|
| 160 |
You may also refer to <https://huggingface.co/spaces/liujch1998/crystal/blob/main/app.py#L10-L86> for implementation.
|
|
|
|
| 155 |
Use the code below to get started with the model.
|
| 156 |
|
| 157 |
```python
|
| 158 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 159 |
+
|
| 160 |
+
tokenizer = AutoTokenizer.from_pretrained('liujch1998/crystal-11b')
|
| 161 |
+
model = AutoModelForSeq2SeqLM.from_pretrained('liujch1998/crystal-11b')
|
| 162 |
+
model.eval()
|
| 163 |
+
|
| 164 |
+
max_question_len, max_knowledge_len, max_answer_len = 128, 32, 2
|
| 165 |
+
k = 1 # number of knowledge statements to generate
|
| 166 |
+
top_p = 0.0001
|
| 167 |
+
|
| 168 |
+
question = 'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller'
|
| 169 |
+
choices = ['A', 'B']
|
| 170 |
+
|
| 171 |
+
choices_ids = tokenizer(choices, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_answer_len).input_ids # (C, AL)
|
| 172 |
+
|
| 173 |
+
prompt = question + ' \\n Knowledge: '
|
| 174 |
+
prompt_tok = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len) # (1, QL)
|
| 175 |
+
knowledges_ids = self.model.generate(
|
| 176 |
+
input_ids=prompt_tok.input_ids,
|
| 177 |
+
attention_mask=prompt_tok.attention_mask,
|
| 178 |
+
max_length=max_knowledge_len + 1,
|
| 179 |
+
min_length=3,
|
| 180 |
+
do_sample=True,
|
| 181 |
+
num_return_sequences=k,
|
| 182 |
+
top_p=top_p,
|
| 183 |
+
) # (K, KL); begins with 0 ([BOS]); ends with 1 ([EOS])
|
| 184 |
+
knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
|
| 185 |
+
knowledges = tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 186 |
+
|
| 187 |
+
prompts = [question + (f' \\n Knowledge: {knowledge} \\n Answer: ' if knowledge != '' else ' \\n Answer:') for knowledge in knowledges]
|
| 188 |
+
prompts_tok = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len + max_knowledge_len) # (K, QL+KL)
|
| 189 |
+
output = model(
|
| 190 |
+
input_ids=prompts_tok.input_ids,
|
| 191 |
+
attention_mask=prompts_tok.attention_mask,
|
| 192 |
+
labels=choices_ids[0].unsqueeze(0).repeat(len(knowledges), 1),
|
| 193 |
+
)
|
| 194 |
+
logitsss = output.logits # (K, AL, V)
|
| 195 |
+
logitss = logitsss[:, 0, :] # (K, V)
|
| 196 |
+
choice_ids = choices_ids[:, 0] # (C)
|
| 197 |
+
answer_logitss = logitss.gather(dim=1, index=choice_ids.unsqueeze(0).expand(len(knowledges), -1)) # (K, C)
|
| 198 |
+
answer_probss = answer_logitss.softmax(dim=1) # (K, C)
|
| 199 |
+
answer_probs = answer_probss.max(dim=0).values # (C)
|
| 200 |
+
pred = answer_probs.argmax(dim=0).item()
|
| 201 |
+
pred = choices[pred]
|
| 202 |
+
|
| 203 |
+
print(f'Question: {question}\nKnowledge: {knowledges[0]}\nAnswer: {pred}')
|
| 204 |
```
|
| 205 |
|
| 206 |
You may also refer to <https://huggingface.co/spaces/liujch1998/crystal/blob/main/app.py#L10-L86> for implementation.
|