Update Time

A few recent projects I’ve worked on have been documented elsewhere but haven’t made it to this blog. The point of this post is to summarize these so that they aren’t lost in the internet void.

AI Art Course

The playlist (you can also start from lesson 1)

Part 2 of AIAIART launched last month. You can see all lessons and a link to the YouTube playlist here: https://github.com/johnowhitaker/aiaiart

Image Generation with CLOOB Conditioned Latent Denoising Diffusion GANs

I had fun trying out a new(ish) approach for text-to-image tasks. The neat thing with conditioning on CLOOB embeddings is that you can train without text captions and still get some text guidance ability at inference time (see image above). This got written up as a nice report on Weights and Biases.

Getting Started with the Microsoft Rice Disease Classification Challenge

Images from the training data

An intro to the latest Zindi challenge with starter code and some thoughts on experiment tracking. You may see more of this at some point – for now, you can read the report here.

Fun with Neural Cellular Automata

Building on lesson 8 of the course, this project involved training various neural cellular automata and figuring out how to make them do tricks like taking a video as a driving signal. I’m particularly pleased with the W&B report for this – I logged interactive HTML previews of the NCAs as shaders as they train, and tracked just about everything during the experiments. I also made a Gradio demo that you can try out right now.

Huggan Projects

So many butterflies

We trained some GANs on butterflies! Have fun with the demo space. I also did a similar version with AI-generated orbs as the training data. I love how easy it is to get a demo running with HF spaces + gradio. Feels like cheating!

Fine-tuning a CLOOB-Conditioned Latent Diffusion Model on WikiArt

Prompt: ‘A sunset landscape painting, oil on canvas’ (fine-tuned Wikiart model)

As part of the Huggingface ‘#huggan’ event, I thought it would be interesting to fine-tune a latent diffusion model on the WikiArt dataset, which (as the name suggests) consists of paintings in various genres and styles.

What is CLOOB-Conditioned Latent Diffusion?

Diffusion models are getting a lot of fame at the moment thanks to GLIDE and DALL-E 2 which have recently rocked the internet with their astounding text-to-image capabilities. They are trained by gradually adding noise to an input image over a series of steps, and having the network predict how to ‘undo’ this process. If we start from pure noise and have the network progressively try to ‘fix’ the image we eventually end up with a nice looking output (if all is working well).

An illustration of this kind of model from the website related to one of the key papers that first outlined this idea.

To add text-to-image capacity to these models, they are often ‘conditioned’ on some representation of the captions that go along with the images. That is, in addition to seeing a noisy image, they also get an encoding of the text describing the image to help in the de-noising step. Starting from noise again but this time giving a description of the desired output image as the text conditioning ideally steers the network towards generating an image that matches the description.

CLOOB architecture diagram (from the project page – which is worth a read!)

Downsides: these diffusion models are computationally intensive to train, and require images with text labels. Latent diffusion models reduce the computational requirements by doing the denoising in the latent space of an autoencoder rather than on images directly. And since CLOOB maps both images and text to the same space, we can substitute the CLOOB encodings of the image itself in place of actual caption encodings if we want to train with unlabelled images. A neat trick if you ask me!

The best non-closed text-to-image implementation at the moment is probably the latent diffusion model trained by the CompVis team, which you can try out here.

Training/Fine-Tuning a model

