WhistleGen: Generating Traditional Irish music with ML

Video overview of this project – do check out my channel if you like this!

Earlier this year I did an experiment where I tried to write some code on a small, atomic project every day. The results are documented at https://johnowhitaker.github.io/days_of_code/. In this post I want to share one of my favorite little diversions – my attempt at teaching a computer to compose some whistle music!

Getting the Data

To train a model we will need some data. Previous attempts at music generation have worked on midi, or raw audio. However, a lot of Irish music is shared in a simplified form called ‘ABC Notation’ using letters and a limited set of symbols to encode the essential melody and leaving embellishments, harmonies and accents largely up to the interpretation of the player. thesession.org is one large central repository of these tunes, but I couldn’t find an easy way to download them in bulk. Web Scraping to the rescue!

A neat(ish) dataset of tunes in ABC notation

You can see the code and details here. Web scraping is one of those cases where there are many valid approaches one could take, but all of them in essence boil down to identifying ways of identifying the specific parts of the html code that surround the data you are interested in. For example, on a page of results from thesession each song is listed as a list item taking the form <li class="manifest-item">. With a bit of patience we can get URLs for each tune and then scrape the relevant info from those URLs with some more effort. At the end of this process, a nice neat dataframe with the titles, metadata and note sequences.

Modelling

We’re going to train a ‘language mode’ – a concept from the field of NLP, where a model (usually an LSTM or transformer – based architecture) tries to predict the next token in a sequence, allowing it to learn from unstructured data such as large chunks of text or, in this case, music. The end result of this is a generative model that can ‘autocomplete’ sequences. These language models can then be re-purposed for classification, translation etc. but in this case we want a generative model so that is unnecessary.

The text needs to be tokenized. We can simply split into individual characters, but since the notation includes ‘note modifiers’ such as ‘=’ which are sometimes placed before a note to sharpen or flatten it and some other 2-character symbols (like ‘|:’ for the start of a bar with a repeat), I chose to build a custom tokenizer. The notebook shows how to construct fastai dataloaders that package everything up neatly ready for training.

Once the dataloaders are ready, we can simply train this like any other language model. I used the learning rate finder (output shown above) to pick an initial learning rate and then, following the example in the fastai docs, gradually unfroze the model and continued to train it. After a few minutes the model is predicting the next token with ~38% accuracy!

Getting Some Output

Some early WhistleGen output

We can feed our model a few tokens and ask it to continue making guesses for the next token in the sequence: learn.predict('|:G', 100, temperature=0.7). The temperature parameter controls how ‘conservative’ the model is; higher values result in output with more randomness. To convert the string of letters that the model spews out into playable music, I used this handy online editor to preview, edit and download the songs.

The model is OK at guessing sensible notes, but it doesn’t produce much in the way of song structure. I found it best to use the output as a starting point, tweaking the odd bit of timing and adding repeats, separate parts and the odd extra flourish to create a song that is effectively co-written by myself and my AI assistant. It’s surprisingly fun! I hope this inspires you to try something like this yourself – do let me know what you create.

In Brief: Playing with Class Imbalance

We often encounter imbalanced data in the world of machine learning, and have to decide how best to handle this. In ‘real life’ it is up to us to decide how to evaluate performance, which types of errors we care the most about and so on. But in the example we’ll look at today, the situation is slightly different: we have an imbalanced training set, and the test set (we’re working with this competition) has had the class distribution modified to make it more balanced. So, we need to find a way to take this into account when submitting predictions. The following plot shows the difference in distributions:

Class distribution of the training set compared to the validation set

It’s worth pointing out that this is showing the results for the validation set – there is an unseen test set that could very well have it’s own slightly different class distribution. There isn’t much to say that wasn’t covered in the notebook, so check that out for implementation details. That said, let’s go over the main strategies we could use:

  1. Do nothing and hope for the best… Not great, but when the imbalance is small then some models are pretty decent at making sensible predictions. This isn’t going to win any competitions though!
  2. Drop some fraction of the majority class. This turned out to work surprisingly well – I suspect this mimics the steps the organizers took when preparing the data.
  3. Generate some extra ‘synthetic’ samples for the under-represented class using Synthetic Minority Oversampling Technique (SMOTE)
  4. Combine the steps 2 and 3, to avoid relying on too much synthetic data. In this case I chose to use the imblearn library’s RandomunderSampler to discard some of the majority class.
  5. Take advantage of the sample_weights parameter available in some models. For example, with Catboost we can explicitly tell the model to assign less weight to samples from the majority class. This lets us use the whole dataset (no need to throw out perfectly good data) and it performed the best in some experiments, loosing only to the basic under-sampling technique in the final assessment.
Dropping 5/6 of the rows from the majority class – a frustratingly successful approach!

Again, check out the notebook for details and code. Here are the results:

