| | from PIL import Image
|
| | from torch import Tensor, stack
|
| | from typing import Union, List
|
| |
|
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| | from timm import create_model
|
| | from timm.data import resolve_data_config
|
| | from timm.data.transforms_factory import create_transform
|
| |
|
| | class EfficientNetImageProcessor(BaseImageProcessor):
|
| | model_input_names = ["pixel_values"]
|
| |
|
| | def __init__(
|
| | self,
|
| | model_name: str,
|
| | **kwargs,
|
| | ):
|
| | self.model_name = model_name
|
| | self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
|
| | super().__init__(**kwargs)
|
| |
|
| | def preprocess(
|
| | self,
|
| | images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
|
| | ) -> BatchFeature:
|
| | """
|
| | Preprocesses input images by applying transformations and returning them as a BatchFeature.
|
| |
|
| | Parameters
|
| | ----------
|
| | images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
|
| | A single image or a list of images in one of the accepted formats.
|
| |
|
| | Returns
|
| | -------
|
| | BatchFeature
|
| | A batch of transformed images
|
| | """
|
| | images = [images] if not isinstance(images, list) else images
|
| |
|
| |
|
| | if len(images) == 0:
|
| | raise ValueError("Received an empty list of images")
|
| |
|
| |
|
| | test_image = images[0]
|
| | if not isinstance(images[0], (Image.Image, Tensor)):
|
| | raise TypeError(
|
| | f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
|
| | f"but got {type(test_image).__name__} instead."
|
| | )
|
| |
|
| |
|
| | transforms = create_transform(**self.config)
|
| | transformed_images = [transforms(image) for image in images]
|
| |
|
| |
|
| | transformed_image_tensors = stack(transformed_images)
|
| |
|
| | data = {'pixel_values': transformed_image_tensors}
|
| | return BatchFeature(data=data)
|
| |
|
| | __all__ = [
|
| | "EfficientNetImageProcessor"
|
| | ] |