The easiest way to train a U-NET Image Segmentation model using TensorFlow and labelme

中文阅读

Foreword

Earlier, I covered how to use Train a Custom Image Segmentation Model Using TensorFlow Object Detection API Mask R-CNN. However, the Mask R-CNN network is more complex and has high-performance consumption. In practical applications, if it is not a very complex segmentation task, there is a more suitable choice, which is U-NET to be explained in this article.

There are many articles about U-NET, but very few articles on custom datasets to model definition, training, and prediction. Therefore, this article aims to cover each step through a Demo, so that everyone can quickly master U-NET.

Installation

Download the source code https://github.com/CatchZeng/tensorflow-unet-labelme to your local and enter this directory.

If you are using macOS, you need to execute the following command before installation.

1
❯ brew install pyqt

Execute the following command to install the virtual environment named unet.

1
❯ conda create -n unet -y python=3.9 && conda activate unet && pip install -r requirements.txt

Dataset

The article is still used as an example of the cup, teapot, and humidifier in The easiest way to Train a Custom Image Segmentation Model Using TensorFlow Object Detection API Mask R-CNN.

Annotate images

Data annotation has already been explained in The easiest way to Train a Custom Image Segmentation Model Using TensorFlow Object Detection API Mask R-CNN, so I won’t repeat it here.

Generate VOC format dataset

U-NET dataset consists of raw images (jpg) and label (mask) images (png), usually using the VOC format to organize the dataset.

Store the labeled data in datasets/train of tensorflow-unet-labelme, and create a new datasets/labels.txt with the category names, see https://github.com/wkentaro/labelme/tree/main/examples/semantic_segmentation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
❯ tree -L 3
.
├── Makefile
├── README.md
├── datasets
│   ├── README.md
│   ├── labels.txt
│   ├── train
│   │   ├── 1.jpg
│   │   ├── 1.json
......
│   │   ├── 9.jpg
│   │   └── 9.json
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py

Note: You can refer to https://github.com/CatchZeng/tensorflow-unet-labelme/tree/master/datasets.

Execute the following command to generate a VOC format dataset.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
❯ make voc
❯ tree -L 5
.
├── Makefile
├── README.md
├── datasets
│   ├── README.md
│   ├── labels.txt
│   ├── test
│   │   └── 1.jpg
│   ├── train
│   │   ├── 1.jpg
│   │   ├── 1.json
......
│   │   ├── 9.jpg
│   │   └── 9.json
│   └── train_voc
│       ├── ImageSets
│       │   └── Segmentation
│       │       ├── test.txt
│       │       ├── train.txt   # List of training set image names
│       │       ├── trainval.txt
│       │       └── val.txt  # list of validation set image names
│       ├── JPEGImages # raw images
│       │   ├── 1.jpg
......
│       │   └── 9.jpg
│       ├── SegmentationClass
│       │   ├── 1.npy
......
│       │   └── 9.npy
│       ├── SegmentationClassPNG # mask images
│       │   ├── 1.png
......
│       │   └── 9.png
│       ├── SegmentationClassVisualization
│       │   ├── 1.jpg
......
│       │   └── 9.jpg
│       └── class_names.txt
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py

The generated datasets/train_voc is the dataset used for training.

Training

Open unet.ipynb and select the Python interpreter as unet to start training.

Detailed code

The dataset is divided into the training set and validation set, and files are read from ImageSets/Segmentation/train.txt and ImageSets/Segmentation/val.txt respectively. Then, it is constructed as a tf.keras.utils.Sequence object through the UnetDataset class, which is convenient for direct training through model.fit later.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
dataset_path = 'datasets/train_voc'
# read dataset txt files
with open(os.path.join(dataset_path, "ImageSets/Segmentation/train.txt"),
          "r",
          encoding="utf8") as f:
    train_lines = f.readlines()
with open(os.path.join(dataset_path, "ImageSets/Segmentation/val.txt"),
          "r",
          encoding="utf8") as f:
    val_lines = f.readlines()
train_batches = UnetDataset(train_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
                            True, dataset_path)
val_batches = UnetDataset(val_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
                          False, dataset_path)
STEPS_PER_EPOCH = len(train_lines) // BATCH_SIZE
VALIDATION_STEPS = len(val_lines) // BATCH_SIZE // VAL_SUBSPLITS

The UnetDataset class inherits from tf.keras.utils.Sequence. Return a set of batch_size data through the __getitem__ method, which includes the original image (images) and the label image (targets). Because the model has a fixed input shape, the resize operation is performed in the process_data method; during the training process, data enhancement can also be added, and a simple flip is used here.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class UnetDataset(tf.keras.utils.Sequence):
def __init__(self, annotation_lines, input_shape, batch_size, num_classes,
                 train, dataset_path):
        self.annotation_lines = annotation_lines
        self.length = len(self.annotation_lines)
        self.input_shape = input_shape
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.train = train
        self.dataset_path = dataset_path
def __len__(self):
        return math.ceil(len(self.annotation_lines) / float(self.batch_size))