StrategyLog Loss (local)
Under-sampling the majority class 0.556998
CatBoost with altered sample weights0.559395
SMOTE + RandomUnderSampler0.579939
No modifications0.674555
Results

The big takeaway here for me was that getting this right makes a huge difference in these types of competition. Without a good strategy even the fanciest model has no hope of matching the top submissions. Fortunately, even basic under-sampling can get great results, and I hope that between my notebook and discussions from others sharing their tips we have an even playing field on this front, allowing competitors to work on the more interesting aspects like feature engineering.

BirdClef Entry: Bird Call Classification with FastAI

The Cornell Lab of Ornithology run an annual competition to identify bird calls in soundscapes. I decided to have a go at this year’s competition to get back into audio classification and try out some new approaches. For this first post I will examine the data, choose methods for picking the right clips within larger recordings and for generating a spectrogram from said clip, and train a simple model to use as a baseline for future experiments.

Finding the Calls

In many recordings, the bird in question is not calling continuously. The final task involves predicting which birds are calling at 5-second intervals, so that is my chosen input length. If we just sample a random 5-second clip from a full recording, we might end up with a clip in which the bird is not calling – not ideal! To get around this, we compute a sort of signal-to-noise measure (in this case, PCEN-based SNR as used by the BirdVox project). With this, we can choose ‘peaks’ where the calls are most prominent.

Identifying ‘peaks’ with a high PCEN-based SNR

The code for this is in my first notebook. For each train file, we store the location of 20 peaks in a csv file which we will than use during training to select the appropriate clips.

Preparing the data for modelling

We could try feeding the raw audio data into a model, but 5 seconds of audio represents quite a lot of data. Some models can handle this, but in most cases a better approach is to find a more compressed representation of the sound. In this case I chose a fairly standard approach: the mel spectrogram. A spectrogram looks like a 2D image, with time on the X axis, frequency on the y axis and intensity represented by colour.

An example spectrogram

The model training notebook shows how we set up the dataloaders to read in a specified clip and turn it into a spectrogram that can be fed to the model. This is quite CPU-heavy, which does slow the training down. But I still chose this approach over pre-computing the spectrograms once at the start because it allows for data augmentation such as shifting the window, adding noise etc on the raw audio before it gets converted to a spectrogram.

You can see all the code in the baseline model notebook. Taking inspiration from the pets tutorial, we create our own custom Transform that handles ‘encoding’ a given clip/label pair, which in turn is used to create our DataLoaders. By adding a ‘decodes’ method we also enable functionality such as ‘show_batch()’.

Training

Loss plot over 3 epochs of training

I’m not doing anything fancy for training – our goal here is a simple model to use as a baseline for future tests. A few things I did learn however:

  • To be able to access the output of a Kaggle notebook, you have to re-run it by clicking ‘save’. This can eat up GPU time, so I have started running my interactive tests with small subsets of the data and then relying on the run triggered by a save to actually do the full training.
  • Because this is then running ‘in the background’, saving any info you need is a must. I use the CSVLogger callback to save the stats after each epoch, and do other things like saving loss plots as pngs. Finally, we save the model itself for future use.
  • With this small model and CPU heavy dataloader, running on CPU was only a couple of times slower than on GPU. Wih a bit of patience, one could simply run this overnight rather than using up your weekly GPU quota, saving the GPU goodness for fast iteration when experimenting. In fact after the save failed a few times I ended up switching off the GPU and letting it train on the CPU over 7 or 8 hours.

Again, full code is in the notebook. After 3 epochs (the number of epochs and the learning rate chosen somewhat arbitrarily) we get to an accuracy of ~53% – impressive given the large number of classes. I’m sure a better model and more training would boost this, but that is something we can play with later…

Evaluation

During training we calculate an ‘accuracy’ score based on some clips withheld from the training data. These all have a single label (even though there may well be other calls mixed in) and they are taken from a different source to the actual test data that we will be scored on in the competition. We would assume a better accuracy in our simplified case will mean a better model, but ideally we want a way to evaluate our model in a setting that is as close as possible to the final task.

Fortunately, the competition hosts have provided some labelled audio recordings that match the format of the test set. We can use this in our evaluation notebook to simulate a submission. Our model needs to provide a list of all bird species calling in a given 5-second clip. The way we will approach this for now is to take the model’s output probabilities and pick some threshold above which we will include a given species.

In the future, we will want to take geographic location into account, as well as ideally training a model directly on this kind of multi-label task. Even without this, our very simple model gets and F1-score of about 0.64 on the provided evaluation set and a leaderboard score of 0.55. The notebook is very rough, but for completeness here is a link.

Conclusions and Next Steps

Our submission scores 0.55, placing 167th on the leaderboard. Not terrible, but there is a ways to go before we are up there near the top. If I manage to spend some time on this, there will hopefully be a part 2 in which I explore ways in which we can get the score boost… Stay tuned for that 🙂

Language Models for Protein Sequence Classification

A 3D model of a protein kinase

