DistilHN: Summarizing News Articles with Transformers

In this series, I’d like to explore how to take an idea within machine learning from proof of concept to production. This first post is going to get things going with a little mini-project that I did in the downtime between Christmas activities, creating a website called DistilHN.com using a bit of machine learning magic and some basic web scraping. Let’s get started.

The DistilHN page

The Idea

I’ve been thinking about how to make a better news feed. When confronted with a clickbait headline, I often want a little more info, but don’t feel like clicking through to the article (and dismissing the cookie popup, and scrolling past the ads, and declining their invite to sign up for the newsletter, and …) just to see what it’s about. So, this is the idea: use AI to generate a short summary that you can read before deciding whether you’re going to commit to the full article or just skip straight to the comments section on Hacker News.

Scraping Text

I started working on a way to get the main text from an arbitrary website using Beautiful Soup, writing heuristics for which elements were worth including or ignoring. It turns out this is a very hard problem! After a while I had something that sort of worked for some sites, but in desperation I decided to take another look around online to see if someone else had already done the hard work.

Extracting text from a website using Trafiltura

Enter the Trafilatura library, purpose-built for this exact task! It makes it super easy to grab the text from any website, as shown in the screenshot above. Aside: all the code shown in this post is also available as a notebook on Google Colab here.

Summarization

For the actual summarization step, I choose to use this model from Facebook which was fine-tuned for news article summarization. You can run it locally with a huggingface pipeline, but I chose to use the free inference API since we’re not going to need to run this thousands of times an hour and we may as well do as little work as possible ourselves! We set up a query, specify the text we want to summarize and the min and max length for the summary, post the request and wait for the summary back.

Summarizing the text with the HuggingFace Inference API

This was a bit of a revelation for me. In the past I’d be downloading and training models as soon as I started a project like this, but here is an existing solution that does the job perfectly. If we want to scale up, Huggingface has paid inference options or we can switch to running the model ourselves. But for this proof-of-concept, the inference API makes our lives easy 🙂

Sharing

It’s one thing to run something like this once in a notebook. To make this a permanent solution, we need a few things:

  • Some server to run a script every hour or so to fetch and summarize the latest articles.
  • A website or something so that we can share our project with others, including a place to host it
  • Ideally, an RSS feed that users can read from their RSS app of choice.

I decided to start by wrapping up the scraping and summarization code into a script and having it write the results to an RSS feed (using the feedgenerator Python library). This way I’d have the content in a known format and a useable output before I start hacking on the front end.

My PythonAnywhere Dashboard – the script has only used ~20 seconds of CPU time so far today!

While you could host something like this yourself on a small VPS, I chose to go the easy route and use a website called PythonAnywhere which handles some of the admin for you. They have a tutorial for hosting a static site and make it easy to upload files like the aforementioned script and set them to run on a schedule. I did end up making a minimal flask app too in case I want to develop this further, but for the initial demo, I just exposed the index.html and feed.xml files to the web via the PythonAnywhere web UI. This is great for getting demos up quickly, and since this is just serving a static site it should scale extremely well.

Speaking of index.html, I made a simple HTML page and modified a Javascript snippet from this tutorial to load in the items from the RSS feed and add them to the page. I’m not particularly comfortable with HTML/CSS so styling this took ages, and it still looks a little clunky. ChatGPT and GitHub CoPilot turned out SUPER useful for this step – I find myself relying on CoPilot’s suggestions much more when working with languages that I am less familiar with, and being able to just type /* Make the image appear at the top, centered */ and then hit tab to get the CSS I needed for something is delightful compared to my usual fiddle->test->google->repeat cycle.

Taking This Further

You can see the final website at https://www.distilhn.com/. I’m quite pleased with how it turned out, even if there are still a few things to iron out. I’m already working on a more ambitious follow-on project, pulling news from across the globe and filtering it using more ML magic… but that will have to wait for a future post 🙂 Until then, have fun with the website, and let me know if you have ideas for improvements! Happy hacking.

Advertisement

How Predictable: Evaluating Song Lyrics with Language Models

I was briefly nerd-sniped this morning by the following tweet:

Can we quantify how ‘predictable’ a set of lyrics are?

Language Models and Token Probabilities

