by Emma Drobina with Andrew Engel and Tony Chiang

What is the Neural Tangent Kernel?

I expect many people reading this are familiar with sentiment analysis – using machine learning to classify text with emotional labels – but much less aware of the neural tangent kernel. It’s a pretty complex topic, and I’m not going to try and explain it in detail here. For more information, see these resources:

  1. original paper on the neural tangent kernel
  2. video summary of the neural tangent kernel
  3. explanation of the math behind the neural tangent kernel
  4. Andrew’s paper on using the neural tangent kernel to explain neural networks

At a very high level: the neural tangent kernel (NTK) tells us how similar two data points are from the point of view of a given neural network. Given a pre-computed NTK, we can transform a neural network into a linear model with equivalent performance. (For infinite-width neural networks, this NTK-based kernel model will be exactly equivalent, but we obviously are usually not working with infinite-width models.) Linear models are much nicer to work with for explanation purposes.

For this project, we use the trace neural tangent kernel (trNTK), which is a faster-to-calculate approximation of the NTK. The trNTKs were computed with projection-NTK, a library for NTK computation created by PNNL. At the moment, it’s not publically available, but it may be made open-source in the future. I also used two other kernels in my experiments:

  • the trNTK0, which is the non-normalized trNTK.
  • the conjugate kernel (CK), which is calculated from the last hidden layer of the model and is an additive component of the NTK for fully-connected models. This lets us explore the predictive value of the full model vs. only its last layer.

What data did you use?

All these analyses were performed using the Twitter US Airline Sentiment dataset released publicly on Kaggle. These tweets date from February 2015 and were hand-labelled as negative, neutral, or positive. I subsampled the dataset to ensure class parity, since the original dataset was unsurprisingly skewed towards negative tweets and to remove all tweets where the labellers were less than 65% confident in their label.

row 1 (negative class example): thank you for dishonoring my upgrade and putting me in a seat I didn't want, all while not even notifying me. great 1K service [thumbs down emoji]row 2 (neutral class example): I have submitted my request. I would appreciate a call by 9am eastern. Thank you. row 3 (positive class example): I really love your customer service Lou Ann in Phoenix rocks. Thanks SW. #Be Our Guest

Figure 1: An example of each class from the dataset.

I fine-tuned a BERT-base-uncased model to 83% accuracy on this dataset for my analyses. This is far from state of the art, but for this project, I was interested in exploring how the NTK functions as a surrogate model, not developing the best sentiment analysis model I could.

Figure 2: Confusion matrix for the BERT-base-uncased model finetuned on the Twitter Airline Data.

Why counterfactuals?

Looking at the trNTK for a single data point can only tell you so much. However, by slightly perturbing a single data point, we can get counterfactual examples – plausible sentences that could exist and that are similar enough to existing test sentences that we can use them to better understand changes in the model’s confidence and/or predictions.

I hand-generated counterfactuals for a number of sentences. For the rest of this post, we will focus on a set of counterfactuals generated for the sentence, “the whole plane. Flight 561 from LGA to PBI”. The perturbations I made fell into three categories:

  1. “neutral” changes that swap words for synonyms or change grammar
  2. “positive” changes that add positive-toned words
  3. and “negative” changes that add negative-toned words

Figure 3: Example counterfactual sentences for each category.

When the kernel attributions are graphed for all counterfactuals grouped by category and kernel, we can see how the perturbations change them. The trNTK and the CK show the neutrally-perturbed counterfactuals largely staying neutral, and the positively- and negatively-perturbed counterfactuals switching over to their targeted class. The trNTK0, however, shows a much more fragmented performance for the negatively-pertubed counterfactuals.

Figure 4.

Looking at the individual level

Since the trNTK is computed layer by layer, we can look at the performance of the model at individual layer steps. This can reveal interesting information about where the original neural network is starting to make its decisions. For example, the next three images show how placing an exclamation mark can make the model acting significantly differently in the last layers of the neural network in a way that is not obvious from looking at the full NTK alone.

Likewise, we can compare and contrast examples of negatively-perturbed counterfactuals and positively-perturbed counterfactuals. At least in our sample, positively-perturbed counterfactuals have their final classification emerge as a candidate before negatively-perturbed counterfactuals. Additionally, positively-perturbed counterfactuals have more similar-looking graphs overall than negatively-perturbed counterfactuals. This reflects what we saw in the aggregate graph back in Figure 4 – the trNTK values for positively-perturbed counterfactuals were all highly positive, while many of the negatively-perturbed counterfactuals were more difficult to distinguish from other classes.

These analyses show us some interesting initial results about how NTKs and counterfactuals can help explore the robustness of large language models.

This project was done May-July 2023 at Pacific Northwest National Lab as part of the National Security Internship Program.