We recently hosted a Zindi hackathon in partnership with Instadeep that challenged participants to predict the functional class of protein kinases (enzymes with some very specific functions in the cell) based on nothing but their amino acid sequences. This kind of sequence classification task has lots of potential applications – there is a lot of un-labelled data lying around on every computational biologist’s computer, and a tool that could guess a given protein’s function would be mighty handy.

Just one problem – it’s not exactly a simple task! There are 20-something amino acids which we represent as letters. Given a sequence like ‘AGASGSUFOFBEASASSSSSASBBBDGDBA’ (frantically monkey-types for emphasis) we need to find a way to a) encode this as something a model can make sense of and b) do the making-sense-of-ing! Fortunately, there’s another field where we need to go from a string of letters to something meaningful: Natural Language Processing. Since I’d just been watching the NLP lesson in the latest amazing fastai course I felt obliged to try out the techniques Jeremy was talking about on this sequence classification task.

The Basic Approach

Tokenized input (left) and class (right)

Treating this as a language task and drawing inspiration from ULMFiT[1], this was my basic approach:

  • I tokenized the sequences using ‘subword tokenization’ which captures not just individual amino acids as tokens but common groupings as well (eg ‘EELR’ is encoded as a single token). I think this basic approach was suggested by the SentencePiece paper[4] and it’s now part of fastai[5].
  • I then created a ‘pretext task’ of sequence completion to train a ‘language model’ (based on the AWD-LSTM architecture[2]). The model learns to predict the next token in a sequence with ~32% accuracy – the hope is that in doing so it also learns useful embeddings and some sort of latent understanding of how these sequences are structured.
  • We keep most of this network as the ‘encoder’ but modify the final layers for the actual task: sequence classification. Thanks to the pre-training, the model can very quickly learn the new task. I can get to 98% accuracy in a couple of minutes by training on only a small subset of the data.
  • Training the model for the sequence classification task takes a while on the full competition dataset, but it eventually reaches 99.8% accuracy with a log_loss on the test set (as used in the competition) of 0.08, which is equivalent to 3rd place.
  • Doing the normal tricks of ensembling, training a second model on reversed sequences etc quite easily bumps this up to glory territory, but that’s the boring bit.

It was fun to see how well this worked. You can find a more detailed write-up of the initial experiments on that competition dataset here. Spurred by these early results, I figured it was worth looking into this a little more deeply. What have others been doing on this task? Is this approach any good compared to the SOTA? Has anyone tried this particular flow on this kind of problem?

Getting Formal

It should come as no surprise that the idea of treating sequence classification like a language modelling task has already occurred to some people. For example, USDMProt[7] turns out to have very nearly the same approach as that outlines above (self-five!). Their github is a great resource.

There are other approaches as well – for example, ProtCNN[6] and DEEPPred[8] propose their own deep learning architectures to solve these kinds of tasks. And there are some older approaches such as BLAST and it’s derivatives[9] that have long been standards in this field which still do decently (although they seem to be getting out-performed by these newer techniques).

So, we’re not the first to try this. However, I couldn’t find any papers using anything like the ‘subword’ tokenization. They either use individual amino acids as tokens, or in rare cases some choice of n-grams (for example, triplets of amino acids). The advantage of subword tokenization over these is that it can scale between the complexity of single-acid encodings and massive n-gram approaches by simply adjusting the vocabulary size.

Your Homework

I did some initial tests – this definitely smells promising, but there is a lot of work to do for this to be useful to anyone, and I don’t currently have the time or compute to give it a proper go. If you’re looking for a fun NLP challenge with the potential to turn into some interesting research, this could be the job for you! Here’s my suggestions:

  • Pick one or more benchmarks. Classification of the PFam dataset is a nice one to start with. The ProtCNN paper[6] (quick link) has done a bunch of the ‘standard’ algorithms and shared their split as a kaggle dataset, so you can quickly compare to those results.
  • Get some data for language model training. The SWISSProt dataset is a nice one, and for early tests even just the PFam dataset is enough to try things out.
  • Train some language models. Do single-acid tokenization as a baseline and then try subword tokenization with a few different vocab sizes to compare.
  • See which models do best on the downstream classification task. Lots of experimenting to be done on sequence length, training regime and so on.
  • For bonus points, throw a transformer model or two at this kind of problem. I bet they’d be great, especially if pre-trained on a nice big dataset.
  • If (as I suspect) one of these does very well, document your findings, try everything again in case it was luck and publish it as a blog or, if you’re a masochist, a paper.
  • … profit?

I really hope someone reading this has the motivation to give this a go. If nothing else it’s a great learning project for language modelling and diving into a new domain. Please let me know if you’re interested – I’d love to chat, share ideas and send you the things I have tried. Good luck 🙂

References

[1] – Howard, J. and Ruder, S., 2018. Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146.