A language model is a neural network trained to predict the next token in a sequence. Specifically, given an input sequence it outputs a probability for each token in its vocabulary. So, given the phrase “Today is a nice ” the model outputs one value for every token, and we can look up the probability associated with the token for “day” – which will likely be fairly high (~0.5 in my tests).

We can look at the probabilities predicted for each successive word in a set of lyrics, and take the average as a measure of ‘predictability’. Here’s the full code I used:

import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

lyrics = """
    And my thoughts stay runnin', runnin' (Runnin')
    The heartbreaks keep comin', comin' (Comin')
    Oh, somebody tell me that I'll be okay
"""
input_ids = tokenizer(lyrics, return_tensors="pt").input_ids
word_probs = []
min_length = 5 # How much do we give to start with

for i in range(min_length, len(input_ids[0])-1):
    ids = input_ids[:,:i]
    with torch.no_grad():
        generated_outputs = gpt2.generate(ids[:,:-1], do_sample=True, output_scores=True,
                                          max_new_tokens=1,
                                          pad_token_id=tokenizer.eos_token_id)
    scores = generated_outputs.scores[0]
    probs = scores.softmax(-1)
    word_probs.append(probs[0][ids[0][-1]])

torch.mean(torch.tensor(word_probs))

My starting point was this post by Patrick Von Platen showing how to generate probabilities per token with GPT-2.

Results

The first test: ‘Remind Me’ by Megan Trainor. The mean probability given by the model for the next word given the lyrics up to that point: 0.58!

Trying a few other songs I could think of with less repetitive lyrics:

  • ‘Levitate’ (21 Pilots): 0.34
  • ‘Mom’s Spaghetti’ (MNM): 0.35
  • The code example above: 0.45
  • I’m Gonna Be (500 Miles)’ (The Proclaimers): 0.59

There is a caveat worth making which is that anything written before 2019 might be in the model’s training data, and so it might ‘know’ the lyrics already making the measure less informative.

Historical Trends

EDIT: Someone (me) didn’t preview their data well enough, the lyrics I used for this were either badly scraped or very processed, so these scores won’t compare well to the previous section and I need to re-do this with a proper dataset before we can say anything concrete about trends!

Plotting the median estimated predictability per decade for a random sample of ~6k songs

I downloaded a bunch of song lyrics via this dataset and sampled some from different years (1950 – 2019). For each, I estimated the predictability as described above. I found very little correlation (correlation coefficient 0.037 EDIT: 0.06 with a larger sample size) between predictability and year released, but there does seem to be a slight uptick in median predictability over time, especially going into the 2010s, which I’m sure will validate those grumbling about ‘music these days’…

Conclusion

This was fun! Go play with the code and see if your least favourite song is actually as predictable as you think it is. Or perhaps run it over the top 100 current hits and see which is best. I should get back to work now, but I hope you’ve enjoyed this little diversion 🙂

Update Time

A few recent projects I’ve worked on have been documented elsewhere but haven’t made it to this blog. The point of this post is to summarize these so that they aren’t lost in the internet void.

AI Art Course

The playlist (you can also start from lesson 1)

Part 2 of AIAIART launched last month. You can see all lessons and a link to the YouTube playlist here: https://github.com/johnowhitaker/aiaiart

Image Generation with CLOOB Conditioned Latent Denoising Diffusion GANs

I had fun trying out a new(ish) approach for text-to-image tasks. The neat thing with conditioning on CLOOB embeddings is that you can train without text captions and still get some text guidance ability at inference time (see image above). This got written up as a nice report on Weights and Biases.

Getting Started with the Microsoft Rice Disease Classification Challenge

Images from the training data

An intro to the latest Zindi challenge with starter code and some thoughts on experiment tracking. You may see more of this at some point – for now, you can read the report here.

Fun with Neural Cellular Automata

Building on lesson 8 of the course, this project involved training various neural cellular automata and figuring out how to make them do tricks like taking a video as a driving signal. I’m particularly pleased with the W&B report for this – I logged interactive HTML previews of the NCAs as shaders as they train, and tracked just about everything during the experiments. I also made a Gradio demo that you can try out right now.

Huggan Projects

So many butterflies

