I’ve been working on variable-sized experts in MoEs (previous post), using a modified version of Megablocks. The TL;DR from that work: at my scale, I didn’t find efficiencies beyond what you’d get from simply using narrower experts across the board. But since I have this hammer, I’ve been looking for nails.
I’d had my eye on doing a project with FlexOlmo for a while, and it seemed like a perfect nail. The core idea is training specialized experts separately on their own domains, then combining them into an MoE1. They got good results, and the architecture opens the door to data collaboration. Organizations can train experts on private data without surrendering it. But training a 4.3B expert isn’t cheap, and in data constrained situations, it doesn’t necessarily make sense to train a model that large. If smaller experts work, that lowers the barrier to participation significantly.
Since I have limited resources, I figured the best way to do this was to prune the expert MLPs of one of the existing expert + public model models that AI2 released, and use distillation to retrain the model, like these two NVIDIA papers (Muralidharan et al., Sreenivas et al.), as opposed to training new experts from scratch. In ~228M tokens of retraining with KLD distillation, I’m happy with the results, and it proves the concept well. Even pruning the expert down to ~800M parameters total improves the Math2 score from 8.1 to 29.1.
I pruned the Flex-math-2x7B-1T math expert to three widths: 8192, 5504, and 2048 (down from 11008). Since the hidden size of the base model has to stay the same, and I didn’t want to touch the attention heads or number of layers, I shrank the expert MLP layers. One thing I’d like to try eventually: pruning different experts by different amounts within the same MoE, with sizing informed by per-expert importance scores rather than arbitrary uniform targets. Variable expert sizes make this straightforward.
For the 2048 width model, I wanted to test how much the dataset used for importance analysis (the step where we decide which neurons to prune) matters. Not necessarily shocking, but it turns out quite a lot! I tried two datasets: a subset of the math data from dolmino-mix-11242, and general data from the same dataset, which does include some math.
58% of the top-2048 most important neurons are different between the two analyses. The early layers mostly agree on what’s important, but from layer 6 onward the rankings diverge.

The model calibrated with math also trained better. Validation loss started lower and ended lower. Note that I did stop the general training run early, but it wasn’t going to catch up, and I wanted to move on to training the larger models.
Here are the loss curves for the math and general models:

These models were retrained using distillation with the top 128 logprobs generated by Flex-math-2x7B-1T using the GSM8k, Metamath-owmfilter, and TuluMath subsets of the DOLMino mix dataset (the same dataset that FlexOlmo was trained with), about 620K total documents. Logprobs dataset here.
Using LM eval harness (which handles base models well), we got pretty close to the paper’s baseline numbers. Math2 is the macro average of GSM8K and MATH3.
| Model | Total Params | Expert Params | Expert Width | GSM8K | MATH | Math2 |
|---|---|---|---|---|---|---|
| Public model, no expert4 | 7.3B | 0 | — | — | — | 8.1 |
| Flex-math-2x7B-1T (baseline) | 11.6B | 4.3B | 11008 (100%) | 69.7 | 35.4 | 52.5 |
| flex-math-8192 | 10.5B | 3.2B | 8192 (74%) | 70.1 | 31.3 | 50.7 |
| flex-math-5504 | 9.5B | 2.2B | 5504 (50%) | 66.6 | 26.8 | 46.7 |
| flex-math-2048 | 8.1B | 0.8B | 2048 (19%) | 44.3 | 13.9 | 29.1 |
The 8192 model is juust about on par with the full-sized expert. Even the 2048 model (0.8B expert params) scores 3.6x the no-expert baseline. The half-sized expert (5504) is pretty competitive with it’s larger siblings.
I think this is a pretty good nail!
I think the way this would go down would be:
I only evaluated on math benchmarks (GSM8K and MATH). It’s possible that pruning the expert hurts general reasoning or other capabilities that I didn’t measure. Running BBH was going to take like 60 hours on my home system, so I figured I’d just punt and do these. Since we’re pruning a math expert and testing math performance, I think the evals here are the right ones, but broader evaluation would be nice. I also really want to know how FlexOlmo works with post-training. Can we mix a post-trained public model and an expert model with just continued pretraining? Or just post-train the expert model? Etc.
This glosses over a few details, but I think it’s an ok way to think about it. ↩
Using the following files from the dataset: data/math/gsm8k/**/*.jsonl, data/math/metamath-owmfilter/**/*.jsonl, data/math/tulu_math/**/*.jsonl. About a week later, I honestly don’t remember why I only chose those from the dataset. I remember trying to avoid code– MathCoder and a couple other parts of the math dataset are code-heavy, but I don’t remember why I avoided e.g., DolminoSynthMath. Bit of an oversight, but I think we still have meaningful results with a smaller dataset. ↩
I think it’s the macro average! I can’t find the exact definition anywhere, but the metrics I found line up with the paper well. ↩
Scores reported from the FlexOlmo paper, Table 1. This is the public-only model with no math expert attached. ↩