Skip to main content

Title

Created on September 15|Last edited on September 15

Section 1

wandb.log()

1
1
# Train Neural network on dataset with fast.ai
2
2
from pathlib import Path
3
3
from fastai.vision import *
4
4
import wandb
5
5
from fastai.callbacks.hooks import *
6
6
from fastai.callback import Callback
7
7
import json
8
8
9
9
from wandb.fastai import WandbCallback
10
10
from functools import partialmethod
11
11
import PIL
12
12
import torch
13
13
14
14
# Segmentation Classes extracted from dataset source code
15
15
# See https://github.com/ucbdrive/bdd-data/blob/master/bdd_data/label.py
16
16
segmentation_classes = [
17
17
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
18
18
    'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
19
19
    'truck', 'bus', 'train', 'motorcycle', 'bicycle', 'void'
20
20
]
21
21
22
22
class LogImagesCallback(Callback):
23
23
24
24
  def __init__(self, learn):
25
25
    self.learn = learn
26
26
   
27
27
  def on_epoch_end(self, **kwargs):
28
28
    num_log = 20
29
29
    input_batch = self.learn.data.valid_ds[:num_log]
30
30
31
31
    raw = []
32
32
    prediction = []
33
33
    ground_truth = []
34
34
    examples = []
35
35
    new_seg = []
36
36
    for i, img in enumerate(input_batch):
37
37
38
38
      # log original image
39
39
      source_img = img[0]
40
40
      x = image2np(source_img.data*255).astype(np.uint8)
41
41
      raw_source = PIL.Image.fromarray(x)
42
42
      raw.append(raw_source)
43
43
44
44
      # predict from original image
45
45
      o = learn.predict(img[0])[0]
46
46
      xo = image2np(o.data).astype(np.uint8)
47
47
      plt.imsave("label.png", xo, cmap="tab20")  
48
48
      f = open_image("label.png")
49
49
  
50
50
      x = image2np(f.data*255).astype(np.uint8)
51
51
      raw_x = PIL.Image.fromarray(x)
52
52
      prediction.append(raw_x)
53
53
54
54
      # draft: new segmentation style
55
55
      new_seg.append(wandb.Image(raw_source, metadata=wandb.Metadata({
56
56
           "type": "segmentation/beta",
57
57
           "segmentation": json.dumps(f.data.tolist()), #,
58
58
           "classes": segmentation_classes
59
59
      })))
60
60
 
61
61
      # log ground truth prediction: convert to plotly color map
62
62
      # via image save (instead of fastai default)
63
63
      img_label = img[1]
64
64
      x_label = image2np(img_label.data).astype(np.uint8)
65
65
      plt.imsave("label_x.png", x_label, cmap="tab20")
66
66
      f = open_image("label_x.png")
67
67
      x = image2np(f.data*255).astype(np.uint8)
68
68
      raw_x_label = PIL.Image.fromarray(x)
69
69
      ground_truth.append(raw_x_label)  
70
70
71
71
    wandb.log({"camera view" : [wandb.Image(e) for e in raw],
72
72
               "prediction" : [wandb.Image(e) for e in prediction],
73
73
               "ground truth" : [wandb.Image(e) for e in ground_truth]})
74
74
75
75
    # draft: new segmentation style
76
76
    for i, s in enumerate(new_seg):
77
77
      wandb.log({"segmentation_" + str(i) : s})
78
78
  
79
79
# Initialize W&B project
80
80
wandb.init(project="deep-drive", entity="stacey")
81
81
82
82
# Define hyper-parameters
83
83
config = wandb.config           # for shortening
84
84
config.framework = "fast.ai"    # AI framework used (for when we create other versions)
85
85
config.img_size = (360, 640)    # dimensions of resized image - can be 1 dim or tuple
86
86
87
87
config.batch_size = 8           # Batch size during training
88
88
config.epochs = 10             # Number of epochs for training
89
89
90
90
config.encoder = "resnet34"
91
91
if config.encoder == "resnet18":
92
92
  encoder = models.resnet18     # encoder of unet (contracting path)
93
93
elif config.encoder == "resnet34":
94
94
  encoder = models.resnet34