We trained some GANs on butterflies! Have fun with the demo space. I also did a similar version with AI-generated orbs as the training data. I love how easy it is to get a demo running with HF spaces + gradio. Feels like cheating!

Fine-tuning a CLOOB-Conditioned Latent Diffusion Model on WikiArt

Prompt: ‘A sunset landscape painting, oil on canvas’ (fine-tuned Wikiart model)

As part of the Huggingface ‘#huggan’ event, I thought it would be interesting to fine-tune a latent diffusion model on the WikiArt dataset, which (as the name suggests) consists of paintings in various genres and styles.

What is CLOOB-Conditioned Latent Diffusion?

Diffusion models are getting a lot of fame at the moment thanks to GLIDE and DALL-E 2 which have recently rocked the internet with their astounding text-to-image capabilities. They are trained by gradually adding noise to an input image over a series of steps, and having the network predict how to ‘undo’ this process. If we start from pure noise and have the network progressively try to ‘fix’ the image we eventually end up with a nice looking output (if all is working well).

An illustration of this kind of model from the website related to one of the key papers that first outlined this idea.

To add text-to-image capacity to these models, they are often ‘conditioned’ on some representation of the captions that go along with the images. That is, in addition to seeing a noisy image, they also get an encoding of the text describing the image to help in the de-noising step. Starting from noise again but this time giving a description of the desired output image as the text conditioning ideally steers the network towards generating an image that matches the description.

CLOOB architecture diagram (from the project page – which is worth a read!)

Downsides: these diffusion models are computationally intensive to train, and require images with text labels. Latent diffusion models reduce the computational requirements by doing the denoising in the latent space of an autoencoder rather than on images directly. And since CLOOB maps both images and text to the same space, we can substitute the CLOOB encodings of the image itself in place of actual caption encodings if we want to train with unlabelled images. A neat trick if you ask me!

The best non-closed text-to-image implementation at the moment is probably the latent diffusion model trained by the CompVis team, which you can try out here.

Training/Fine-Tuning a model