@JDP provides training code for CLOOB conditioned latent diffusion (https://github.com/JD-P/cloob-latent-diffusion) based on the similar CLIP conditioned diffusion trained by Katherine Crowson (https://github.com/crowsonkb/v-diffusion-pytorch). One of my #huggan team members, Théo Gigant, uploaded the WikiArt dataset to the huggingface hub, and the images were downloaded, resized and saved to a directory on a 2xA6000 GPU machine provided by Paperspace.

After a few false starts figuring out model loading and other little quirks, we did a ~12 hour training run and logged the results using Weights and Biases. You can view demo outputs from the model as it trains in the report, which thanks to the W&B magic showed them live as the model was training, making for exciting viewing among our team 🙂

Evaluating The Resulting Model

WikiArt is not a huge dataset relative to the model (which has over a billion parameters). One of the main things we were curious about was how the resulting model would be different from the one we started with, which was trained on a much larger and more diverse set of images. Has it ‘overfit’ to the point of being unuseable? How much more ‘arty’ do the results look when passing descriptions that don’t necessarily suggest fine art? And has fine-tuning on a relatively ‘clean’ dataset lowered the ability of the model to produce disturbing outputs? To answer these questions, we generated hundreds of images with both models.

I’ve moved the side-by-side comparisons to a gallery at the end of this post. These were the key takeaways for me:

  • Starting from a ‘photorealistic’ autoencoder didn’t stop it from making very painterly outputs. This was useful – we thought we might have to train our own autoencoder first as well.
  • The type of output definitely shifted, almost everything it makes looks like a painting
  • It lost a lot of more general concepts but does really well with styles/artists/image types present in the dataset. So landscape paintings are great, but ‘a frog’ is not going to give anything recognizable and ‘an avocado armchair’ is a complete fail 🙂
  • It may have over-fit, and this seems to have made it much less likely to generate disturbing content (at the expense of also being bad at a lot of other content types).

Closing Thoughts

Approaches like CLOOB-Conditioned Latent Diffusion are bringing down the barrier to entry and making it possible for individuals or small organisations to have a crack at training diffusion models without $$$ of compute.

Our model during training (left) vs OpenAI’s DALL-E 2 (right) which was unveiled during our project and inspired various memes 🙂

This little experiment of ours has shown that it is possible to train one of these models on a relatively small dataset and end up with something that can create pleasing outputs, even if it can’t quite manage an avocado armchair. And as a bonus, it’s domain-focused enough that I’m happily sharing a live demo that anyone can play with online, without worrying that it’ll be used to generate any highly-realistic fake photographs of celebrity nudity or other such nonsense. What a time to be alive!

Comparison images

Sketchy Unet

The model demo running on Huggingface Spaces

I wanted a fast way to go from an image to something like a rough charcoal sketch. This would be the first step in a longer pipeline that would later add detail and colour, so all it has to do is give a starting point with the right sort of proportions.

Finding a dataset

I found a small dataset that seemed like a good starting point (originally created in ‘APDrawingGAN: Generating Artistic Portrait Drawings From Face Photos With Hierarchical GANs‘ by Ran Yi, Yong-Jin Liu, Yu-Kun Lai, Paul L. Rosin). It’s quick to download, and (with a little datablock wrangling) easy enough to load with fastai. See the notebook for details.

Training the model

I chose to model this as an image-to-image task, and used fastai’s unet_learner function to create a U-net style network based on a Resnet34 backbone. Starting with 128px images and then moving up to 224px, the model is trained to minimise the MSE between the output and the reference sketch. In about 3 minutes (!!) we end up with a model that is doing pretty much exactly what I want:

Images (left), artist’s sketch (center), model outputs (right)

Sharing a Demo

I’ve been playing around with HuggingFace Spaces recently, and this model was a great candidate for a simple demo that should run reasonably fast even on a CPU (like those provided by Spaces). At the end of the training notebook you can see the gradio interface code. Very user-friendly for these quick demos! The trained model was uploaded to huggingface as well, and they somehow detected that my code was downloading it because it shows up as a ‘linked model’ from the space.

It’s neat that I can so easily share everything related to a mini-project like this for others to follow along. The colab notebook provides a free cloud environment to replicate training, the model is hosted by someone with lots of bandwidth and is easy to download, and the demo needs no technical skills and lets anyone try it out in seconds. Hooray for fastai, gradio, huggingface and so many others who work so hard to make our lives easy 🙂

Update: What’s this for?

Waterface demo: https://huggingface.co/spaces/johnowhitaker/waterface

I used this model to ‘sketchify’ images before loading them into an imstack and optimising that to match a CLOOB prompt like ‘A charcoal and watercolor sketch of a person’. After a few steps the result looks pretty OR more likely a little creepy. Ah, the power of AI 🙂 Try it out here.

Turtle Recall: A Contrastive Learning Approach

NB: A scoring glitch caused this approach to look very good on the leaderboard, but local validation and a fix from Zindi later confirmed that it isn’t as magical as it first seemed. Still interesting from an educational point of view but if you’re looking to compete I’d suggest investigating alternate strategies.


Zindi has a competition running to identify individual turtles based on images from different views. This presents an interesting challenge for a few reasons:
1) There are relatively few images per turtle (10-50 each) and these have been taken from multiple angles. Given how similar they are, simply treating this as a normal multi-class classification challenge is hard.
2) There is an emphasis on generalization – it would be great if the organizations involved could add additional turtles without expensive re-training of models.

