Notes on Designing Machine Learning Systems
Published:
Machine learning system design
area: US🇺🇸 date: 2022/10/19 → 2022/12/25
Requirements for ML systems
Reliability Scalability: resource scaling (up-scaling, down-scaling) Maintainability Adaptability
Data
Data is full of biases: from collecting, sampling and labeling
Data models
Relational model document model graph model
Batch processing vs Stream processing
Stream:
- low latency, but nowadays, streaming technologies are highly scalable and fully distributed, can compute in parallel
- the strength of streaming processing is in stateful computation
sampling
helps accomplish a task faster and cheaper with less data.
works or necessary when:
- don’t have access to all real-world data, training on a subset
- infeasible to process all data you have access to
Common sampling methods:
- No-probability sampling
- convenience sampling, snowball sampling, judgement sampling, quota sampling.
- have selection bias
- Simple random sampling:
- rare categories may disappear
- Stratified sampling
- impossible to divide all samples into groups, e.g. multi-label
- Weighted sampling
- allow you to leverage domain expertise
- correct data distribution to true data
- Reservoir sampling
- deal with steaming data, maintain a reservoir with k elements
- Importance sampling
- allow us to sample from a distribution when we only have access to another distribution
Labeling
Hand labels
- expensive
- poses threat to data privacy
- slow
- label multiplicity data lineage: keep track of the origin of each data samples and labels. helps you both flag potential biases in your data and debug your models.
Natural labels
feedback loop length may be long
Deal with lack of label
Weak supervision
- leverage heuristics to obtain noisy labels, simple but powerful paradigm.
- A good way to start to explore the effectiveness of ML without investing too much labeling effort Labeling function encodes features like: keyword heuristic, regular expressions, database lookup, outputs of other models
Semi-supervision
leverage structural assumptions to generate new labels bases on a small set of initial labels.
- self-training
- data samples share similar characteristics share the same labels
- perturbation-based method: small perturbations shouldn’t change a sample’s label Use large evaluation set to select the best model, and continue training it on evaluation set.
Transfer learning
the larger the pre-trained base model, the better its performance ob the downstream tasks
Active learning
ML models can achieve greater accuracy with fewer labels if they can choose which samples to learn from.
- label samples according to metrics or heuristics, e.g. uncertainty measurement
- query-by-committee (ensemble method): use multi-models for voting
Class imbalance
- insufficient signal for your model to learn to detect minority classes.
- get stuck in non-optimal solution by exploring a simple heuristic instead of learning underlying data pattern
- leads to asymmetric costs of errors. (e.g. misclassification of cancerous cells is much more dangerous.)
Handle class imbalance
- use right evaluation metrics
- data-level methods: resampling (under_sampling: samples of majority class; oversampling: copy minority class)
- algorithm-level methods: class-balanced loss, focal loss, etc.
Data augmentation
- simple label-preserving transformations
- perturbation
- data synthesis
Data distribution shifts
covariate shift
- sample selection bias problem
- training data is artificially altered to make model easier to learn
- caused by model’s learning process, especially active learning
- major change in environment or the way your application is used
Detecting data distribution shifts
- statical methods: metrics, two-sample hypothesis test
Addressing data distribution shifts
- design your system to make it more robust to shifts
- make it easier for it to adapt to shifts
Data leakage
The phenomenon when a form of the label (which is not available during inference) ‘leaks’ into the set of features used for making predictions.
Cautions for data leakage
- splitting time-correlated data by time, instead of randomly.
- split data before data analysis and processing, as not gaining infos about test split.
- check duplicates before splitting / group leakage
- use statics from only the train split, to scale your features and handle missing values.
Model development
Training
Tips for Model selection
- Start with the simplest model, which you have a lot room to improve upon with.
- potential for improvements in the future.
- Use learning curves, to know how easy/difficult to achieve further improvements.
- evaluate trade-off: recall vs precision, latency vs performance, etc.
Ensemble of multi-models for prediction
- Bootstrap aggregating (bagging): sample data with replacement for different models.
- use majority vote of all models, improves unstable methods like neural networks.
- Boosting
- Stacking
Debugging ML models
Tips:
- start with simple and gradually add more components
- overfit a single batch
- set a random seed
Distributed training
data & model parallelism
AutoML
- soft AutoML: Hyper-parameter tuning
- as a part of standard pipelines
- Graduate student descent (GSD)
- popular methods: random search, grid search, Bayesian optimization
- sensitive hyper-parameters should be carefully tuned
- hard AutoML: Architecture search and learned optimizer
Model offline validation
Baselines
random baseline, simple heuristic, zero rule baseline(always predict the most common class), human baseline, existing solutions
differentiate between ‘a good system’ and ‘a useful system’
Evaluation methods
Except performance metrics, we need our model to be robust, fair, calibrated and overall make sense.
- perturbation tests: add noisy data for testing
- invariance tests: Exclude sensitive info from features used for training
- directional expectation tests
- model calibration
- slice-based evaluation: per-class, focus more on minority.
- focus on over-all performance may blind us to huge potential model improvements.
- some subsets of data may be more critical
- Simpson’s paradox: aggregation con conceal and contradict actual situation.
- help improve performance on both overall and critical data, and help detect potential biases
- methods to slice:
- heuristics-bases
- error analysis
- slice finder: beam search, clustering, decision
Monitoring
- accuracy-related metrics
- predictions (monitor predictions for data shifts)
- features
- feature validation (table testing), expected schema
- two-sample test
- raw input
Model compression
- low-rank factorization: compact filters
- knowledge distillation: sensitive to applications and model architectures
- Pruning:
- remove entire nodes
- sparse: find least useful parameters and set them to 0
- Quantization: straight forward to do, and generalizes over tasks and architectures.
- quantization aware training
- post-training
- fixed-point inference has become standard in the industry
- Tensorrt offer post-training quantization for free
Model optimization
- vectorization, parallelization, loop tiling (change data accessing order in a loop to leverage hardware’s memory layout and cache), operator fusion
- use ML to optimize: autoTVM