@JDP provides training code for CLOOB conditioned latent diffusion (https://github.com/JD-P/cloob-latent-diffusion) based on the similar CLIP conditioned diffusion trained by Katherine Crowson (https://github.com/crowsonkb/v-diffusion-pytorch). One of my #huggan team members, Théo Gigant, uploaded the WikiArt dataset to the huggingface hub, and the images were downloaded, resized and saved to a directory on a 2xA6000 GPU machine provided by Paperspace.

After a few false starts figuring out model loading and other little quirks, we did a ~12 hour training run and logged the results using Weights and Biases. You can view demo outputs from the model as it trains in the report, which thanks to the W&B magic showed them live as the model was training, making for exciting viewing among our team 🙂

Evaluating The Resulting Model

WikiArt is not a huge dataset relative to the model (which has over a billion parameters). One of the main things we were curious about was how the resulting model would be different from the one we started with, which was trained on a much larger and more diverse set of images. Has it ‘overfit’ to the point of being unuseable? How much more ‘arty’ do the results look when passing descriptions that don’t necessarily suggest fine art? And has fine-tuning on a relatively ‘clean’ dataset lowered the ability of the model to produce disturbing outputs? To answer these questions, we generated hundreds of images with both models.

I’ve moved the side-by-side comparisons to a gallery at the end of this post. These were the key takeaways for me:

  • Starting from a ‘photorealistic’ autoencoder didn’t stop it from making very painterly outputs. This was useful – we thought we might have to train our own autoencoder first as well.
  • The type of output definitely shifted, almost everything it makes looks like a painting
  • It lost a lot of more general concepts but does really well with styles/artists/image types present in the dataset. So landscape paintings are great, but ‘a frog’ is not going to give anything recognizable and ‘an avocado armchair’ is a complete fail 🙂
  • It may have over-fit, and this seems to have made it much less likely to generate disturbing content (at the expense of also being bad at a lot of other content types).

Closing Thoughts

Approaches like CLOOB-Conditioned Latent Diffusion are bringing down the barrier to entry and making it possible for individuals or small organisations to have a crack at training diffusion models without $$$ of compute.

Our model during training (left) vs OpenAI’s DALL-E 2 (right) which was unveiled during our project and inspired various memes 🙂

This little experiment of ours has shown that it is possible to train one of these models on a relatively small dataset and end up with something that can create pleasing outputs, even if it can’t quite manage an avocado armchair. And as a bonus, it’s domain-focused enough that I’m happily sharing a live demo that anyone can play with online, without worrying that it’ll be used to generate any highly-realistic fake photographs of celebrity nudity or other such nonsense. What a time to be alive!

Comparison images

Sketchy Unet

The model demo running on Huggingface Spaces

I wanted a fast way to go from an image to something like a rough charcoal sketch. This would be the first step in a longer pipeline that would later add detail and colour, so all it has to do is give a starting point with the right sort of proportions.

Finding a dataset

I found a small dataset that seemed like a good starting point (originally created in ‘APDrawingGAN: Generating Artistic Portrait Drawings From Face Photos With Hierarchical GANs‘ by Ran Yi, Yong-Jin Liu, Yu-Kun Lai, Paul L. Rosin). It’s quick to download, and (with a little datablock wrangling) easy enough to load with fastai. See the notebook for details.

Training the model

I chose to model this as an image-to-image task, and used fastai’s unet_learner function to create a U-net style network based on a Resnet34 backbone. Starting with 128px images and then moving up to 224px, the model is trained to minimise the MSE between the output and the reference sketch. In about 3 minutes (!!) we end up with a model that is doing pretty much exactly what I want:

Images (left), artist’s sketch (center), model outputs (right)

Sharing a Demo

I’ve been playing around with HuggingFace Spaces recently, and this model was a great candidate for a simple demo that should run reasonably fast even on a CPU (like those provided by Spaces). At the end of the training notebook you can see the gradio interface code. Very user-friendly for these quick demos! The trained model was uploaded to huggingface as well, and they somehow detected that my code was downloading it because it shows up as a ‘linked model’ from the space.

It’s neat that I can so easily share everything related to a mini-project like this for others to follow along. The colab notebook provides a free cloud environment to replicate training, the model is hosted by someone with lots of bandwidth and is easy to download, and the demo needs no technical skills and lets anyone try it out in seconds. Hooray for fastai, gradio, huggingface and so many others who work so hard to make our lives easy 🙂

Update: What’s this for?

Waterface demo: https://huggingface.co/spaces/johnowhitaker/waterface

I used this model to ‘sketchify’ images before loading them into an imstack and optimising that to match a CLOOB prompt like ‘A charcoal and watercolor sketch of a person’. After a few steps the result looks pretty OR more likely a little creepy. Ah, the power of AI 🙂 Try it out here.

AIAIART Course Retrospective

A few weeks ago we wrapped up the first run of ‘AIAIART’, a short course on creating art with deep learning. The course was originally delivered over Discord, but you can access recordings of the lessons on YouTube alongside Colab Notebooks containing the code and examples.

The experience of putting this together and sharing it was highly enjoyable. I always get a kick out of seeing my code or teaching being used by other people to make cool stuff, and our Discord server is a steady stream of fun projects and experiments that make me so happy.

If I had to distil a few key takeaways I’ve gained from this endeavour, they would be

  • Optimization is magic. Set things up so that <some function> uses <some parameters> to produce <some output> which can be evaluated against <some goal> in a differentiable way, and suddenly you can iteratively update those parameters bit by bit until (if all goes well) you achieve said goal. The secret here is that code for updating an image to look more like a description is practically identical to the code for updating the parameters of a neural network to solve some complicated task. And so while we were busy making art, everyone was secretly learning the much broader skill of solving problems with optimization 🙂
  • You don’t need a PhD to dabble with deep learning. Quite a few students had been playing with various AI art models but hadn’t been able to dig in and understand the code or inner workings. But once we started building up from simple examples, it suddenly ‘clicked’ and what was previously intimidating walls of code became fancier versions of the patterns we’d already seen again and again.
  • I really like teaching. Seeing that ‘aha’ moment makes me so happy – I’m going to have to find ways to do more of this 🙂
  • People are SO COOL! I love seeing how different people can see the same material and get inspired to create wildly different things.
  • AI art is SO COOL! We’re still at the beginning of this movement, but already there are such powerful and amazing models and techniques available to us. With a little bit of tinkering you cna learn how to make them sing, and the results can be simply stunning. I look forward to seeing where the next few generations of tech take us.

Anyway, that’s about all I have for this post. Check out the videos or come and hang out in the discord to see what we’re playing with next, and stay tuned since I might turn this V1 course into something a little more polished over the Christmas holidays. Happy arting – J

WhistleGen: Generating Traditional Irish music with ML

Video overview of this project – do check out my channel if you like this!

Earlier this year I did an experiment where I tried to write some code on a small, atomic project every day. The results are documented at https://johnowhitaker.github.io/days_of_code/. In this post I want to share one of my favorite little diversions – my attempt at teaching a computer to compose some whistle music!

Getting the Data

To train a model we will need some data. Previous attempts at music generation have worked on midi, or raw audio. However, a lot of Irish music is shared in a simplified form called ‘ABC Notation’ using letters and a limited set of symbols to encode the essential melody and leaving embellishments, harmonies and accents largely up to the interpretation of the player. thesession.org is one large central repository of these tunes, but I couldn’t find an easy way to download them in bulk. Web Scraping to the rescue!

A neat(ish) dataset of tunes in ABC notation

You can see the code and details here. Web scraping is one of those cases where there are many valid approaches one could take, but all of them in essence boil down to identifying ways of identifying the specific parts of the html code that surround the data you are interested in. For example, on a page of results from thesession each song is listed as a list item taking the form <li class="manifest-item">. With a bit of patience we can get URLs for each tune and then scrape the relevant info from those URLs with some more effort. At the end of this process, a nice neat dataframe with the titles, metadata and note sequences.

Modelling

We’re going to train a ‘language mode’ – a concept from the field of NLP, where a model (usually an LSTM or transformer – based architecture) tries to predict the next token in a sequence, allowing it to learn from unstructured data such as large chunks of text or, in this case, music. The end result of this is a generative model that can ‘autocomplete’ sequences. These language models can then be re-purposed for classification, translation etc. but in this case we want a generative model so that is unnecessary.

The text needs to be tokenized. We can simply split into individual characters, but since the notation includes ‘note modifiers’ such as ‘=’ which are sometimes placed before a note to sharpen or flatten it and some other 2-character symbols (like ‘|:’ for the start of a bar with a repeat), I chose to build a custom tokenizer. The notebook shows how to construct fastai dataloaders that package everything up neatly ready for training.

Once the dataloaders are ready, we can simply train this like any other language model. I used the learning rate finder (output shown above) to pick an initial learning rate and then, following the example in the fastai docs, gradually unfroze the model and continued to train it. After a few minutes the model is predicting the next token with ~38% accuracy!

Getting Some Output

Some early WhistleGen output

We can feed our model a few tokens and ask it to continue making guesses for the next token in the sequence: learn.predict('|:G', 100, temperature=0.7). The temperature parameter controls how ‘conservative’ the model is; higher values result in output with more randomness. To convert the string of letters that the model spews out into playable music, I used this handy online editor to preview, edit and download the songs.

The model is OK at guessing sensible notes, but it doesn’t produce much in the way of song structure. I found it best to use the output as a starting point, tweaking the odd bit of timing and adding repeats, separate parts and the odd extra flourish to create a song that is effectively co-written by myself and my AI assistant. It’s surprisingly fun! I hope this inspires you to try something like this yourself – do let me know what you create.

In Brief: Playing with Class Imbalance

We often encounter imbalanced data in the world of machine learning, and have to decide how best to handle this. In ‘real life’ it is up to us to decide how to evaluate performance, which types of errors we care the most about and so on. But in the example we’ll look at today, the situation is slightly different: we have an imbalanced training set, and the test set (we’re working with this competition) has had the class distribution modified to make it more balanced. So, we need to find a way to take this into account when submitting predictions. The following plot shows the difference in distributions:

Class distribution of the training set compared to the validation set

It’s worth pointing out that this is showing the results for the validation set – there is an unseen test set that could very well have it’s own slightly different class distribution. There isn’t much to say that wasn’t covered in the notebook, so check that out for implementation details. That said, let’s go over the main strategies we could use:

  1. Do nothing and hope for the best… Not great, but when the imbalance is small then some models are pretty decent at making sensible predictions. This isn’t going to win any competitions though!
  2. Drop some fraction of the majority class. This turned out to work surprisingly well – I suspect this mimics the steps the organizers took when preparing the data.
  3. Generate some extra ‘synthetic’ samples for the under-represented class using Synthetic Minority Oversampling Technique (SMOTE)
  4. Combine the steps 2 and 3, to avoid relying on too much synthetic data. In this case I chose to use the imblearn library’s RandomunderSampler to discard some of the majority class.
  5. Take advantage of the sample_weights parameter available in some models. For example, with Catboost we can explicitly tell the model to assign less weight to samples from the majority class. This lets us use the whole dataset (no need to throw out perfectly good data) and it performed the best in some experiments, loosing only to the basic under-sampling technique in the final assessment.
Dropping 5/6 of the rows from the majority class – a frustratingly successful approach!

Again, check out the notebook for details and code. Here are the results:

StrategyLog Loss (local)
Under-sampling the majority class 0.556998
CatBoost with altered sample weights0.559395
SMOTE + RandomUnderSampler0.579939
No modifications0.674555
Results

The big takeaway here for me was that getting this right makes a huge difference in these types of competition. Without a good strategy even the fanciest model has no hope of matching the top submissions. Fortunately, even basic under-sampling can get great results, and I hope that between my notebook and discussions from others sharing their tips we have an even playing field on this front, allowing competitors to work on the more interesting aspects like feature engineering.

BirdClef Entry: Bird Call Classification with FastAI

The Cornell Lab of Ornithology run an annual competition to identify bird calls in soundscapes. I decided to have a go at this year’s competition to get back into audio classification and try out some new approaches. For this first post I will examine the data, choose methods for picking the right clips within larger recordings and for generating a spectrogram from said clip, and train a simple model to use as a baseline for future experiments.

Finding the Calls

In many recordings, the bird in question is not calling continuously. The final task involves predicting which birds are calling at 5-second intervals, so that is my chosen input length. If we just sample a random 5-second clip from a full recording, we might end up with a clip in which the bird is not calling – not ideal! To get around this, we compute a sort of signal-to-noise measure (in this case, PCEN-based SNR as used by the BirdVox project). With this, we can choose ‘peaks’ where the calls are most prominent.

Identifying ‘peaks’ with a high PCEN-based SNR

The code for this is in my first notebook. For each train file, we store the location of 20 peaks in a csv file which we will than use during training to select the appropriate clips.

Preparing the data for modelling

We could try feeding the raw audio data into a model, but 5 seconds of audio represents quite a lot of data. Some models can handle this, but in most cases a better approach is to find a more compressed representation of the sound. In this case I chose a fairly standard approach: the mel spectrogram. A spectrogram looks like a 2D image, with time on the X axis, frequency on the y axis and intensity represented by colour.

An example spectrogram

The model training notebook shows how we set up the dataloaders to read in a specified clip and turn it into a spectrogram that can be fed to the model. This is quite CPU-heavy, which does slow the training down. But I still chose this approach over pre-computing the spectrograms once at the start because it allows for data augmentation such as shifting the window, adding noise etc on the raw audio before it gets converted to a spectrogram.

You can see all the code in the baseline model notebook. Taking inspiration from the pets tutorial, we create our own custom Transform that handles ‘encoding’ a given clip/label pair, which in turn is used to create our DataLoaders. By adding a ‘decodes’ method we also enable functionality such as ‘show_batch()’.

Training

Loss plot over 3 epochs of training

I’m not doing anything fancy for training – our goal here is a simple model to use as a baseline for future tests. A few things I did learn however:

  • To be able to access the output of a Kaggle notebook, you have to re-run it by clicking ‘save’. This can eat up GPU time, so I have started running my interactive tests with small subsets of the data and then relying on the run triggered by a save to actually do the full training.
  • Because this is then running ‘in the background’, saving any info you need is a must. I use the CSVLogger callback to save the stats after each epoch, and do other things like saving loss plots as pngs. Finally, we save the model itself for future use.
  • With this small model and CPU heavy dataloader, running on CPU was only a couple of times slower than on GPU. Wih a bit of patience, one could simply run this overnight rather than using up your weekly GPU quota, saving the GPU goodness for fast iteration when experimenting. In fact after the save failed a few times I ended up switching off the GPU and letting it train on the CPU over 7 or 8 hours.

Again, full code is in the notebook. After 3 epochs (the number of epochs and the learning rate chosen somewhat arbitrarily) we get to an accuracy of ~53% – impressive given the large number of classes. I’m sure a better model and more training would boost this, but that is something we can play with later…

Evaluation

During training we calculate an ‘accuracy’ score based on some clips withheld from the training data. These all have a single label (even though there may well be other calls mixed in) and they are taken from a different source to the actual test data that we will be scored on in the competition. We would assume a better accuracy in our simplified case will mean a better model, but ideally we want a way to evaluate our model in a setting that is as close as possible to the final task.

Fortunately, the competition hosts have provided some labelled audio recordings that match the format of the test set. We can use this in our evaluation notebook to simulate a submission. Our model needs to provide a list of all bird species calling in a given 5-second clip. The way we will approach this for now is to take the model’s output probabilities and pick some threshold above which we will include a given species.

In the future, we will want to take geographic location into account, as well as ideally training a model directly on this kind of multi-label task. Even without this, our very simple model gets and F1-score of about 0.64 on the provided evaluation set and a leaderboard score of 0.55. The notebook is very rough, but for completeness here is a link.

Conclusions and Next Steps

Our submission scores 0.55, placing 167th on the leaderboard. Not terrible, but there is a ways to go before we are up there near the top. If I manage to spend some time on this, there will hopefully be a part 2 in which I explore ways in which we can get the score boost… Stay tuned for that 🙂

Language Models for Protein Sequence Classification

A 3D model of a protein kinase

We recently hosted a Zindi hackathon in partnership with Instadeep that challenged participants to predict the functional class of protein kinases (enzymes with some very specific functions in the cell) based on nothing but their amino acid sequences. This kind of sequence classification task has lots of potential applications – there is a lot of un-labelled data lying around on every computational biologist’s computer, and a tool that could guess a given protein’s function would be mighty handy.

Just one problem – it’s not exactly a simple task! There are 20-something amino acids which we represent as letters. Given a sequence like ‘AGASGSUFOFBEASASSSSSASBBBDGDBA’ (frantically monkey-types for emphasis) we need to find a way to a) encode this as something a model can make sense of and b) do the making-sense-of-ing! Fortunately, there’s another field where we need to go from a string of letters to something meaningful: Natural Language Processing. Since I’d just been watching the NLP lesson in the latest amazing fastai course I felt obliged to try out the techniques Jeremy was talking about on this sequence classification task.