One potential approach that should help address these problems is to learn useful representations – some way to encode an image in a meaningful way such that the representations of images of one individual are all ‘similar’ by some measure while at the same time being dissimilar to the representations of images from other individuals. If we can pull this off, then given a new image we can encode it and compare the resulting representation with those of all known turtle images. This gives a ranked list of the most likely matches as well as a similarity score that could tell us if we’re looking at a completely new turtle.

To keep this post light on code, I have more info and a working example in this colab notebook. I’m also working on a video and will update this post once that’s done. And a modified version of this might be posted on Zindi learn, which again will be linked here once it’s up.

Contrastive Learning

The goal of contrastive learning is to learn these useful representations in an unsupervised or loosely-supervised fashion (aka self-supervised learning). A typical approach is to take some images, create augmented versions of those images and then embed both the originals and the augmented versions with some encoder network. The objective is to maximise the similarity between an image and its augmented version while minimising the similarity between that image and all the rest of the images in the batch. The trick here is that augmentation is used to create two ‘versions’ of an image. In our turtle case, we also have pictures of the same individual from different angles which can be used in place of (or in addition to) image augmentations to get multiple versions depicting one individual.

Top two rows: 16 turtles. Bottom 2 rows: augmented versions of different views of those same 16 turtles.

In my implementation, we generate a batch by picking batch_size turtles and then creating two sets of images with different pictures of those turtles. A resnet50 backbone acts like the encoder and is used to create embeddings of all of these images. We use a contrastive loss function to calculate a loss and update the network weights.

You can check the notebook or the video for more details on the implementation here. Once all the bugs were ironed out, the training loop runs and the loss shrinks nicely over time. But the question arises: how do we tell if the representations being learnt are actually useful?

Key reference for going deeper: SimCLR – A Simple Framework for Contrastive Learning of Visual Representations

Representational Similarity Matrices

Remember, our end goal is to be able to tell which individual turtle is in a new image. If things are working well, we’ll feed the new image through our encoder model to get a representation and then compare that to the encoded representations of the known turtles. All pictures of a given individual should be ‘similar’ in this space, but should not be similar to images of other individuals. A neat way to visualize this is through something called a Representational Similarity matrix. We take, say, 16 images of 5 different turtles. We embed them all and compute all possible pair-wise similarities and then plot them as a heatmap:

A Representation Similarity Matrix (RSM) comparing embeddings of 16 images from each of 5 turtles.

The images are obviously identical to themselves – hence the thin bright diagonal. But here you can also see that images of a given turtle seem to be similar to others of that same turtle – for instance, the bottom right 16×16 square shows that all images of the red turtle are quite similar to each other. This also shows us which turtles might be regularly confused (pink and yellow for eg) and which are relatively easy to disambiguate (pink and green).

RSMs are a useful tool for quickly getting a feel for the kind of representations being learnt, and I think more people should use them to add visual feedback when working on this kind of model. Looking at RSMs for images in the training set vs a validation set, or for different views, can shed more light on how everything is working. Of course, they don’t tell the whole story and we should still do some other evaluations on a validation set.

So does it work?

I trained a model on a few hundred batches with an embedding size of 100. For the test set, I took the turtle_ids of the most similar images in the training set to each test image and used those as the submission. If there were no images with a similarity above 0.8 I added ‘new_turtle’ as the first guess. This scores ~0.4 in local testing and ~0.36 on the public leaderboard. This is pretty good considering we ignored the image_position label, the label balance and various flaws in the data! However, a classification-based baseline with FastAI scores ~0.6 and the top entries are shockingly close to perfect with mapk scores >0.98 so we have a way to go before this is competitive.

One benefit of our approach: adding a new turtle to the database doesn’t require re-training. Instead, we simply encode any images of that individual we have and add the embeddings to the list of possible matches we’ll use when trying to ID new images.

Where Next?

There are many ways to improve on this:

  • Experiment with parameters such as embedding size, batch size, augmentation types, training approach, regularization etc.
  • Incorporate the image_position labels, either doing separate models for different angles, filtering potential matches based on the test labels or finding some way to feed the label into the model as an extra type of conditioning.
  • Experiment with fine-tuning the model on the classification task. Since it has now (theoretically) learnt good representations, we could likely fine-tune it with a classification loss and get even better competition performance (at the cost of lower genaralizability)
  • Explore automated data cleaning. Some images are out-of-domain, showing random background as opposed to turtle faces . Some images are just bad quality, or just don’t work with center-cropping.
  • Try different models as the backbone
  • Investigate label balance

