| import os | |
| import ujson | |
| import torch | |
| import random | |
| from collections import defaultdict, OrderedDict | |
| from colbert.parameters import DEVICE | |
| from colbert.modeling.colbert import ColBERT | |
| from colbert.utils.utils import print_message, load_checkpoint | |
| def load_model(args, do_print=True): | |
| colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased', | |
| query_maxlen=args.query_maxlen, | |
| doc_maxlen=args.doc_maxlen, | |
| dim=args.dim, | |
| similarity_metric=args.similarity, | |
| mask_punctuation=args.mask_punctuation) | |
| colbert = colbert.to(DEVICE) | |
| print_message("#> Loading model checkpoint.", condition=do_print) | |
| checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) | |
| colbert.eval() | |
| return colbert, checkpoint | |