The Basic Approach

Tokenized input (left) and class (right)

Treating this as a language task and drawing inspiration from ULMFiT[1], this was my basic approach:

  • I tokenized the sequences using ‘subword tokenization’ which captures not just individual amino acids as tokens but common groupings as well (eg ‘EELR’ is encoded as a single token). I think this basic approach was suggested by the SentencePiece paper[4] and it’s now part of fastai[5].
  • I then created a ‘pretext task’ of sequence completion to train a ‘language model’ (based on the AWD-LSTM architecture[2]). The model learns to predict the next token in a sequence with ~32% accuracy – the hope is that in doing so it also learns useful embeddings and some sort of latent understanding of how these sequences are structured.
  • We keep most of this network as the ‘encoder’ but modify the final layers for the actual task: sequence classification. Thanks to the pre-training, the model can very quickly learn the new task. I can get to 98% accuracy in a couple of minutes by training on only a small subset of the data.
  • Training the model for the sequence classification task takes a while on the full competition dataset, but it eventually reaches 99.8% accuracy with a log_loss on the test set (as used in the competition) of 0.08, which is equivalent to 3rd place.
  • Doing the normal tricks of ensembling, training a second model on reversed sequences etc quite easily bumps this up to glory territory, but that’s the boring bit.

It was fun to see how well this worked. You can find a more detailed write-up of the initial experiments on that competition dataset here. Spurred by these early results, I figured it was worth looking into this a little more deeply. What have others been doing on this task? Is this approach any good compared to the SOTA? Has anyone tried this particular flow on this kind of problem?