…And many more. I hope this post gets you excited about the competition! Feel free to copy and adapt the notebook (with attribution please) and let me know if you manage to make any improvements. See you on the leaderboard 🙂

AIAIART Course Retrospective

A few weeks ago we wrapped up the first run of ‘AIAIART’, a short course on creating art with deep learning. The course was originally delivered over Discord, but you can access recordings of the lessons on YouTube alongside Colab Notebooks containing the code and examples.

The experience of putting this together and sharing it was highly enjoyable. I always get a kick out of seeing my code or teaching being used by other people to make cool stuff, and our Discord server is a steady stream of fun projects and experiments that make me so happy.

If I had to distil a few key takeaways I’ve gained from this endeavour, they would be

  • Optimization is magic. Set things up so that <some function> uses <some parameters> to produce <some output> which can be evaluated against <some goal> in a differentiable way, and suddenly you can iteratively update those parameters bit by bit until (if all goes well) you achieve said goal. The secret here is that code for updating an image to look more like a description is practically identical to the code for updating the parameters of a neural network to solve some complicated task. And so while we were busy making art, everyone was secretly learning the much broader skill of solving problems with optimization 🙂
  • You don’t need a PhD to dabble with deep learning. Quite a few students had been playing with various AI art models but hadn’t been able to dig in and understand the code or inner workings. But once we started building up from simple examples, it suddenly ‘clicked’ and what was previously intimidating walls of code became fancier versions of the patterns we’d already seen again and again.
  • I really like teaching. Seeing that ‘aha’ moment makes me so happy – I’m going to have to find ways to do more of this 🙂
  • People are SO COOL! I love seeing how different people can see the same material and get inspired to create wildly different things.
  • AI art is SO COOL! We’re still at the beginning of this movement, but already there are such powerful and amazing models and techniques available to us. With a little bit of tinkering you cna learn how to make them sing, and the results can be simply stunning. I look forward to seeing where the next few generations of tech take us.

Anyway, that’s about all I have for this post. Check out the videos or come and hang out in the discord to see what we’re playing with next, and stay tuned since I might turn this V1 course into something a little more polished over the Christmas holidays. Happy arting – J

Playing with Tweet Sentiment Analysis

The average sentiment of the most recent 200 tweets from each country’s capital city.

A mentee of mine has been working on web scraping for NLP projects and her most recent target was Twitter. She’s working on something cool (stay tuned) but in the meantime, I thought I’d share a few of my own experiments. You can follow along and see full code examples in this colab notebook.

Scraping Tweets with Twint

Scraping tweets from a specific user

I used twint – a scraper written in Python which gives a lot of functionality while avoiding the need for API keys, authentication etc. You can target specific users, locations, topics and dates (see their wiki for details) which makes this a powerful tool for finding and downloading tweets. For my tests today, I chose a few well-known Twitter personalities from my feed. I also scraped tweets from capital cities around the world, using the ‘Lang’ configuration option to focus on English tweets to make comparison easier (yes, I know, this is not ideal).

Sentiment Score with roBERTa

NLTK’s SIA can give a quick and easy sentiment score for a piece of text, but many tweets use more obscure language and styles that aren’t well-captured by the default lexicon or the approach as a whole. Luckily, tweet sentiment analysis is a popular task and there are pre-trained deep learning models available that do a pretty good job out-of-the-box. I used a roBERTa model fine-tuned on the TweetEval task. The model card on huggingface had all the code needed to classify a piece of text, making it very simple to get started. I’m so glad this trend of making models accessible with key info is catching on!

The model outputs three scores corresponding to the labels ‘negative’, ‘neutral’ and ‘positive’. We can combine the positive and negative scores to get a combined sentiment score running from -1 (very negative) to +1 (very positive). From this, we can get stats like ‘average sentiment’, but I wanted a better way to see at a glance what a user’s tweets look like. Hexbin plots to the rescue 🙂 These show the distribution of tweets in both sentiment and tweet length. You can see that Musk tends to tweet shorter, more neutral tweets while Gates favours mid-length positive ones and Lomborg tends heavily towards grumpy full-length rants 😂