[2] – Merity, S., Keskar, N.S. and Socher, R., 2017. Regularizing and optimizing LSTM language models. arXiv preprint arXiv:1708.02182.

[3] – Smith, L.N., 2017, March. Cyclical learning rates for training neural networks. In 2017 IEEE Winter Conference on Applications of Computer Vision (WACV) (pp. 464-472). IEEE.

[4] – Kudo, T. and Richardson, J., 2018. Sentencepiece: A simple and language independent subword tokenizer and detokenizer for neural text processing. arXiv preprint arXiv:1808.06226.

[5] – Howard, J. and Gugger, S., 2020. Fastai: A layered API for deep learning. Information, 11(2), p.108.

[6] – Bileschi, M.L., Belanger, D., Bryant, D.H., Sanderson, T., Carter, B., Sculley, D., DePristo, M.A. and Colwell, L.J., 2019. Using deep learning to annotate the protein universe. bioRxiv, p.626507. (ProtCNN)

[7] – Strodthoff, N., Wagner, P., Wenzel, M. and Samek, W., 2020. UDSMProt: universal deep sequence models for protein classification. Bioinformatics36(8), pp.2401-2409. (USDMProt)

[8] – Rifaioglu, A.S., Doğan, T., Martin, M.J., Cetin-Atalay, R. and Atalay, V., 2019. DEEPred: automated protein function prediction with multi-task feed-forward deep neural networks. Scientific reports9(1), pp.1-16. (DEEPPred)

[9] – Altschul, S.F., Madden, T.L., Schäffer, A.A., Zhang, J., Zhang, Z., Miller, W. and Lipman, D.J., 1997. Gapped BLAST and PSI-BLAST: a new generation of protein database search programs. Nucleic acids research25(17), pp.3389-3402.

Data Glimpse: Predicted Historical Air Quality for African Cities

Air quality has been in the news a lot recently. Smoke from fires has had thousands of Californians searching for info around the health hazards of particulate matter pollution. Lockdown-induced changes have shown how a reduction in transport use can bring a breath of fresh air to cities. And a respiratory virus sweeping the globe has brought forward discussions around lung health and pollution, and the health burden associated with exposure to unhealthy levels of pollutants. There are thousands of air quality sensors around the world, but if you view a map of these sensors, it’s painfully obvious that some areas are underserved, with a marked lack of data:

Air Quality from sensors around the world. Source: https://waqi.info/

The ‘gap in the map’ was the motivation for a weekend hackathon hosted through Zindi, which challenged participants to build a model capable of predicting air quality (specifically PM25 concentration) based on available satellite and weather data.

The hackathon was a success, and was enough of a proof-of-concept that we decided to put a little more work into taking the results and turning them into something useful. Myself and Yasin Ayami spent a bit of time re-creating the initial data preparation phase (pulling the data from the Sentinel 5P satellite data collections in Google Earth Engine, creating a larger training set of known air quality readings etc) and then we trained a model inspired by the winning solutions that is able to predict historical air quality with a mean absolute error of less than 20.

