good-cosmos-425
Title
Created on September 15|Last edited on September 15
Comment
Section 1
wandb.log()
1 | 1 |
| |
2 | 2 | from pathlib import Path | |
3 | 3 | from fastai.vision import * | |
4 | 4 |
| |
5 | 5 | from fastai.callbacks.hooks import * | |
6 | 6 | from fastai.callback import Callback | |
7 | 7 |
| |
8 | 8 | ||
9 | 9 | from wandb.fastai import WandbCallback | |
10 | 10 | from functools import partialmethod | |
11 | 11 |
| |
12 | 12 |
| |
13 | 13 | ||
14 | 14 |
| |
15 | 15 |
| |
16 | 16 | segmentation_classes = [ | |
17 | 17 | 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', | |
18 | 18 | 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', | |
19 | 19 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle', 'void' | |
20 | 20 |
| |
21 | 21 | ||
22 | 22 | class LogImagesCallback(Callback): | |
23 | 23 | ||
24 | 24 | def __init__(self, learn): | |
25 | 25 | self.learn = learn | |
26 | 26 | ||
27 | 27 | def on_epoch_end(self, **kwargs): | |
28 | 28 | num_log = 20 | |
29 | 29 | input_batch = self.learn.data.valid_ds[:num_log] | |
30 | 30 | ||
31 | 31 | raw = [] | |
32 | 32 | prediction = [] | |
33 | 33 | ground_truth = [] | |
34 | 34 | examples = [] | |
35 | 35 | new_seg = [] | |
36 | 36 | for i, img in enumerate(input_batch): | |
37 | 37 | ||
38 | 38 |
| |
39 | 39 | source_img = img[0] | |
40 | 40 | x = image2np(source_img.data*255).astype(np.uint8) | |
41 | 41 | raw_source = PIL.Image.fromarray(x) | |
42 | 42 | raw.append(raw_source) | |
43 | 43 | ||
44 | 44 |
| |
45 | 45 | o = learn.predict(img[0])[0] | |
46 | 46 | xo = image2np(o.data).astype(np.uint8) | |
47 | 47 | plt.imsave("label.png", xo, cmap="tab20") | |
48 | 48 | f = open_image("label.png") | |
49 | 49 | ||
50 | 50 | x = image2np(f.data*255).astype(np.uint8) | |
51 | 51 | raw_x = PIL.Image.fromarray(x) | |
52 | 52 | prediction.append(raw_x) | |
53 | 53 | ||
54 | 54 |
| |
55 | 55 | new_seg.append(wandb.Image(raw_source, metadata=wandb.Metadata({ | |
56 | 56 | "type": "segmentation/beta", | |
57 | 57 | "segmentation": json.dumps(f.data.tolist()), #, | |
58 | 58 | "classes": segmentation_classes | |
59 | 59 | }))) | |
60 | 60 | ||
61 | 61 |
| |
62 | 62 |
| |
63 | 63 | img_label = img[1] | |
64 | 64 | x_label = image2np(img_label.data).astype(np.uint8) | |
65 | 65 | plt.imsave("label_x.png", x_label, cmap="tab20") | |
66 | 66 | f = open_image("label_x.png") | |
67 | 67 | x = image2np(f.data*255).astype(np.uint8) | |
68 | 68 | raw_x_label = PIL.Image.fromarray(x) | |
69 | 69 | ground_truth.append(raw_x_label) | |
70 | 70 | ||
71 | 71 | wandb.log({"camera view" : [wandb.Image(e) for e in raw], | |
72 | 72 | "prediction" : [wandb.Image(e) for e in prediction], | |
73 | 73 | "ground truth" : [wandb.Image(e) for e in ground_truth]}) | |
74 | 74 | ||
75 | 75 |
| |
76 | 76 | for i, s in enumerate(new_seg): | |
77 | 77 | wandb.log({"segmentation_" + str(i) : s}) | |
78 | 78 | ||
79 | 79 |
| |
80 | 80 | wandb.init(project="deep-drive", entity="stacey") | |
81 | 81 | ||
82 | 82 |
| |
83 | 83 | config = wandb.config # for shortening | |
84 | 84 | config.framework = "fast.ai" # AI framework used (for when we create other versions) | |
85 | 85 | config.img_size = (360, 640) # dimensions of resized image - can be 1 dim or tuple | |
86 | 86 | ||
87 | 87 | config.batch_size = 8 # Batch size during training | |
88 | 88 | config.epochs = 10 # Number of epochs for training | |
89 | 89 | ||
90 | 90 | config.encoder = "resnet34" | |
91 | 91 | if config.encoder == "resnet18": | |
92 | 92 | encoder = models.resnet18 # encoder of unet (contracting path) | |
93 | 93 | elif config.encoder == "resnet34": | |
94 | 94 | encoder = models.resnet34 | |
95 | 95 | elif config.encoder == "squeezenet1_0": | |
96 | 96 | encoder = models.squeezenet1_0 | |
97 | 97 | elif config.encoder == "squeezenet1_1": | |
98 | 98 | encoder = models.squeezenet1_1 | |
99 | 99 | elif config.encoder == "alexnet": | |
100 | 100 | encoder = models.alexnet | |
101 | 101 | ||
102 | 102 |
| |
103 | 103 |
| |
104 | 104 |
| |
105 | 105 | config.pretrained = True # whether we use a frozen pre-trained encoder | |
106 | 106 | ||
107 | 107 | ||
108 | 108 |
| |
109 | 109 | config.weight_decay = 0.097 # weight decay applied on layers | |
110 | 110 | config.bn_weight_decay = True # whether weight decay is applied on batch norm layers | |
111 | 111 | config.one_cycle = True # use the "1cycle" policy -> https://arxiv.org/abs/1803.09820 | |
112 | 112 |
| |
113 | 113 | config.learning_rate = 0.001 # learning rate | |
114 | 114 | save_model = False | |
115 | 115 | ||
116 | 116 |
| |
117 | 117 |
| |
118 | 118 | config.training_stages = 2 | |
119 | 119 | ||
120 | 120 |
| |
121 | 121 | path_data = Path('../../../../BigData/bdd100K/bdd100k/seg') | |
122 | 122 | path_lbl = path_data / 'labels' | |
123 | 123 | path_img = path_data / 'images' | |
124 | 124 | ||
125 | 125 |
| |
126 | 126 | get_y_fn = lambda x: path_lbl / x.parts[-2] / f'{x.stem}_train_id.png' | |
127 | 127 | ||
128 | 128 |
| |
129 | 129 | src = (SegmentationItemList.from_folder(path_img).use_partial_data(0.05) | |
130 | 130 |
| |
131 | 131 | .split_by_folder(train='train', valid='val') | |
132 | 132 | .label_from_func(get_y_fn, classes=segmentation_classes)) | |
133 | 133 | ||
134 | 134 |
| |
135 | 135 | data = (src.transform(get_transforms(), size=config.img_size, tfm_y=True) | |
136 | 136 | .databunch(bs=config.batch_size) | |
137 | 137 | .normalize(imagenet_stats)) | |
138 | 138 | ||
139 | 139 | config.num_train = len(data.train_ds) | |
140 | 140 | config.num_valid = len(data.valid_ds) | |
141 | 141 | ||
142 | 142 |
| |
143 | 143 |
| |
144 | 144 |
| |
145 | 145 | void_code = 19 | |
146 | 146 |
| |
147 | 147 | def acc(input, target): | |
148 | 148 | target = target.squeeze(1) | |
149 | 149 | mask = target != void_code | |
150 | 150 | try: | |
151 | 151 | i = (input.argmax(dim=1)[mask] == target[mask]).float() | |
152 | 152 | m_i = i.mean() | |
153 | 153 |
| |
154 | 154 | except: | |
155 | 155 | return torch.tensor([0.0]) | |
156 | 156 | ||
157 | 157 | def traffic_acc(input, target): | |
158 | 158 | target = target.squeeze(1) | |
159 | 159 | mask_pole = target == 5 | |
160 | 160 | mask_light = target == 6 | |
161 | 161 | mask_sign = taget = 7 | |
162 | 162 | mask_traffic = mask_pole | mask_light | mask_sign | |
163 | 163 | try: | |
164 | 164 | i = (input.argmax(dim=1)[mask_traffic] == target[mask_traffic]).float() | |
165 | 165 | m_i = i.mean() | |
166 | 166 |
| |
167 | 167 | except: | |
168 | 168 | return torch.tensor([0.0]) | |
169 | 169 | ||
170 | 170 | def road_acc(input, target): | |
171 | 171 | target = target.squeeze(1) | |
172 | 172 | mask = target == 0 | |
173 | 173 | try: | |
174 | 174 | intersection = input.argmax(dim=1)[mask] == target[mask] | |
175 | 175 | mean_intersection = intersection.float().mean() | |
176 | 176 |
| |
177 | 177 | except: | |
178 | 178 | return torch.tensor([0.0]) | |
179 | 179 | ||
180 | 180 | ||
181 | 181 | def car_acc(input, target): | |
182 | 182 | target = target.squeeze(1) | |
183 | 183 | mask = target == 13 | |
184 | 184 | try: | |
185 | 185 | intersection = input.argmax(dim=1)[mask] == target[mask] | |
186 | 186 | mean_intersection = intersection.float().mean() | |
187 | 187 |
| |
188 | 188 | except: | |
189 | 189 | return torch.tensor([0.0]) | |
190 | 190 | ||
191 | 191 | def human_acc(input, target): | |
192 | 192 | target = target.squeeze(1) | |
193 | 193 | mask_human = target == 11 | |
194 | 194 |
| |
195 | 195 |
| |
196 | 196 |
| |
197 | 197 | try: | |
198 | 198 | intersection = (input.argmax(dim=1)[mask_human] == target[mask_human]).float() | |
199 | 199 | mean_interesection = intersection.mean() | |
200 | 200 | print("GOT SOME HUMANS: ", mean_intersection) | |
201 | 201 |
| |
202 | 202 | except: | |
203 | 203 | return torch.tensor([0.0]) | |
204 | 204 | ||
205 | 205 |
| |
206 | 206 |
| |
207 | 207 |
| |
208 | 208 |
| |
209 | 209 |
| |
210 | 210 | ||
211 | 211 |
| |
212 | 212 |
| |
213 | 213 |
| |
214 | 214 | SMOOTH = 1e-6 | |
215 | 215 | def iou(input, target): | |
216 | 216 |
| |
217 | 217 |
| |
218 | 218 |
| |
219 | 219 | target = target.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W | |
220 | 220 | intersection = (input.argmax(dim=1) & target).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 | |
221 | 221 | union = (input.argmax(dim=1) | target).float().sum((1, 2)) # Will be zzero if both are 0 | |
222 | 222 | iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 | |
223 | 223 | ||
224 | 224 |
| |
225 | 225 |
| |
226 | 226 | return iou.mean() | |
227 | 227 | ||
228 | 228 | def human_iou(input, target): | |
229 | 229 |
| |
230 | 230 |
| |
231 | 231 |
| |
232 | 232 | target = target.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W | |
233 | 233 | mask_human = target == 11 | |
234 | 234 | intersection = (input.argmax(dim=1)[mask_human] == target[mask_human]).float()#.sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 | |
235 | 235 | union = (input.argmax(dim=1)[mask_human] | target[mask_human]).float() #.sum((1, 2)) # Will be zzero if both are 0 | |
236 | 236 | iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 | |
237 | 237 | ||
238 | 238 |
| |
239 | 239 |
| |
240 | 240 | return iou.mean() | |
241 | 241 | ||
242 | 242 |
| |
243 | 243 | learn = unet_learner( | |
244 | 244 |
| |
245 | 245 | arch=encoder, | |
246 | 246 | pretrained=config.pretrained, | |
247 | 247 | metrics=[iou, acc, car_acc, traffic_acc, human_acc, human_iou, road_acc], | |
248 | 248 | wd=config.weight_decay, | |
249 | 249 | bn_wd=config.bn_weight_decay, | |
250 | 250 | callback_fns=partial(WandbCallback, save_model=save_model, monitor='iou'))#, input_type='images')) | |
251 | 251 | ||
252 | 252 |
| |
253 | 253 | if config.one_cycle: | |
254 | 254 | learn.fit_one_cycle( | |
255 | 255 | config.epochs // 2, | |
256 | 256 | max_lr=slice(config.learning_rate), | |
257 | 257 | callbacks=[LogImagesCallback(learn)]) | |
258 | 258 | learn.unfreeze() | |
259 | 259 | learn.fit_one_cycle( | |
260 | 260 | config.epochs // 2, | |
261 | 261 | max_lr=slice(config.learning_rate / 100, | |
262 | 262 | config.learning_rate / 10), | |
263 | 263 | callbacks=[LogImagesCallback(learn)]) | |
264 | 264 | else: | |
265 | 265 | learn.fit( | |
266 | 266 | config.epochs, | |
267 | 267 | lr=slice(config.learning_rate)) | |
268 | 268 | ||
269 | 269 |
| |
270 | 270 | learn.export() |
Run set
84
Run set
84
Add a comment