Scoring Countries

I was curious: what would we see if we grabbed some tweets from the capital city of each country and found the average sentiment score? Where do the positive tweeters live? Ideally, we’d account for different languages, grab a wide selection of tweets covering a longer timeline and do all sorts of other careful analyses. But since this entire project is the result of one night’s insomnia I just grabbed the latest 200 English tweets from each country’s capital (using the countryinfo library to get the coordinates) and went with those. Plotting the average sentiment as a choropleth map using Plotly gives us the title image of this post. Don’t read too much into this – it’s just a demo to show what might be possible with a bit more work.


Data Science gives us the tools to ask questions about the world around us. And thanks to the kind folks who put so much effort into the libraries and tools we can access for free, it doesn’t have to be hard! I hope this post inspires you to ask your own questions. Feel free to modify and share the code, and PLEASE tag me on Twitter @johnowhitaker with your own visualizations and extensions. Happy scraping 🙂

EDIT: I made a Huggingface space where you can try this for yourself: https://huggingface.co/spaces/johnowhitaker/twitter_viz

WhistleGen: Generating Traditional Irish music with ML

Video overview of this project – do check out my channel if you like this!

Earlier this year I did an experiment where I tried to write some code on a small, atomic project every day. The results are documented at https://johnowhitaker.github.io/days_of_code/. In this post I want to share one of my favorite little diversions – my attempt at teaching a computer to compose some whistle music!

Getting the Data

To train a model we will need some data. Previous attempts at music generation have worked on midi, or raw audio. However, a lot of Irish music is shared in a simplified form called ‘ABC Notation’ using letters and a limited set of symbols to encode the essential melody and leaving embellishments, harmonies and accents largely up to the interpretation of the player. thesession.org is one large central repository of these tunes, but I couldn’t find an easy way to download them in bulk. Web Scraping to the rescue!

A neat(ish) dataset of tunes in ABC notation

You can see the code and details here. Web scraping is one of those cases where there are many valid approaches one could take, but all of them in essence boil down to identifying ways of identifying the specific parts of the html code that surround the data you are interested in. For example, on a page of results from thesession each song is listed as a list item taking the form <li class="manifest-item">. With a bit of patience we can get URLs for each tune and then scrape the relevant info from those URLs with some more effort. At the end of this process, a nice neat dataframe with the titles, metadata and note sequences.


We’re going to train a ‘language mode’ – a concept from the field of NLP, where a model (usually an LSTM or transformer – based architecture) tries to predict the next token in a sequence, allowing it to learn from unstructured data such as large chunks of text or, in this case, music. The end result of this is a generative model that can ‘autocomplete’ sequences. These language models can then be re-purposed for classification, translation etc. but in this case we want a generative model so that is unnecessary.

The text needs to be tokenized. We can simply split into individual characters, but since the notation includes ‘note modifiers’ such as ‘=’ which are sometimes placed before a note to sharpen or flatten it and some other 2-character symbols (like ‘|:’ for the start of a bar with a repeat), I chose to build a custom tokenizer. The notebook shows how to construct fastai dataloaders that package everything up neatly ready for training.

Once the dataloaders are ready, we can simply train this like any other language model. I used the learning rate finder (output shown above) to pick an initial learning rate and then, following the example in the fastai docs, gradually unfroze the model and continued to train it. After a few minutes the model is predicting the next token with ~38% accuracy!

Getting Some Output

Some early WhistleGen output

We can feed our model a few tokens and ask it to continue making guesses for the next token in the sequence: learn.predict('|:G', 100, temperature=0.7). The temperature parameter controls how ‘conservative’ the model is; higher values result in output with more randomness. To convert the string of letters that the model spews out into playable music, I used this handy online editor to preview, edit and download the songs.

The model is OK at guessing sensible notes, but it doesn’t produce much in the way of song structure. I found it best to use the output as a starting point, tweaking the odd bit of timing and adding repeats, separate parts and the odd extra flourish to create a song that is effectively co-written by myself and my AI assistant. It’s surprisingly fun! I hope this inspires you to try something like this yourself – do let me know what you create.

In Brief: Playing with Class Imbalance

