My job involving fullstack dev, testing and validation kept me busy, but I did get to dabble into data analysis to optimize our client's limited server resource allocation.
Getting back into PyTorch, and inspired by the AI in Medicine Lab (https://aimlab.ca/) at UBC, I wanted to get back to my undergraduate roots of machine learning algorithms and review what I had learned.
So, back to image analysis...the underlying technique behind much of the work done by DALL-E, Midjourney and Stable Diffusion is an architecture called U-Net.
The U-Net architecture was originally proposed to tackle medical image segmentation, and in general it works as you would expect:
1) The model gets an input image
2) It makes a guess at segmenting the different aspects of the image
3) We compare the loss and use the error to adjust our model's parameters to improve the output
4) Repeat
Not only U-Net models segment the objects on the image, it is very good at creating near pixel perfect masks around the objects - so it's a bit like a variation of classification as each pixel gets a class it belongs to.
So what makes this convolutional model different and so much better?
The architecture contains two parts. The 'Encoders' on the left, and the 'Decoders' on the right.
The encoders condenses the information on an image on each layer, reducing the size of the image. At every stage as we go down the layers, the size of the image halves but the number of channels double.
'Channel' here refers to the number of feature maps and they double as each convolutional layer applies multiple filters to the input, capturing different aspects of the image.
In the early layers filters are used to capture basic edges, basic textures, etc, and in lower layers the model learns high level features such as object shapes, spatial information, etc.
Before I get into the decoding part, the arrows you going from the left to the right? They signify the feature mapping information that we got from the encoding steps being fed into the corresponding decoders.
These skip connections help restore lost spatial information by combining high-level semantic features from the encoder with the finer details, helping the decoder layers to locate features accurately.
Overall, decoders use transposed convolutions (or upsampling layers with convolutions) to reconstruct the image back to its original size.
These fine details are what helps the models to identify exactly where the objects are after they are classified by the encoders.
If encoders identify the 'what's of the image, decoders identify the 'where's. Encoders extract what is in the image, focusing on feature representation, while decoders reconstruct where objects are, ensuring precise localization.
Top comments (0)