Getting Formal

It should come as no surprise that the idea of treating sequence classification like a language modelling task has already occurred to some people. For example, USDMProt[7] turns out to have very nearly the same approach as that outlines above (self-five!). Their github is a great resource.

There are other approaches as well – for example, ProtCNN[6] and DEEPPred[8] propose their own deep learning architectures to solve these kinds of tasks. And there are some older approaches such as BLAST and it’s derivatives[9] that have long been standards in this field which still do decently (although they seem to be getting out-performed by these newer techniques).

So, we’re not the first to try this. However, I couldn’t find any papers using anything like the ‘subword’ tokenization. They either use individual amino acids as tokens, or in rare cases some choice of n-grams (for example, triplets of amino acids). The advantage of subword tokenization over these is that it can scale between the complexity of single-acid encodings and massive n-gram approaches by simply adjusting the vocabulary size.

Your Homework

I did some initial tests – this definitely smells promising, but there is a lot of work to do for this to be useful to anyone, and I don’t currently have the time or compute to give it a proper go. If you’re looking for a fun NLP challenge with the potential to turn into some interesting research, this could be the job for you! Here’s my suggestions:

  • Pick one or more benchmarks. Classification of the PFam dataset is a nice one to start with. The ProtCNN paper[6] (quick link) has done a bunch of the ‘standard’ algorithms and shared their split as a kaggle dataset, so you can quickly compare to those results.
  • Get some data for language model training. The SWISSProt dataset is a nice one, and for early tests even just the PFam dataset is enough to try things out.
  • Train some language models. Do single-acid tokenization as a baseline and then try subword tokenization with a few different vocab sizes to compare.
  • See which models do best on the downstream classification task. Lots of experimenting to be done on sequence length, training regime and so on.
  • For bonus points, throw a transformer model or two at this kind of problem. I bet they’d be great, especially if pre-trained on a nice big dataset.
  • If (as I suspect) one of these does very well, document your findings, try everything again in case it was luck and publish it as a blog or, if you’re a masochist, a paper.
  • … profit?