We often encounter imbalanced data in the world of machine learning, and have to decide how best to handle this. In ‘real life’ it is up to us to decide how to evaluate performance, which types of errors we care the most about and so on. But in the example we’ll look at today, the situation is slightly different: we have an imbalanced training set, and the test set (we’re working with this competition) has had the class distribution modified to make it more balanced. So, we need to find a way to take this into account when submitting predictions. The following plot shows the difference in distributions:

Class distribution of the training set compared to the validation set

It’s worth pointing out that this is showing the results for the validation set – there is an unseen test set that could very well have it’s own slightly different class distribution. There isn’t much to say that wasn’t covered in the notebook, so check that out for implementation details. That said, let’s go over the main strategies we could use:

  1. Do nothing and hope for the best… Not great, but when the imbalance is small then some models are pretty decent at making sensible predictions. This isn’t going to win any competitions though!
  2. Drop some fraction of the majority class. This turned out to work surprisingly well – I suspect this mimics the steps the organizers took when preparing the data.
  3. Generate some extra ‘synthetic’ samples for the under-represented class using Synthetic Minority Oversampling Technique (SMOTE)
  4. Combine the steps 2 and 3, to avoid relying on too much synthetic data. In this case I chose to use the imblearn library’s RandomunderSampler to discard some of the majority class.
  5. Take advantage of the sample_weights parameter available in some models. For example, with Catboost we can explicitly tell the model to assign less weight to samples from the majority class. This lets us use the whole dataset (no need to throw out perfectly good data) and it performed the best in some experiments, loosing only to the basic under-sampling technique in the final assessment.
Dropping 5/6 of the rows from the majority class – a frustratingly successful approach!

Again, check out the notebook for details and code. Here are the results:

StrategyLog Loss (local)
Under-sampling the majority class 0.556998
CatBoost with altered sample weights0.559395
SMOTE + RandomUnderSampler0.579939
No modifications0.674555

The big takeaway here for me was that getting this right makes a huge difference in these types of competition. Without a good strategy even the fanciest model has no hope of matching the top submissions. Fortunately, even basic under-sampling can get great results, and I hope that between my notebook and discussions from others sharing their tips we have an even playing field on this front, allowing competitors to work on the more interesting aspects like feature engineering.

BirdClef Entry: Bird Call Classification with FastAI

The Cornell Lab of Ornithology run an annual competition to identify bird calls in soundscapes. I decided to have a go at this year’s competition to get back into audio classification and try out some new approaches. For this first post I will examine the data, choose methods for picking the right clips within larger recordings and for generating a spectrogram from said clip, and train a simple model to use as a baseline for future experiments.

Finding the Calls

In many recordings, the bird in question is not calling continuously. The final task involves predicting which birds are calling at 5-second intervals, so that is my chosen input length. If we just sample a random 5-second clip from a full recording, we might end up with a clip in which the bird is not calling – not ideal! To get around this, we compute a sort of signal-to-noise measure (in this case, PCEN-based SNR as used by the BirdVox project). With this, we can choose ‘peaks’ where the calls are most prominent.

Identifying ‘peaks’ with a high PCEN-based SNR

The code for this is in my first notebook. For each train file, we store the location of 20 peaks in a csv file which we will than use during training to select the appropriate clips.

Preparing the data for modelling

We could try feeding the raw audio data into a model, but 5 seconds of audio represents quite a lot of data. Some models can handle this, but in most cases a better approach is to find a more compressed representation of the sound. In this case I chose a fairly standard approach: the mel spectrogram. A spectrogram looks like a 2D image, with time on the X axis, frequency on the y axis and intensity represented by colour.

An example spectrogram

The model training notebook shows how we set up the dataloaders to read in a specified clip and turn it into a spectrogram that can be fed to the model. This is quite CPU-heavy, which does slow the training down. But I still chose this approach over pre-computing the spectrograms once at the start because it allows for data augmentation such as shifting the window, adding noise etc on the raw audio before it gets converted to a spectrogram.

You can see all the code in the baseline model notebook. Taking inspiration from the pets tutorial, we create our own custom Transform that handles ‘encoding’ a given clip/label pair, which in turn is used to create our DataLoaders. By adding a ‘decodes’ method we also enable functionality such as ‘show_batch()’.


Loss plot over 3 epochs of training

