AIToolbox - Model Training Framework for PyTorch
AIToolbox is a framework which helps you train deep learning models in PyTorch and quickly iterate experiments. It hides the repetitive technicalities of training the neural nets and frees you to focus on interesting part of devising new models. In essence, it offers a keras-style train loop abstraction which can be used for higher level training process while still allowing the manual control on the lower level when desired.
In addition to orchestrating the model training loop the framework also helps you keep track of different experiments by automatically saving models in a structured traceable way and creating performance reports. These can be stored both locally or on AWS S3 (Google Cloud Storage in beta) which makes the library very useful when training on the GPU instance on AWS. Instance can be automatically shut down when training is finished and all the results are safely stored on S3.
AIToolbox consists of three main user-facing components:
aitoolbox.torchtrain- PyTorch train loop engine
aitoolbox.experiment- experiment tracking
aitoolbox.cloud- cloud operations for AWS and Google Cloud
All three AIToolbox components can be used independently when only some subset of functionality is desired in a project.
However, the greatest benefit of AIToolbox comes when all components are used together in unison in order to ease
the process of PyTorch model training and experiment tracking as much as possible.
Most of this top-level API is exposed to the user via the functionality implemented in
To install the AIToolbox package execute:
pip install aitoolbox
If you want to install the most recent version from github repository, first clone the package repository and then install via the pip command:
git clone https://github.com/mv1388/aitoolbox.git pip install ./aitoolbox
AIToolbox package can be also provided as a dependency in the
requirements.txt file. This can be done by
just specifying the
aitoolbox dependency. On the other hand, to automatically
download the current master branch from github include the following dependency specification in the requirements.txt: