Update README.md
Browse files
README.md
CHANGED
|
@@ -48,20 +48,22 @@ class ConvStem(nn.Module):
|
|
| 48 |
Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
|
| 49 |
"""
|
| 50 |
|
| 51 |
-
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None,
|
| 52 |
super().__init__()
|
| 53 |
|
| 54 |
-
|
| 55 |
-
assert
|
|
|
|
| 56 |
|
| 57 |
img_size = to_2tuple(img_size)
|
| 58 |
patch_size = to_2tuple(patch_size)
|
|
|
|
| 59 |
self.img_size = img_size
|
| 60 |
self.patch_size = patch_size
|
| 61 |
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 62 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 63 |
-
self.flatten = flatten
|
| 64 |
|
|
|
|
| 65 |
stem = []
|
| 66 |
input_dim, output_dim = 3, embed_dim // 8
|
| 67 |
for l in range(2):
|
|
@@ -73,15 +75,17 @@ class ConvStem(nn.Module):
|
|
| 73 |
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
| 74 |
self.proj = nn.Sequential(*stem)
|
| 75 |
|
|
|
|
| 76 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 77 |
|
| 78 |
def forward(self, x):
|
| 79 |
B, C, H, W = x.shape
|
|
|
|
|
|
|
| 80 |
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 81 |
-
|
|
|
|
| 82 |
x = self.proj(x)
|
| 83 |
-
if self.flatten:
|
| 84 |
-
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 85 |
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
|
| 86 |
x = self.norm(x)
|
| 87 |
return x
|
|
|
|
| 48 |
Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
|
| 49 |
"""
|
| 50 |
|
| 51 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, **kwargs):
|
| 52 |
super().__init__()
|
| 53 |
|
| 54 |
+
# Check input constraints
|
| 55 |
+
assert patch_size == 4, "Patch size must be 4"
|
| 56 |
+
assert embed_dim % 8 == 0, "Embedding dimension must be a multiple of 8"
|
| 57 |
|
| 58 |
img_size = to_2tuple(img_size)
|
| 59 |
patch_size = to_2tuple(patch_size)
|
| 60 |
+
|
| 61 |
self.img_size = img_size
|
| 62 |
self.patch_size = patch_size
|
| 63 |
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 64 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
|
|
| 65 |
|
| 66 |
+
# Create stem network
|
| 67 |
stem = []
|
| 68 |
input_dim, output_dim = 3, embed_dim // 8
|
| 69 |
for l in range(2):
|
|
|
|
| 75 |
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
| 76 |
self.proj = nn.Sequential(*stem)
|
| 77 |
|
| 78 |
+
# Apply normalization layer (if provided)
|
| 79 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 80 |
|
| 81 |
def forward(self, x):
|
| 82 |
B, C, H, W = x.shape
|
| 83 |
+
|
| 84 |
+
# Check input image size
|
| 85 |
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 86 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 87 |
+
|
| 88 |
x = self.proj(x)
|
|
|
|
|
|
|
| 89 |
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
|
| 90 |
x = self.norm(x)
|
| 91 |
return x
|