transformertc.train¶
Train function for Token level Classification with BERT
Module Contents¶
-
transformertc.train.logger¶
-
transformertc.train.set_seed(seed)¶
-
transformertc.train.get_steps_and_epochs(epochs, epoch_steps, max_steps)¶
-
transformertc.train.log_train_start(total_steps, epoch_steps, epochs, batch_size)¶
-
transformertc.train.get_optimizer_and_scheduler(model, lr, wdecay, adam_eps, warmup_steps, total_steps)¶
-
transformertc.train.log_step(model, device, logging_steps, global_step, logging_loss, total_loss, scheduler, val_dataloader=None, label_list=None, no_tqdm=False)¶
-
transformertc.train.pytorch_train(model, device, dataloader, epochs=4, max_steps=0, lr=5e-05, wdecay=0.0, warmup_steps=0, adam_epsilon=1e-08, max_grad_norm=1.0, logging_steps=0, labels=None, val_dataloader=None, no_tqdm=False)¶ - Train the model with pytorch.
Either args.num_train epochs or args.max_steps need to be passed. The latter takes priority if both are passed.
- Parameters
args.device – the device (e.g. gpu) where the model is
args.max_steps (int, optional) – maximum steps, defaults to epochs * steps_per_epoch
args.num_train_epochs (int, optional) – number of epochs to train for
args.learning_rate –
args.wdecay –
args.warmup_steps –
args.adam_epsilon –
args.max_grad_norm –
args.no_tqdm (bool, optional) – enable or disable TQDM
args.seed (int, optional) – random seed for python/numpy/torch
args.checkpoint_dir (str) – save model checkpoints to this directory
args.logging_steps (int, optional) – log every logging_steps