I’m not doing anything fancy for training – our goal here is a simple model to use as a baseline for future tests. A few things I did learn however:

  • To be able to access the output of a Kaggle notebook, you have to re-run it by clicking ‘save’. This can eat up GPU time, so I have started running my interactive tests with small subsets of the data and then relying on the run triggered by a save to actually do the full training.
  • Because this is then running ‘in the background’, saving any info you need is a must. I use the CSVLogger callback to save the stats after each epoch, and do other things like saving loss plots as pngs. Finally, we save the model itself for future use.
  • With this small model and CPU heavy dataloader, running on CPU was only a couple of times slower than on GPU. Wih a bit of patience, one could simply run this overnight rather than using up your weekly GPU quota, saving the GPU goodness for fast iteration when experimenting. In fact after the save failed a few times I ended up switching off the GPU and letting it train on the CPU over 7 or 8 hours.

Again, full code is in the notebook. After 3 epochs (the number of epochs and the learning rate chosen somewhat arbitrarily) we get to an accuracy of ~53% – impressive given the large number of classes. I’m sure a better model and more training would boost this, but that is something we can play with later…


During training we calculate an ‘accuracy’ score based on some clips withheld from the training data. These all have a single label (even though there may well be other calls mixed in) and they are taken from a different source to the actual test data that we will be scored on in the competition. We would assume a better accuracy in our simplified case will mean a better model, but ideally we want a way to evaluate our model in a setting that is as close as possible to the final task.

Fortunately, the competition hosts have provided some labelled audio recordings that match the format of the test set. We can use this in our evaluation notebook to simulate a submission. Our model needs to provide a list of all bird species calling in a given 5-second clip. The way we will approach this for now is to take the model’s output probabilities and pick some threshold above which we will include a given species.

In the future, we will want to take geographic location into account, as well as ideally training a model directly on this kind of multi-label task. Even without this, our very simple model gets and F1-score of about 0.64 on the provided evaluation set and a leaderboard score of 0.55. The notebook is very rough, but for completeness here is a link.

Conclusions and Next Steps

Our submission scores 0.55, placing 167th on the leaderboard. Not terrible, but there is a ways to go before we are up there near the top. If I manage to spend some time on this, there will hopefully be a part 2 in which I explore ways in which we can get the score boost… Stay tuned for that 🙂

Language Models for Protein Sequence Classification

A 3D model of a protein kinase

We recently hosted a Zindi hackathon in partnership with Instadeep that challenged participants to predict the functional class of protein kinases (enzymes with some very specific functions in the cell) based on nothing but their amino acid sequences. This kind of sequence classification task has lots of potential applications – there is a lot of un-labelled data lying around on every computational biologist’s computer, and a tool that could guess a given protein’s function would be mighty handy.

Just one problem – it’s not exactly a simple task! There are 20-something amino acids which we represent as letters. Given a sequence like ‘AGASGSUFOFBEASASSSSSASBBBDGDBA’ (frantically monkey-types for emphasis) we need to find a way to a) encode this as something a model can make sense of and b) do the making-sense-of-ing! Fortunately, there’s another field where we need to go from a string of letters to something meaningful: Natural Language Processing. Since I’d just been watching the NLP lesson in the latest amazing fastai course I felt obliged to try out the techniques Jeremy was talking about on this sequence classification task.

The Basic Approach

Tokenized input (left) and class (right)

Treating this as a language task and drawing inspiration from ULMFiT[1], this was my basic approach:

  • I tokenized the sequences using ‘subword tokenization’ which captures not just individual amino acids as tokens but common groupings as well (eg ‘EELR’ is encoded as a single token). I think this basic approach was suggested by the SentencePiece paper[4] and it’s now part of fastai[5].
  • I then created a ‘pretext task’ of sequence completion to train a ‘language model’ (based on the AWD-LSTM architecture[2]). The model learns to predict the next token in a sequence with ~32% accuracy – the hope is that in doing so it also learns useful embeddings and some sort of latent understanding of how these sequences are structured.
  • We keep most of this network as the ‘encoder’ but modify the final layers for the actual task: sequence classification. Thanks to the pre-training, the model can very quickly learn the new task. I can get to 98% accuracy in a couple of minutes by training on only a small subset of the data.
  • Training the model for the sequence classification task takes a while on the full competition dataset, but it eventually reaches 99.8% accuracy with a log_loss on the test set (as used in the competition) of 0.08, which is equivalent to 3rd place.
  • Doing the normal tricks of ensembling, training a second model on reversed sequences etc quite easily bumps this up to glory territory, but that’s the boring bit.