95
95
elif config.encoder == "squeezenet1_0":
96
96
  encoder = models.squeezenet1_0
97
97
elif config.encoder == "squeezenet1_1":
98
98
  encoder = models.squeezenet1_1
99
99
elif config.encoder == "alexnet":
100
100
  encoder = models.alexnet
101
101
102
102
#config.encoder = encoder.__name__
103
103
#config.encoder = "resnet18"
104
104
#encoder = models.resnet18
105
105
config.pretrained = True        # whether we use a frozen pre-trained encoder
106
106
107
107
108
108
# SWEEPS UNCOMMENT
109
109
config.weight_decay = 0.097     # weight decay applied on layers
110
110
config.bn_weight_decay = True # whether weight decay is applied on batch norm layers
111
111
config.one_cycle = True         # use the "1cycle" policy -> https://arxiv.org/abs/1803.09820
112
112
# SWEEPS UNCOMMENT
113
113
config.learning_rate = 0.001     # learning rate
114
114
save_model = False
115
115
116
116
# Custom values to filter runs
117
117
# SWEEPS UNCOMMENT
118
118
config.training_stages = 2
119
119
120
120
# Data paths
121
121
path_data = Path('../../../../BigData/bdd100K/bdd100k/seg')
122
122
path_lbl = path_data / 'labels'
123
123
path_img = path_data / 'images'
124
124
125
125
# Associate a label to an input
126
126
get_y_fn = lambda x: path_lbl / x.parts[-2] / f'{x.stem}_train_id.png'
127
127
128
128
# Load data into train & validation sets
129
129
src = (SegmentationItemList.from_folder(path_img).use_partial_data(0.05)
130
130
#src = (SegmentationItemList.from_folder(path_img)
131
131
       .split_by_folder(train='train', valid='val')
132
132
       .label_from_func(get_y_fn, classes=segmentation_classes))
133
133
134
134
# Resize, augment, load in batch & normalize (so we can use pre-trained networks)
135
135
data = (src.transform(get_transforms(), size=config.img_size, tfm_y=True)
136
136
        .databunch(bs=config.batch_size)
137
137
        .normalize(imagenet_stats))
138
138
139
139
config.num_train = len(data.train_ds)
140
140
config.num_valid = len(data.valid_ds)
141
141
142
142
########################################
143
143
# Accuracy metrics
144
144
#---------------------------------------
145
145
void_code = 19
146
146
# overall accuracy: across all classes, ignore unlabeled pixels
147
147
def acc(input, target):
148
148
    target = target.squeeze(1)
149
149
    mask = target != void_code
150
150
    try:
151
151
      i = (input.argmax(dim=1)[mask] == target[mask]).float()
152
152
      m_i = i.mean()
153
153
      return m_i
154
154
    except:
155
155
      return torch.tensor([0.0])
156
156
157
157
def traffic_acc(input, target):
158
158
    target = target.squeeze(1)
159
159
    mask_pole = target == 5
160
160
    mask_light = target == 6
161
161
    mask_sign = taget = 7
162
162
    mask_traffic = mask_pole | mask_light | mask_sign
163
163
    try:
164
164
      i = (input.argmax(dim=1)[mask_traffic] == target[mask_traffic]).float()
165
165
      m_i = i.mean()
166
166
      return m_i
167
167
    except:
168
168
      return torch.tensor([0.0])
169
169
170
170
def road_acc(input, target):
171
171
    target = target.squeeze(1)
172
172
    mask = target == 0 
173
173
    try:
174
174
        intersection = input.argmax(dim=1)[mask] == target[mask]
175
175
        mean_intersection = intersection.float().mean()
176
176
        return mean_intersection
177
177
    except:
178
178
        return torch.tensor([0.0])
179
179
180
180
181
181
def car_acc(input, target):
182
182
    target = target.squeeze(1)
183
183
    mask = target == 13 
184
184
    try:
185
185
        intersection = input.argmax(dim=1)[mask] == target[mask]
186
186
        mean_intersection = intersection.float().mean()
187
187
        return mean_intersection
188
188
    except:
189
189
        return torch.tensor([0.0])
