cardvaultplus / inference_example.py
sugiv's picture
Add comprehensive inference example script
24dc13b verified
#!/usr/bin/env python3
"""
CardVault+ Inference Example
Simple example showing how to use the CardVault+ model for card extraction
"""
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image, ImageDraw
import json
def create_sample_card():
"""Create a sample credit card image for testing"""
# Create card-like image
img = Image.new('RGB', (400, 250), color='lightblue')
draw = ImageDraw.Draw(img)
# Add card elements
draw.text((20, 50), "SAMPLE BANK", fill='black')
draw.text((20, 100), "1234 5678 9012 3456", fill='black')
draw.text((20, 150), "JOHN DOE", fill='black')
draw.text((300, 150), "12/25", fill='black')
return img
def extract_card_info(image_path_or_pil=None):
"""Extract structured information from a card image"""
# Load the model
print("Loading CardVault+ model...")
model_id = "sugiv/cardvaultplus"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
# Load image
if image_path_or_pil is None:
print("Creating sample card image...")
image = create_sample_card()
elif isinstance(image_path_or_pil, str):
image = Image.open(image_path_or_pil)
else:
image = image_path_or_pil
# Prepare extraction prompt
prompt = "<image>Extract structured information from this card/document in JSON format."
# Process the image and prompt
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Move to GPU if available
device = next(model.parameters()).device
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
# Generate extraction
print("Extracting information...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
pad_token_id=processor.tokenizer.eos_token_id
)
# Decode response
response = processor.decode(outputs[0], skip_special_tokens=True)
# Extract JSON if present
extracted_json = None
if '{' in response and '}' in response:
try:
json_start = response.find('{')
json_end = response.rfind('}') + 1
json_str = response[json_start:json_end]
extracted_json = json.loads(json_str)
except:
pass
return {
'full_response': response,
'extracted_json': extracted_json,
'success': extracted_json is not None
}
if __name__ == "__main__":
# Example usage
result = extract_card_info() # Uses sample card
print("="*50)
print("CardVault+ Extraction Results")
print("="*50)
print(f"Success: {result['success']}")
print(f"Full Response: {result['full_response']}")
if result['extracted_json']:
print("Extracted JSON:")
print(json.dumps(result['extracted_json'], indent=2))
# Example with your own image:
# result = extract_card_info("path/to/your/card.jpg")