I really hope someone reading this has the motivation to give this a go. If nothing else it’s a great learning project for language modelling and diving into a new domain. Please let me know if you’re interested – I’d love to chat, share ideas and send you the things I have tried. Good luck 🙂

References

[1] – Howard, J. and Ruder, S., 2018. Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146.

[2] – Merity, S., Keskar, N.S. and Socher, R., 2017. Regularizing and optimizing LSTM language models. arXiv preprint arXiv:1708.02182.

[3] – Smith, L.N., 2017, March. Cyclical learning rates for training neural networks. In 2017 IEEE Winter Conference on Applications of Computer Vision (WACV) (pp. 464-472). IEEE.

[4] – Kudo, T. and Richardson, J., 2018. Sentencepiece: A simple and language independent subword tokenizer and detokenizer for neural text processing. arXiv preprint arXiv:1808.06226.

[5] – Howard, J. and Gugger, S., 2020. Fastai: A layered API for deep learning. Information, 11(2), p.108.

[6] – Bileschi, M.L., Belanger, D., Bryant, D.H., Sanderson, T., Carter, B., Sculley, D., DePristo, M.A. and Colwell, L.J., 2019. Using deep learning to annotate the protein universe. bioRxiv, p.626507. (ProtCNN)

[7] – Strodthoff, N., Wagner, P., Wenzel, M. and Samek, W., 2020. UDSMProt: universal deep sequence models for protein classification. Bioinformatics36(8), pp.2401-2409. (USDMProt)

[8] – Rifaioglu, A.S., Doğan, T., Martin, M.J., Cetin-Atalay, R. and Atalay, V., 2019. DEEPred: automated protein function prediction with multi-task feed-forward deep neural networks. Scientific reports9(1), pp.1-16. (DEEPPred)

[9] – Altschul, S.F., Madden, T.L., Schäffer, A.A., Zhang, J., Zhang, Z., Miller, W. and Lipman, D.J., 1997. Gapped BLAST and PSI-BLAST: a new generation of protein database search programs. Nucleic acids research25(17), pp.3389-3402.