transformertc.train

Train function for Token level Classification with BERT

Module Contents

transformertc.train.logger
transformertc.train.set_seed(seed)
transformertc.train.get_steps_and_epochs(epochs, epoch_steps, max_steps)
transformertc.train.log_train_start(total_steps, epoch_steps, epochs, batch_size)
transformertc.train.get_optimizer_and_scheduler(model, lr, wdecay, adam_eps, warmup_steps, total_steps)
transformertc.train.log_step(model, device, logging_steps, global_step, logging_loss, total_loss, scheduler, val_dataloader=None, label_list=None, no_tqdm=False)
transformertc.train.pytorch_train(model, device, dataloader, epochs=4, max_steps=0, lr=5e-05, wdecay=0.0, warmup_steps=0, adam_epsilon=1e-08, max_grad_norm=1.0, logging_steps=0, labels=None, val_dataloader=None, no_tqdm=False)
Train the model with pytorch.

Either args.num_train epochs or args.max_steps need to be passed. The latter takes priority if both are passed.

Parameters
  • args.device – the device (e.g. gpu) where the model is

  • args.max_steps (int, optional) – maximum steps, defaults to epochs * steps_per_epoch

  • args.num_train_epochs (int, optional) – number of epochs to train for

  • args.learning_rate

  • args.wdecay

  • args.warmup_steps

  • args.adam_epsilon

  • args.max_grad_norm

  • args.no_tqdm (bool, optional) – enable or disable TQDM

  • args.seed (int, optional) – random seed for python/numpy/torch

  • args.checkpoint_dir (str) – save model checkpoints to this directory

  • args.logging_steps (int, optional) – log every logging_steps