Eric Houzelle commited on
Commit
0675bbd
·
1 Parent(s): ed659da

Upload minimal files only

Browse files
Files changed (3) hide show
  1. README.md +79 -3
  2. colorizer-tiny.pth +3 -0
  3. model.py +67 -0
README.md CHANGED
@@ -1,3 +1,79 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎨 Simple Colorizer - Image Colorization Model
2
+
3
+ This repository contains a PyTorch-trained U-Net model that automatically colorizes grayscale images.
4
+
5
+ ---
6
+
7
+ ## 📂 Repository Contents
8
+
9
+ - `best_colorization_model.pth`: Trained model weights
10
+ - `model.py`: The `ImprovedUNet` architecture definition
11
+ - `README.md`: This file
12
+
13
+ ---
14
+
15
+ ## 🚀 Usage Example
16
+
17
+ ### 1️⃣ Install Dependencies
18
+
19
+ ```python
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ### 2️⃣ Load the Model
24
+ ```python
25
+ import torch
26
+ from model import ImprovedUNet
27
+ ```
28
+
29
+ # Create the model instance
30
+ ```python
31
+ model = ImprovedUNet()
32
+
33
+ # Load the weights
34
+ checkpoint = torch.load("best_colorization_model.pth", map_location="cpu")
35
+ model.load_state_dict(checkpoint["model_state_dict"])
36
+ model.eval()
37
+ ```
38
+ ### 3️⃣ Colorize an Image
39
+ ```python
40
+ from PIL import Image
41
+ import torchvision.transforms as T
42
+
43
+ img = Image.open("path/to/grayscale_image.jpg").convert("L")
44
+ transform = T.Compose([
45
+ T.Resize((256, 256)),
46
+ T.ToTensor(),
47
+ T.Normalize(mean=[0.5], std=[0.5])
48
+ ])
49
+
50
+ input_tensor = transform(img).unsqueeze(0)
51
+
52
+ with torch.no_grad():
53
+ output = model(input_tensor)
54
+
55
+ output_image = output.squeeze(0).permute(1, 2, 0).numpy()
56
+ output_image = (output_image * 255).clip(0, 255).astype("uint8")
57
+
58
+ Image.fromarray(output_image).save("colorized_output.png")
59
+ ```
60
+
61
+ ℹ️ Training Information
62
+
63
+ Architecture: Custom U-Net (ImprovedUNet)
64
+
65
+ Input Size: 256x256 pixels
66
+
67
+ Optimizer: Adam
68
+
69
+ Loss Function: MSE
70
+
71
+ Epochs: [Specify the number of epochs]
72
+
73
+ 📈 Results
74
+ Here is an example of an image colorized by the model:
75
+ ![Colorized Example](test_result_PARIS.png)
76
+
77
+
78
+ ✨ Author
79
+ This model was developed by Eric Houzelle.
colorizer-tiny.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e357b9c98f66b82ae79e5202110fae3d528b625c484e575463b9714256f74e
3
+ size 372679415
model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ConvBlock(nn.Module):
2
+ def __init__(self, in_ch, out_ch, dropout=0.1):
3
+ super().__init__()
4
+ self.conv = nn.Sequential(
5
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
6
+ nn.BatchNorm2d(out_ch),
7
+ nn.ReLU(inplace=True),
8
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
9
+ nn.BatchNorm2d(out_ch),
10
+ nn.ReLU(inplace=True),
11
+ nn.Dropout2d(dropout)
12
+ )
13
+ def forward(self, x):
14
+ return self.conv(x)
15
+
16
+ class ImprovedUNet(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.enc1 = ConvBlock(1, 64, dropout=0.1)
20
+ self.enc2 = ConvBlock(64, 128, dropout=0.1)
21
+ self.enc3 = ConvBlock(128, 256, dropout=0.2)
22
+ self.enc4 = ConvBlock(256, 512, dropout=0.2)
23
+
24
+ self.pool = nn.MaxPool2d(2)
25
+
26
+ self.bottleneck = ConvBlock(512, 1024, dropout=0.3)
27
+
28
+ self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
29
+ self.dec4 = ConvBlock(1024, 512, dropout=0.2)
30
+
31
+ self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
32
+ self.dec3 = ConvBlock(512, 256, dropout=0.2)
33
+
34
+ self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
35
+ self.dec2 = ConvBlock(256, 128, dropout=0.1)
36
+
37
+ self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
38
+ self.dec1 = ConvBlock(128, 64, dropout=0.1)
39
+
40
+ self.out_conv = nn.Conv2d(64, 3, 1)
41
+ self.out_act = nn.Tanh()
42
+ def forward(self, x):
43
+ e1 = self.enc1(x)
44
+ e2 = self.enc2(self.pool(e1))
45
+ e3 = self.enc3(self.pool(e2))
46
+ e4 = self.enc4(self.pool(e3))
47
+
48
+ b = self.bottleneck(self.pool(e4))
49
+
50
+ d4 = self.up4(b)
51
+ d4 = torch.cat([d4, e4], dim=1)
52
+ d4 = self.dec4(d4)
53
+
54
+ d3 = self.up3(d4)
55
+ d3 = torch.cat([d3, e3], dim=1)
56
+ d3 = self.dec3(d3)
57
+
58
+ d2 = self.up2(d3)
59
+ d2 = torch.cat([d2, e2], dim=1)
60
+ d2 = self.dec2(d2)
61
+
62
+ d1 = self.up1(d2)
63
+ d1 = torch.cat([d1, e1], dim=1)
64
+ d1 = self.dec1(d1)
65
+
66
+ out = self.out_conv(d1)
67
+ return self.out_act(out)