190
190
191
191
def human_acc(input, target):
192
192
    target = target.squeeze(1)
193
193
    mask_human = target == 11 
194
194
    #mask_rider = target == 12
195
195
    #mask_human = mask_person | mask_rider
196
196
    # this measures similarity of truth & guess on places where either has human pixels
197
197
    try:
198
198
        intersection = (input.argmax(dim=1)[mask_human] == target[mask_human]).float()
199
199
        mean_interesection = intersection.mean()
200
200
        print("GOT SOME HUMANS: ", mean_intersection)
201
201
        return mean_intersection
202
202
    except:
203
203
        return torch.tensor([0.0])
204
204
205
205
# cases we care about for human iou:
206
206
# 1. Truth: human, Guess: not human => most important
207
207
# 2. Truth: human, Guess: human => true positive, counts for accuracy as pixel intersection
208
208
# 3. Truth: not human, Guess: human => less important
209
209
# 4. Truth: not human, Guess: not human => true negative
210
210
211
211
########################################
212
212
# IoU metrics
213
213
#---------------------------------------
214
214
SMOOTH = 1e-6
215
215
def iou(input, target):
216
216
    # You can comment out this line if you are passing tensors of equal shape
217
217
    # But if you are passing output from UNet or something it will most probably
218
218
    # be with the BATCH x 1 x H x W shape
219
219
    target = target.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
220
220
    intersection = (input.argmax(dim=1) & target).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
221
221
    union = (input.argmax(dim=1) | target).float().sum((1, 2))         # Will be zzero if both are 0
222
222
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
223
223
    
224
224
    #thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
225
225
    #return thresholded.mean()
226
226
    return iou.mean()
227
227
228
228
def human_iou(input, target):
229
229
    # You can comment out this line if you are passing tensors of equal shape
230
230
    # But if you are passing output from UNet or something it will most probably
231
231
    # be with the BATCH x 1 x H x W shape
232
232
    target = target.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
233
233
    mask_human = target == 11
234
234
    intersection = (input.argmax(dim=1)[mask_human] == target[mask_human]).float()#.sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
235
235
    union = (input.argmax(dim=1)[mask_human] | target[mask_human]).float() #.sum((1, 2))         # Will be zzero if both are 0
236
236
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
237
237
    
238
238
    #thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
239
239
    #return thresholded.mean()
240
240
    return iou.mean()
241
241
242
242
# Create NN
243
243
learn = unet_learner(
244
244
    data,
245
245
    arch=encoder,
246
246
    pretrained=config.pretrained,
247
247
    metrics=[iou, acc, car_acc, traffic_acc, human_acc, human_iou, road_acc],
248
248
    wd=config.weight_decay,
249
249
    bn_wd=config.bn_weight_decay,
250
250
    callback_fns=partial(WandbCallback, save_model=save_model, monitor='iou'))#, input_type='images'))
251
251
252
252
# Train
253
253
if config.one_cycle:
254
254
    learn.fit_one_cycle(
255
255
        config.epochs // 2,
256
256
        max_lr=slice(config.learning_rate),
257
257
        callbacks=[LogImagesCallback(learn)])
258
258
    learn.unfreeze()
259
259
    learn.fit_one_cycle(
260
260
        config.epochs // 2,
261
261
        max_lr=slice(config.learning_rate / 100,
262
262
                     config.learning_rate / 10),
263
263
        callbacks=[LogImagesCallback(learn)])
264
264
else:
265
265
    learn.fit(
266
266
        config.epochs,
267
267
        lr=slice(config.learning_rate))
268
268
269
269
# try to save learner
270
270
learn.export()
0.0000.0010.0020.0030.0040.0050.0060.0070.008l...0102030405060708090100110120n...alexnetresnet18resnet34e...5.05.25.45.65.86.06.26.46.66.87.07.27.47.67.88.0b...0.000.010.020.030.040.050.060.070.080.090.10w...1.01.21.41.61.82.02.22.42.62.83.0t...0.100.150.200.250.300.350.400.450.500.550.600.650.700.750.80iou0.0000.0020.0040.0060.0080.0100.0120.0140.0160.018h...
Run set
84



Run set
84