Dashboard for exploring air quality across Africa (http://www.datasciencecastnet.com/airq/)

A full report along with notebooks and explanation can be found in this GitHub repository. But the good news is that you don’t need to re-create the whole process if you’d just like a look at the model outputs – those predictions are available in the repository as well. For example, to get the predictions for major cities across Africa you can download and explore this CSV file. And if you don’t want to download anything, I’ve also made a quick dashboard to show the data, both as a time-series for whatever city you want to view and as a map showing the average for all the locations.

I’ve tagged this post as a ‘Data Glimpse’ since the details are already written up elsewhere 🙂 I hope it’s of interest, and as always let me know if you have any questions around this. J.

Personal Metrics

This is just a quick post following on from some recent conversations in this area. tldr: Tracking some data about yourself is a great exercise, and I highly recommend it. In this post I’ll share a few of the tools I use, and dig around in my own data to see if there are any interesting insights….

Time Tracker: Toggl

The first source of data is my time tracker: toggl. It’s simple to use, and has a web app as well as a good android app. As a consultant, this is useful for billing etc, but it has also just become a general habit to log what I’m working on. It’s good motivation not to context-switch, and it’s a great way to keep track of what I’m up to. A good day of work can sometimes mean 4 hours on the clock, since I tend not to log small tasks or admin, but it’s still good enough that I’ll bill clients based on the hours logged. Toggle let you do some reporting within the app, but you can also export the data to CSV for later analysis. Here’s my last two years, total seconds per month:

Time logged per month (as of August 12)

As you can see, I’ve been busier than normal the past few months – one of the reasons this blog hasn’t had any new posts for a while!

Daily mood and activities: daylio

Daylio is a smartphone app that asks ‘How was your day?’ every day, and optionally let’s you log activities for the day. I’ve made it a habit, although tracking stopped for a few months at the start of the pandemic :/ One thing I like about this (and the previous thing I used, https://year-in-pixels.glitch.me/) is that it forces you to evaluate how you’re feeling. Was today great, or merely good? Why was it ‘Meh’? And by quantifying something less concrete than simply hours worked, it let’s me see what I can do to optimize for generally better days.

Time worked on days marked as Average, Good or Great

Mondays are my lowest day, followed by Wednesdays. Being outdoors bumps my rating from ~4 (good) to nearly 4.5 (5 being ‘great’). As you can see in the image above, lots of work tends to mean not-so-great days. Around 3 hours per day logged (4-6 hours work) is where I start properly having fun, and if I can fit in activities like birding or something creative then it’s even closer to optimum. I’m in a pretty good place now despite the busyness – the average score (~4.3) is much higher than when I was still in uni trying to balance work and assignments (3.3). It’s nice to see this – on tougher days it’s amazing to look back and see how many good or great ones there are, and how lovely life is overall.

Moar data: uLogMe

I recently found a project called uLogMe (by Karpathy oif all people), and after reading his post about it I decided to give it a go. If you’re keen to try it, look for a fork on HitHub as the original project is deprecated. I only use the logging scripts, which keep track of active window title and number of keystrokes in each 9s window. This is really fun data, as you can identify different activities, find patterns, see trends in when you’re most active… As one example, look at a fairly typical day from last month:

Keystroke intensity over time

You can see me start a little late, since it’s winter. After an initial burst of work I went on a long walk looking for insects (there was a bioblitz on) before hacking away during my 10am meeting. There are spikes of activity and periods of very little (meetings) or no (breaks) activity. 6-8pm is my class time, so I’m tapping away in demos as I teach, especially in the second half of the lesson.

Check out Karpathy’s post to see what else it’s possible to do with this data.

Putting it all together

I can’t wait to get a fitness tracker to add sleep tracking, exercise and heart rate. But even without those, I have some really great data to be playing with. I can see relationships between external factors (travel, activities, work) and my mood, explore how much time goes into different projects, graph the number of characters being typed in different applications (spoiler: I use Jupyter a LOT) and generally put some hard numbers behind my intuition around how I’m spending my time and how that’s affecting me.

A small subset of the data now waiting to be analysed

I hope that this post marks a return to this blog for me (hours are trending downwards now that some courses I teach are wrapping up) and that it inspires you to find some personal data to track! If you still aren’t convinced, here’s a TED talk that might push you over the edge. Happy hacking 🙂

Self-Supervised Learning with Image网

Intro

Until fairly recently, deep learning models needed a LOT of data to get decent performance. Then came an innovation called transfer learning, which we’ve covered in some previous posts. We train a network once on a huge dataset (such as ImageNet, or the entire text of Wikipedia), and it learns all kinds of useful features. We can then retrain or ‘fine-tune’ this pretrained model on a new task (say, elephant vs zebra), and get incredible accuracy with fairly small training sets. But what do we do when there isn’t a pretrained model available?

Pretext tasks (left) vs downstream task (right). I think I need to develop this style of illustration – how else will readers know that this blog is just a random dude writing on weekends? 🙂

Enter Self-Supervised Learning (SSL). The idea here is that in some domains, there may not be vast amounts of labeled data, but there may be an abundance of unlabeled data. Can we take advantage of this by using it somehow to train a model that, as with transfer learning, can then be re-trained for a new task on a small dataset? It turns out the answer is yes – and it’s shaking things up in a big way. This fastai blog post gives a nice breakdown of SSL, and shows some examples of ‘pretext tasks’ – tasks we can use to train a network on unlabeled data. In this post, we’ll try it for ourselves!

Follow along in the companion notebook.

Image网

Read the literature on computer vision, and you’ll see that ImageNet has become THE way to show off your new algorithm. Which is great, but coming in at 1.3 million images, it’s a little tricky for the average person to play with. To get around this, some folks are turning to smaller subsets of ImageNet for early experimentation – if something works well in small scale tests, *then* we can try it in the big leagues. Leading this trend have been Jeremy Howard and the fastai team, who often use ImageNette (10 easy classes from ImageNet), ImageWoof (Some dog breeds from ImageNet) and most recently Image网 (‘ImageWang’, 网 being ‘net’ in Chinese).

Image网 contains some images from both ImageNette and ImageWoof, but with a twist: only 10% of the images are labeled to use for training. The remainder are in a folder, unsup, specifically for use in unsupervised learning. We’ll be using this dataset to try our hand at self-supervised learning, using the unlabeled images to train our network on a pretext task before trying classification.

Defining Our Pretext Task

A pretext task should be one that forces the network to learn underlying patterns in the data. This is a new enough field that new ideas are being tried all the time, and I believe that a key skill in the future will be coming up with pretext tasks in different domains. For images, there are some options explained well in this fastai blog. Options include:

  • Colorization of greyscale images
  • Classifying corrupted images
  • Image In-painting (filling in ‘cutouts’ in the image)
  • Solving jigsaws

For fun, I came up with a variant of the image in-painting task that combines it with colorization. Several sections of the input image are blurred and turned greyscale. The network tries to replace these regions with sensible values, with the goal being to have the output match the original image as closely as possible. One reason I like the idea of this as a pretext task is that we humans get something similar. Each time we move our eyes, things that were in our blurry, greyscale peripheral vision are brought into sharp focus in our central vision – another input for the part of our brain that’s been pretending they were full HD color the whole time 🙂

Here are some examples of the grey-blurred images and the desired outputs:

Input/Output pairs for our pretext task, using the RandomGreyBlur transform

We train our network on this task for 15 epochs, and then save its parameters for later use in the downstream task. See the notebook for implementation details.

Downstream Task: Image Classification

Now comes the fun part: seeing if our pretext task is of any use! We’ll follow the structure of the Image网 leaderboard here, looking at models for different image sizes trained with 5, 20, 80 or 200 epochs. The theory here is that we’d hope that out pretext task has given us a decent network, so we should get some results after 5 epochs, and keep getting better and better results with more training.

Results from early testing

The notebook goes through the process, training models on the labeled data provided with Image网 and scoring them on the validation set. This step can be quite tedious, but the 5-epoch models are enough to show that we’ve made an improvement on the baseline, which is pretty exciting. For training runs 20 epochs and greater, we still beat a baseline with no pre-training, but fall behind the current leaderboard entry based on simple inpainting. There is much tweaking to be done, and the runs take ~1 minute per epoch, so I’ll update this when I have more results.

Where Next?

Image网 is fairly new, and the leaderboard still needs filling in. Now is your chance for fame! Play with different pretext tasks (for eg, try just greyscale instead of blurred greyscale – it’s a single line of code to change), or tweak some of the parameters in the notebook and see if you can get a better score. And someone please do 256px?

Beyond this toy example, remember that unlabeled data can be a useful asset, especially if labeled data is sparse. If you’re ever facing a domain where a pretrained model is unavailable, self-supervised learning might come to your rescue.

Meta ‘Data Glimpse’ – Google Dataset Search

Christmas came in January this year, with Google’s release of ‘Dataset Search‘. They’ve indexed millions of cool datasets and made it easy to search through them. This post isn’t about any specific dataset, but rather I just wanted to share this epic new resource with you.

Google’s Dataset Search

I saw the news as it came out, which meant I had the pleasure of sharing it with my colleagues – all of whom got nerd sniped to some degree, likely resulting a much loss of revenue and a ton of fun had by all 🙂 A few minutes after clicking the link I was clustering dolphin vocalizations and smiling to myself. If you’re ever looking for an experiment to write up, have a trawl through the datasets on there and pick one that hasn’t got much ML baggage attached – you’ll have a nice novel project to brag about.

Clustering Dolphin noises

Say what you like about Google, there are people there doing so much to push research forward. Tools like Colab, Google Scholar, and now Dataset Search make it easy to do some pretty amazing research from anywhere. So go on – dive in 🙂

Swoggle Part 2 – Building a Policy Network with PyTorch, dealing with Cheaty Agents and ‘Beating’ the Game

In part 1, we laid the groundwork for our Reinforcement Learning experiments by creating a simple game (Swoggle) that we’d be trying to teach out AI to play. We also created some simple Agents that followed hard-coded rules for play, to give our AI some opponents. In this post, we’ll get to the hard part – using RL to learn to play this game.

The Task

Reinforcement Learning (Artist’s Depiction)

We want to create some sort of Agent capable of looking at the state of the game and deciding on the best move. It should be able to learn the rules and how to win by playing many games. Concretely, our agent should take in an array encoding the dice roll, the positions of the players and bases etc, and it should output one of 192 possible moves (64 squares, with two special kinds of move to give 64*3 possible actions). This agent shouldn’t just be a passive actor – it must also be able to learn from past games.

Policy Networks

In RL, a ‘policy’ is a map from game state to action. So when we talk about ‘Policy Learners’, ‘Policy Gradients’ or ‘Policy Networks’, we’re referring to something that is able to learn a good policy over time.

The network we’ll be training

So how would we ‘learn’ a policy? If we had a vast archive of past games, we could treat this as a supervised learning task – feed in the game state, chosen action and eventual reward for each action in the game history to a neural network or other learning algorithm and hope that it learns what ‘good’ actions look like. Sadly, we don’t have such an archive! So, we take the following approach:

  • Start a game (an ‘episode’)
  • Feed the game state through our policy network, which initially will give random output probabilities on each possible action
  • Pick an action, favoring those for which the network output is high
  • Keep making actions and feeding the resultant game state through the network to pick the next one, until the game ends.
  • Calculate the reward. If we won, +100. If we lost, -20. Maybe an extra +0.1 for each valid move made, and some negative reward for each time we tried to break the rules.
  • Update the network, so that it (hopefully) will better predict which moves will result in positive rewards.
  • Start another game and repeat, for as long as you want.

Here’s a notebook where I implement this. The code borrows a little from this implementation (with associated blog post that explains it well). Some things I changed:

  • The initial example (like most resources you’ll find if you look around) chooses a problem with a single action – up or down, for example. I modified the network to take in 585 inputs (the Swoggle game state representation) and give out 192 outputs for the 62*3 possible actions an agent could take. I also added the final sigmoid layer since I’ll be interpreting the outputs as probabilities.
  • Many implementations either take random actions (totally random) or look at the argmax of the network output. This isn’t great in our case – random actions are quite often invalid moves, but the top output of the network might also be invalid. Instead, we sample an action from the probability distribution represented by the network output. This is like the approach Andrej Karpathy takes in his classic ‘Pong from Pixels’ post (which I highly recommend).
  • This game is dice-based (which adds randomness) and not all actions are possible at all times, so I needed to add code to handle cases where the proposed move is invalid. In those cases, we add a small negative reward and try a different action.
  • The implementation I started from used a parameter epsilon to shift from exploration (making random moves) to optimal play (picking the top network output). I removed this – by sampling from the prob. distribution, we keep our agent on it’s toes, and it always has a chance of acting randomly/unpredictably. This should make it more fun to play against, while still keeping it’s ability to play well most of the time.

This whole approach takes a little bit of time to internalize, and I’m not best placed to explain it well. Check out the aforementioned ‘Pong from Pixels’ post and google for Policy Gradients to learn more.

Success? Or Cheaty Agents?

OpenAI’s glitch-finding players (source: https://openai.com/blog/emergent-tool-use/)

Early on, I seemed to have hit upon an excellent strategy. Within a few games, my Agent was winning nearly 50% of games against the basic game AI (for a four player game, anything above 25% is great!). Digging a little deeper, I found my mistake. If the agent proposed a move that was invalid, it stayed where it was while the other agents moved around. This let it ‘camp’ on it’s base, or wait for a good dice roll before swoggling another base. I was able to get a similar win-rate with the following algorithm:

  1. Pick a random move
  2. If it’s valid, make the move. If not, stay put (not always a valid action but I gave the agent control of the board!)

That’s it – that’s the ‘CheatyAgent’ algorithm 🙂 Fortunately, I’m not the first to have flaws in my game engine exploited by RL agents – check out the clip from OpenAI above!

Another bug: See where I wrote sr.dice() instead of dice_roll? This let the network re-roll if it proposed an invalid move, which could lead to artificially high performance.

After a few more sneaky attempts by the AI to get around my rules, I finally got a setup that forced the AI to play by the rules, make valid moves and generally behave like a good and proper Swoggler should.

Winning for real

Learning to win!!!

With the bugs ironed out, I could start tweaking rewards and training the network! It took a few goes, but I was able to find a setup that let the agent learn to play in a remarkably short time. After a few thousand games, we end up with a network that can win against three BasicAgents about 40-45% of the time! I used the trained network to pick moves in 4000 games, and it won 1856 of them, confirming it’s superiority to the BasicAgents, who hung their heads in shame.

So much more to try

I’ve still got plenty to play around with. The network still tries to propose lots of invalid moves. Tweaking the rewards can change this (note the orange curve below that tracks ratio of valid:invalid moves) but at the cost of diverting the network from the true goal: winning games!

Learning to make valid moves, but at the cost of winning.

That said, I’m happy enough with the current state of things to share this blog. Give it a go yourself! I’ll probably keep playing with this, but unless I find something super interesting, there probably won’t be a part 3 in this series. Thanks for coming along on my RL journey 🙂

Swoggle Part 1- RL Environments and Literate Programming with NBDev

I’m going to be exploring the world of Reinforcement Learning. But there will be no actual RL in this post – that’s for part two. This post will do two things: describe the game we’ll be training our AI on, and show how I developed it using a tool called NBDev which is making me so happy at the moment. Let’s start with NBDev.

What is NBDev?

Like many, I started my programming journey editing scripts in Notepad. Then I discovered the joy of IDEs with syntax highlighting, and life got better. I tried many editors over the years, benefiting from better debugging, code completion, stylish themes… But essentially, they all offer the same workflow: write code in an editor, run it and see what happens, make some changes, repeat. Then came Jupyter notebooks. Inline figures and explanations. Interactivity! Suddenly you don’t need to re-run everything just to try something new. You can work in stages, seeing the output of each stage before coding the next step. For some tasks, this is a major improvement. I found myself using them more and more, especially as I drifted into Data Science.

But what about when you want to deploy code? Until recently, my approach was to experiment in Jupyter, and then copy and paste code into a separate file or files which would become my library or application. This caused some friction – which is where NBDev comes in.

~~~~~ “Create delightful python projects using Jupyter Notebooks” – NBDev website ~~~~~

With NBDev, everything happens in your notebooks. By adding special comments like #export to the start of a cell, you tell NBDev how to treat the code. This means you can write a function that will be exported, write some examples to illustrate how it works, plot the results and surround it with nice explanations in markdown. The exported code gets paces in a neat, well-ordered .py file that becomes your final product. The Notebook(s) becomes documentation, and the extra examples you added to show functionality work as tests (although you can also add more formal unit testing). An extra line of code uploads your library for others to install with pip. And if you’re following their guide, you get a documentation site and continuous integration that updates whenever you push your changes to GitHub.

The upshot of all this is that you can effortlessly create good, clean code and documentation without having to switch between notebooks, editors and separate documentation. And the process you followed, the journey that lead to the final design choices, is no longer hidden. You can show how things developed, and include experiments that justify a particular choice. This is ‘literate programming’, and it feels like a major shift in the way I think about software development. I could wax lyrical about this for ages, but you should just go and read about it in the launch post here.

What on Earth is Swoggle?

Christmas, 2019. Our wedding has brought a higher-than-normal influx of relatives to Cape Town, and when this extended family gets together, there are some things that are inevitable. One of these, it turns out, is the invention of new games to keep the cousins entertained. And thus, Swoggle was born 🙂

A Swoggle game in progress – 2 players are left.

The game is played on an 8×8 board. There are usually 4 players, each with a base in one of the corners. Players can move (a dice determines how far), “spoggle” other players (capturing them and placing them in “swoggle spa” – none of this violent termnology) or ‘swoggle’ a base (gently retiring the bases owner from the game – no killing here). To make things interesting, there are four ‘drones’ that can be used as shields or to take an occupied base. Moving with a drone halves the distance you can travel, to make up for the advantages. A player with a drone can’t be spoggled by another player unless they too have a drone, or they ‘powerjump’ from their base (a half-distance move) onto the droned player. Maybe I’ll make a video one day and explain the rules properly 🙂

So, that’s the game. Each round is fairly quick, so we usually play multiple rounds, awarding points for different achievements. Spoggling (capturing) a player: 1 point. Swoggling (taking out a base): 3 points. Last one standing: 5 points. The dice rolls add lots of randomness, but there is still plenty of room for tactics, sibling rivalry and comedic mistakes.

Game Representation

If we’re going to teach a computer to play this, we need a way to represent the game state, check if moves are valid, keep track of who’s in the swoggle spa and which bases are still standing, etc. I settled on something like this:

Game state representation

There is a Cell in each x, y location, with attributes for player, drone and base. These cells are grouped in a Board, which represents the game grid and tracks the spa. The Board class also contains some useful methods like is_valid_move() and ways to move a particular player around. At the highest level, I have a Swoggle class that wraps a board, handles setting up the initial layout, provides a few extra convenience functions and can be used to run a game manually or with some combination of agents (which we’ll cover in the next section). Since I’m working in NBDev, I have some docs with almost no effort, so check out https://johnowhitaker.github.io/swoggle/ for details on this implementation. Here’s what the documentation system turned my notebooks into:

Part of the generated documentation

The ability to write code and comments in a notebook, and have that turn into a swanky docs page, is borderline magical. Mine is a little messy since this is a quick hobby project. To see what this looks like in a real project, check out the docs for NBDev itself or Fastai v2.

Creating Agents

Since the end goal is to use this for reinforcement learning, it would be nice to have an easy way to add ‘Agents’ – code that defines how a player in the game will make a move in a given situation. It would also be useful to have a few non-RL agents to test things out and, later, to act as opponents for my fancier bots. I implemented two types of agent:

  • RandomAgent Simply picks a random but valid move by trial and error, and makes that move.
  • BasicAgent Adds a few simple heuristics. If it can take a base, it does so. If it can spoggle a player, it does so. If neither of these options are possible, it moves randomly.

You can see the agent code here. The notebook also defines a few other useful functions, such as win_rates() to pit different agents against each-other and see how they do. This is fun to play with – after a few experiments it’s obvious that the board layout and order of players matters a lot. A BasicAgent going last will win ~62% of games against three RandomAgents – not unexpected. But of the three RandomAgents, the one opposite the BasicAgent (and thus furthest from it) will win the majority of the remaining games.

Next Step: Reinforcement Learning!

This was a fun little holiday coding exercise. I’m definitely an NBDev convert – I feel so much more productive using this compared to any other development approach I’ve tried. Thank you Jeremy, Sylvain and co for this excellent tool!

Now, the main point of this wasn’t just to get the game working – it was to use it for something interesting. And that, I hope, is coming soon in Part 2. As I type this, a neural network is slowly but surely learning to follow the rules and figuring out how to beat those sneaky RandomAgents. Wish it luck, stay tuned, and, if you’re *really* bored, pip install swoggle and watch some BasicAgents battle it out 🙂