Source code for unetseg.console.train

# -*- coding: utf-8 -*-
"""
This is a skeleton file that can serve as a starting point for a Python
console script. To run this script uncomment the following lines in the
[options.entry_points] section in setup.cfg:

    console_scripts =
         fibonacci = unetseg.skeleton:run

Then run `python setup.py install` which will install the command `fibonacci`
inside your current environment.
Besides console scripts, the header (i.e. until _logger...) of this file can
also be used as template for Python modules.

Note: This skeleton file can be safely removed if not needed!
"""

import argparse
import logging
import sys

from unetseg import __version__
from unetseg.train import TrainConfig, train

__author__ = "Damián Silvani"
__copyright__ = "Dymaxion Labs"
__license__ = "apache-2.0"

_logger = logging.getLogger(__name__)


[docs]def parse_args(args): """Parse command line parameters Args: args ([str]): command line parameters as list of strings Returns: :obj:`argparse.Namespace`: command line parameters namespace """ parser = argparse.ArgumentParser( description="Train a model", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--version", action="version", version="unetseg {ver}".format(ver=__version__) ) parser.add_argument( "-v", "--verbose", dest="loglevel", help="set loglevel to INFO", action="store_const", const=logging.INFO, ) parser.add_argument( "-vv", "--very-verbose", dest="loglevel", help="set loglevel to DEBUG", action="store_const", const=logging.DEBUG, ) parser.add_argument( "train_dir", help="Path to image tiles and masks (directory with images/ and masks/)", ) parser.add_argument( "-o", "--output", help="path to output model (.h5)", default="./unet.h5" ) parser.add_argument("-W", "--width", type=int, help="Image tile width") parser.add_argument("-H", "--height", type=int, help="Image tile height") parser.add_argument( "-N", "--num-channels", default=3, type=int, help="Number of channels" ) parser.add_argument( "-C", "--num-classes", default=1, type=int, help="Number of classes" ) parser.add_argument( "-E", "--epochs", default=15, type=int, help="number of training epochs" ) parser.add_argument( "--steps-per-epoch", default=100, type=int, help="steps per epoch" ) parser.add_argument( "--early-stopping-patience", default=3, type=int, help="number of epochs with no improvement after which training will be stopped", ) parser.add_argument("--batch-size", default=32, type=int, help="batch size") parser.add_argument( "--image-augmentation", dest="image_augmentation", help="Apply image augmentation", action="store_true", default=True, ) parser.add_argument( "--no-image-augmentation", dest="image_augmentation", help="Do not apply image augmentation", action="store_false", default=False, ) parser.add_argument( "--evaluate", dest="evaluate", help="Evaluate metrics over validation set at the end of training", action="store_true", default=True, ) parser.add_argument( "--no-evaluate", dest="evaluate", help="Do not evaluate emtrics over validation set at the end of training", action="store_false", default=False, ) parser.add_argument( "-s", "--seed", default=None, type=int, help="Seed number for the random number generation", ) return parser.parse_args(args)
[docs]def setup_logging(loglevel): """Setup basic logging Args: loglevel (int): minimum loglevel for emitting messages """ logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s" logging.basicConfig( level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S" )
[docs]def main(args): """Main entry point allowing external calls Args: args ([str]): command line parameter list """ args = parse_args(args) setup_logging(args.loglevel) config = TrainConfig( width=args.width, height=args.height, n_channels=args.num_channels, n_classes=args.num_classes, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch, early_stopping_patience=args.early_stopping_patience, apply_image_augmentation=args.image_augmentation, batch_size=args.batch_size, seed=args.seed, images_path=args.train_dir, evaluate=args.evaluate, model_path=args.output, ) train(config)
[docs]def run(): """Entry point for console_scripts""" main(sys.argv[1:])
if __name__ == "__main__": run()