This repository contains a customizable PyTorch training loop template that simplifies training, validation, and testing of models. It includes support for:
- ✅ Early stopping
- 📉 Learning rate scheduling
- 📊 Metric logging
- ❄️ Freezing part of your network for first few epochs
- 🔧 Fast dev run: train the model on a single batch to ensure soundness
- 🛑 Graceful
Keyboard Interrupthandling during training, will return the results up to the current epoch
Instead of rewriting boilerplate code for every project, use this reusable trainer as a solid starting point and adapt it to your specific needs!
A working example using the Iris dataset, including a custom Dataset class and a simple FNN model, is available in main.py.
- Copy the
trainer.pyfile into your project. - Add the packages listed in
requirements.txtto your project environment. - By default, your
DataLoadershould return batches as a dictionary with the following keys:{'input': [...], 'target': [...]} - Create the
output_parse(self, output)function in your model class to match the output requirements of your model
from trainer import Trainer
trainer = Trainer(model=model, device='cpu')results = trainer.fit(
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
criterion=criterion,
max_epochs=100,
early_stopping=True,
patience=5,
early_stopping_monitor='accuracy',
early_stopping_mode='max',
scheduler=None,
metrics={
'accuracy': sklearn.metrics.accuracy_score
},
fast_dev_run=False,
)The trainer.fit function will output a dictionary containing training, validation, and test losses and metrics.
test_results = trainer.test(test_loader, criterion)The trainer.test function will output a dictionary containing test losses and metrics.
y_pred = trainer.predict(test_loader)The trainer.predict function accepts both a DataLoader and a single input, and it outputs the predictions.
This project is under the MIT license. See LICENSE for more information.