def __getitem__(self, index):
        images = []
        targets = []
        for i in range(index * self.batch_size, (index + 1) * self.batch_size):
            i = i % self.length
            name = self.annotation_lines[i].split()[0]
            jpg = Image.open(
                os.path.join(os.path.join(self.dataset_path, "JPEGImages"),
                             name + ".jpg"))
            png = Image.open(
                os.path.join(
                    os.path.join(self.dataset_path, "SegmentationClassPNG"),
                    name + ".png"))
jpg, png = self.process_data(jpg,
                                         png,
                                         self.input_shape,
                                         random=self.train)
images.append(jpg)
            targets.append(png)
images = np.array(images)
        targets = np.array(targets)
        return images, targets
def rand(self, a=0, b=1):
        return np.random.rand() * (b - a) + a
def process_data(self, image, label, input_shape, random=True):
        image = cvtColor(image)
        label = Image.fromarray(np.array(label))
        h, w, _ = input_shape
# resize
        image, _, _ = resize_image(image, (w, h))
        label, _, _ = resize_label(label, (w, h))
if random:
            # flip
            flip = self.rand() < .5
            if flip:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)
# np
        image = np.array(image, np.float32)
        image = normalize(image)
label = np.array(label)
        label[label >= self.num_classes] = self.num_classes
return image, label

The model definition part is relatively simple, as in the paper, mainly downsampling, upsampling, and concat.

Here, the author refers to https://www.tensorflow.org/tutorials/images/segmentation, you can check the detailed analysis.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def unet_model(output_channels: int):
    inputs = tf.keras.layers.Input(shape=INPUT_SHAPE)
# Downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])
# This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(filters=output_channels,
                                           kernel_size=3,
                                           strides=2,
                                           padding='same')  #64x64 -> 128x128
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)

In the Callback part, the author mainly uses DisplayCallback and ModelCheckpointCallback.

DisplayCallback is used to display the predicted results after training an epoch, so as to observe the effect of the model.

ModelCheckpointCallback is used to save the weights (model) in the logs folder after training an epoch, and record the accuracy and loss rate after each epoch. In this way, the user can select a better model after training.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print('\nSample Prediction after epoch {}\n'.format(epoch + 1))
class ModelCheckpointCallback(tf.keras.callbacks.Callback):
def __init__(self,
                 filepath,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 period=1):
        super(ModelCheckpointCallback, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
            warnings.warn(
                'ModelCheckpoint mode %s is unknown, '
                'fallback to auto mode.' % (mode), RuntimeWarning)
            mode = 'auto'
if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch + 1, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn(
                        'Can save best model only with %s available, '
                        'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print(
                                '\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                                ' saving model to %s' %
                                (epoch + 1, self.monitor, self.best, current,
                                 filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.model.save_weights(filepath, overwrite=True)
                        else:
                            self.model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s did not improve' %
                                  (epoch + 1, self.monitor))
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' %
                          (epoch + 1, filepath))
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)

In the prediction model part, first, resize the file to the INPUT_SHAPE size of the model. It should be noted here that the predicted graph is not necessarily the same scale as INPUT_SHAPE. In order not to cause inaccurate prediction results due to the scale problem, the author adds **gray placeholders for the place where the image is not proportional** when resize, as shown in the following figure.

Then after the prediction is finished, the gray placeholders is removed.

Note: Because the rendering of this case is a Demo, only 20 epochs are trained, and the accuracy is general. In practical applications, you can train more to improve the accuracy.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def detect_image(image_path):
    image = Image.open(image_path)
    image = cvtColor(image)
old_img = copy.deepcopy(image)
    ori_h = np.array(image).shape[0]
    ori_w = np.array(image).shape[1]
# resize and add gray placeholders
    image_data, nw, nh = resize_image(image, (INPUT_SHAPE[1], INPUT_SHAPE[0]))
image_data = normalize(np.array(image_data, np.float32))
image_data = np.expand_dims(image_data, 0)
pr = model.predict(image_data)[0]
## remove gray placeholders
    pr = pr[int((INPUT_SHAPE[0] - nh) // 2) : int((INPUT_SHAPE[0] - nh) // 2 + nh), \
            int((INPUT_SHAPE[1] - nw) // 2) : int((INPUT_SHAPE[1] - nw) // 2 + nw)]
pr = cv2.resize(pr, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1)
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
    # for c in range(NUM_CLASSES):
    #     seg_img[:, :, 0] += ((pr[:, :] == c ) * colors[c][0]).astype('uint8')
    #     seg_img[:, :, 1] += ((pr[:, :] == c ) * colors[c][1]).astype('uint8')
    #     seg_img[:, :, 2] += ((pr[:, :] == c ) * colors[c][2]).astype('uint8')
    seg_img = np.reshape(
        np.array(colors, np.uint8)[np.reshape(pr, [-1])], [ori_h, ori_w, -1])
image = Image.fromarray(seg_img)
    image = Image.blend(old_img, image, 0.7)
return image

Summary

This article, through a demo, introduces how to create a data set to train a U-NET model and predict the picture, the entire process. You can find a scene by yourself, make a custom dataset, and practice it again to get a better grasp. That’s it for this article, see you in the next one.

Further reading


CatchZeng
Written by CatchZeng Follow
AI (Machine Learning) and DevOps enthusiast.