|
|
|
|
|
""" |
|
|
CardVault+ Inference Example |
|
|
Simple example showing how to use the CardVault+ model for card extraction |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
from PIL import Image, ImageDraw |
|
|
import json |
|
|
|
|
|
def create_sample_card(): |
|
|
"""Create a sample credit card image for testing""" |
|
|
|
|
|
img = Image.new('RGB', (400, 250), color='lightblue') |
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
|
|
|
draw.text((20, 50), "SAMPLE BANK", fill='black') |
|
|
draw.text((20, 100), "1234 5678 9012 3456", fill='black') |
|
|
draw.text((20, 150), "JOHN DOE", fill='black') |
|
|
draw.text((300, 150), "12/25", fill='black') |
|
|
|
|
|
return img |
|
|
|
|
|
def extract_card_info(image_path_or_pil=None): |
|
|
"""Extract structured information from a card image""" |
|
|
|
|
|
|
|
|
print("Loading CardVault+ model...") |
|
|
model_id = "sugiv/cardvaultplus" |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
if image_path_or_pil is None: |
|
|
print("Creating sample card image...") |
|
|
image = create_sample_card() |
|
|
elif isinstance(image_path_or_pil, str): |
|
|
image = Image.open(image_path_or_pil) |
|
|
else: |
|
|
image = image_path_or_pil |
|
|
|
|
|
|
|
|
prompt = "<image>Extract structured information from this card/document in JSON format." |
|
|
|
|
|
|
|
|
inputs = processor(text=prompt, images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
device = next(model.parameters()).device |
|
|
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
print("Extracting information...") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=150, |
|
|
do_sample=False, |
|
|
pad_token_id=processor.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response = processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
extracted_json = None |
|
|
if '{' in response and '}' in response: |
|
|
try: |
|
|
json_start = response.find('{') |
|
|
json_end = response.rfind('}') + 1 |
|
|
json_str = response[json_start:json_end] |
|
|
extracted_json = json.loads(json_str) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return { |
|
|
'full_response': response, |
|
|
'extracted_json': extracted_json, |
|
|
'success': extracted_json is not None |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
result = extract_card_info() |
|
|
|
|
|
print("="*50) |
|
|
print("CardVault+ Extraction Results") |
|
|
print("="*50) |
|
|
print(f"Success: {result['success']}") |
|
|
print(f"Full Response: {result['full_response']}") |
|
|
|
|
|
if result['extracted_json']: |
|
|
print("Extracted JSON:") |
|
|
print(json.dumps(result['extracted_json'], indent=2)) |
|
|
|
|
|
|
|
|
|
|
|
|