Generative AI services like GPT-4 (and the machine learning and deep learning tasks that train them) consume mountains of data. That fuels a belief that more data means better AI. Consider how EleutherAI’s recent expansion of Pile, already one of the world’s biggest AI training datasets, was a headline-worthy announcement.
But more data doesn’t guarantee your generative AI will respond better to user queries. As natural language datasets get larger, the potential for redundant and detrimental data increases, leading to:
- inflated costs
- increased latency
- decreased model accuracy
Fortunately, dataset pruning can sift out inefficient, detrimental examples from these massive sets. The Data & AI Research Team (DART) here at WillowTree has seen that for some natural language datasets, training on as little as 20% of the dataset to classify the remaining 80% often produces accuracies similar to training on 90% and classifying the 10% remaining.
A Dual-Model Approach to Dataset Pruning
Our data pruning technique combines two models:
- a classification model trained on the original dataset (and eventually on the pruned dataset)
- a regression model used to determine which examples to prune from the original dataset
For the regression model, we create a binary inclusion dataset where each column represents one example in our full original dataset, and each row is one train/test split of that dataset. Each row also has an accuracy — the test accuracy of the classification model for that train/test split. (See Figure 1 in the “Run train/test splits” section below for more detail.)
This binary inclusion dataset allows us to focus on the feature importance of our regression model. That means the features are the data points (i.e., nodes) in our dataset, and the regression model predicts the accuracy (or some other metric) of the inference model to perform data pruning.
Our inference model is a K-nearest neighbor (KNN) algorithm with k=1 for intent classification, where the input is the embeddings of the user prompt. Our regression model is linear regression, which attempts to predict the accuracy of our classification model based on the examples included in our training set.
Throughout our research, we’ve tested this strategy using multiple embedding models, namely OpenAI’s ada-002 and 3-large embeddings, and Voyage AI’s voyage-large-2 embeddings.
Why We Developed This Data Pruning Strategy
Discerning beneficial data from potentially excessive or harmful information is a prevalent uncertainty in our field, more so than simply noting shortcomings in existing training sets. In our work with intent classification, we often face the challenge of refining our training datasets. Whether we manually curate data, pay for organized sets, or generate synthetic data from platforms like GPT-4, we frequently question the optimal balance of data inclusion and exclusion.
When using a generative AI model to generate a new dataset for training, we often default to generating a very large number of examples for each intent category. We think that because generation is so easy, we might as well get a huge number of examples. After all, more data is typically better.
However, we found that this often wasn’t the case. For some datasets we examined, training on as little as 20% of the dataset to classify the remaining 80% often produced accuracies similar to training on 90% of the dataset and classifying the remaining 10%, suggesting that the dataset is highly repetitive in nature. Viewed differently, we could conclude that the informational density of this dataset (value per observation) is quite low.
This led us to believe that there is a limit to the usefulness of synthetic examples, especially those generated using the same prompt, model, and model settings. However, using these tools to generate examples is still much faster and less expensive than human generation, and knowing exactly how many useful examples can be wrung out of each prompt is a prohibitively difficult task.
In response, our approach to pruning allows for the generation of a very large dataset (or combination of sets) with a high number of examples for each intent category to then prune and create our subset. This allows the use of quickly and easily generated synthetic data without having to consider exactly how many examples to generate for each category via different generation methods.
How to Prune Your Own Datasets
Using this method will help you determine which examples should remain in the training dataset and which can be removed — all without manually inspecting examples in the dataset to browse for biases or potential shortcomings. We illustrate each step with results from our own testing for reference.
The heart of this method is feature importance: The aim is to find which examples in the set have the highest contribution to model accuracy. This relies on two models: an intent classification model and an accuracy prediction model. We used a KNN with k=1 for our intent classification model, and linear regression for our accuracy prediction model.
1. Generate train/test splits
To start, generate a large number of train/test splits and find the test accuracy of the KNN for each split. This should produce a data frame like the (truncated) version in Figure 1 shown below:
- each row represents one unique train/test split
- each numbered column indicates whether that example was included in the training set
- the accuracy column denotes the test accuracy for the given train/test split
This method requires more train/test splits than example prompts in the intent classification so the matrix containing all of the train/test splits is full column rank.
2. Train the linear regression model
After you create the train/test splits, the next step is to train the linear regression model. Train the model to predict test accuracy based on which examples you include in the training dataset.
Theoretically, the positive coefficients correspond to example prompts that increase the intent classification model’s accuracy by being included in the training dataset. The negative coefficients correspond to example prompts that decrease the intent classification model’s accuracy via their inclusion.
Because the trained linear regression model provides an approximation of each example’s contribution to intent classification accuracy via its coefficients, these coefficients will be used to perform pruning.
3. Prune the dataset
To prune the dataset, gradually remove examples from the training set, starting with the examples with the lowest coefficients in the linear regression model. At each iteration, review your analytics to check the updated accuracy of the intent classification model.
To avoid issues of small test sets when only a few examples have been removed from the dataset, we used the following modification of k-fold cross-validation accuracy:
For each dataset, perform this modification over a wide range of removals to find out how many removals yield the highest cross-validation accuracy. In our work, we also compared accuracies based on removing examples via our method and random removal to show how our method performs better than chance (see Figure 3 below).
Based on the results in Figure 3, we would choose the dataset with the examples of the 500 lowest coefficients removed (maximizing accuracy). The accuracy with these examples removed was 75% or greater. The baseline accuracy (with no pruning at all) was just below 66%, meaning there was a nearly 10-point improvement over using the original dataset.
Testing this method on a variety of datasets, and using a variety of different embedding models (OpenAI’s ada-002 and 3-large, Voyage AI’s voyage-large-2), we saw similar results to those in Figure 3.
While the increase in accuracy via pruning in Figure 4 is smaller than that of Figure 3, this may be partly due to this dataset’s much higher baseline accuracy.
Parting Insights: Accuracy and Optimal Number of Examples
The likely presence of redundant examples, especially in a large synthetic dataset, is an intuitive explanation for why we might not see a dropoff in accuracy by removing some examples from our training set. However, a potential explanation for increased accuracy after removing training data is somewhat less intuitive.
One explanation could be that certain training points may have a certain label, but end up in an embedding space that’s more likely to contain examples from a different intent category. Figure 5 below shows us how one data point in whose intent category is an anomaly given its position in the embedding space would negatively affect model performance.
Another insight from our results involves the optimal number of examples per intent category. We found across multiple different datasets that the optimal datasets did not exhibit convergence to a specific number of examples per intent category. Based on these results, developers cannot simply aim for a fixed number of examples for each category and be confident that this is the optimal number and variety of examples.
This again reveals the need for effective dataset pruning. Getting the best number of examples for each intent category doesn’t seem to be a matter of picking a single number and generating that number. Rather, generating a large number of examples for each category, then filtering out less helpful examples via data pruning, appears to be an effective way of settling on the number of examples to include for each individual category.
Master the Skill of Optimizing Generative AI Data
Our findings indicate that over-inflated synthetic datasets lead to suboptimal model performance, increased costs, and overall poorer query results. They further demonstrate how efficient data pruning enhanced model accuracy compared to using the complete dataset. Sacrificing detrimental data improved performance.
Equally important, our research shows there isn’t a single optimal number of examples per intent category in a pruned dataset, underlining the need for individualized pruning strategies. Implementing such a strategy will:
- enhance system performance
- minimize infrastructural costs
- more efficiently use resources
This is particularly true of strategies that enable pruning without needing to manually investigate for specific biases. Further research could unearth more insights, such as regression models for predicting classification model accuracy that outperform at accurately representing feature importance, thus improving pruning.
Keep up with the latest enterprise AI tech research, thought leadership, and best practices on WillowTree’s Data & AI Knowledge Hub.