It was fun to see how well this worked. You can find a more detailed write-up of the initial experiments on that competition dataset here. Spurred by these early results, I figured it was worth looking into this a little more deeply. What have others been doing on this task? Is this approach any good compared to the SOTA? Has anyone tried this particular flow on this kind of problem?

Getting Formal

It should come as no surprise that the idea of treating sequence classification like a language modelling task has already occurred to some people. For example, USDMProt[7] turns out to have very nearly the same approach as that outlines above (self-five!). Their github is a great resource.

There are other approaches as well – for example, ProtCNN[6] and DEEPPred[8] propose their own deep learning architectures to solve these kinds of tasks. And there are some older approaches such as BLAST and it’s derivatives[9] that have long been standards in this field which still do decently (although they seem to be getting out-performed by these newer techniques).

So, we’re not the first to try this. However, I couldn’t find any papers using anything like the ‘subword’ tokenization. They either use individual amino acids as tokens, or in rare cases some choice of n-grams (for example, triplets of amino acids). The advantage of subword tokenization over these is that it can scale between the complexity of single-acid encodings and massive n-gram approaches by simply adjusting the vocabulary size.

Your Homework

I did some initial tests – this definitely smells promising, but there is a lot of work to do for this to be useful to anyone, and I don’t currently have the time or compute to give it a proper go. If you’re looking for a fun NLP challenge with the potential to turn into some interesting research, this could be the job for you! Here’s my suggestions:

  • Pick one or more benchmarks. Classification of the PFam dataset is a nice one to start with. The ProtCNN paper[6] (quick link) has done a bunch of the ‘standard’ algorithms and shared their split as a kaggle dataset, so you can quickly compare to those results.
  • Get some data for language model training. The SWISSProt dataset is a nice one, and for early tests even just the PFam dataset is enough to try things out.
  • Train some language models. Do single-acid tokenization as a baseline and then try subword tokenization with a few different vocab sizes to compare.
  • See which models do best on the downstream classification task. Lots of experimenting to be done on sequence length, training regime and so on.
  • For bonus points, throw a transformer model or two at this kind of problem. I bet they’d be great, especially if pre-trained on a nice big dataset.
  • If (as I suspect) one of these does very well, document your findings, try everything again in case it was luck and publish it as a blog or, if you’re a masochist, a paper.
  • … profit?

I really hope someone reading this has the motivation to give this a go. If nothing else it’s a great learning project for language modelling and diving into a new domain. Please let me know if you’re interested – I’d love to chat, share ideas and send you the things I have tried. Good luck 🙂


[1] – Howard, J. and Ruder, S., 2018. Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146.

[2] – Merity, S., Keskar, N.S. and Socher, R., 2017. Regularizing and optimizing LSTM language models. arXiv preprint arXiv:1708.02182.

[3] – Smith, L.N., 2017, March. Cyclical learning rates for training neural networks. In 2017 IEEE Winter Conference on Applications of Computer Vision (WACV) (pp. 464-472). IEEE.

[4] – Kudo, T. and Richardson, J., 2018. Sentencepiece: A simple and language independent subword tokenizer and detokenizer for neural text processing. arXiv preprint arXiv:1808.06226.

[5] – Howard, J. and Gugger, S., 2020. Fastai: A layered API for deep learning. Information, 11(2), p.108.

[6] – Bileschi, M.L., Belanger, D., Bryant, D.H., Sanderson, T., Carter, B., Sculley, D., DePristo, M.A. and Colwell, L.J., 2019. Using deep learning to annotate the protein universe. bioRxiv, p.626507. (ProtCNN)

[7] – Strodthoff, N., Wagner, P., Wenzel, M. and Samek, W., 2020. UDSMProt: universal deep sequence models for protein classification. Bioinformatics36(8), pp.2401-2409. (USDMProt)

[8] – Rifaioglu, A.S., Doğan, T., Martin, M.J., Cetin-Atalay, R. and Atalay, V., 2019. DEEPred: automated protein function prediction with multi-task feed-forward deep neural networks. Scientific reports9(1), pp.1-16. (DEEPPred)

[9] – Altschul, S.F., Madden, T.L., Schäffer, A.A., Zhang, J., Zhang, Z., Miller, W. and Lipman, D.J., 1997. Gapped BLAST and PSI-BLAST: a new generation of protein database search programs. Nucleic acids research25(17), pp.3389-3402.