Skip to content

train_family_classifier.py

michele edited this page Oct 13, 2021 · 2 revisions

In this page

Imported Modules



  • from nets.Family_Classifier_net import Net as Family_Net
  • from nets.generators.fresh_generators import get_generator

Back to top

Classes and functions

train_network(fresh_ds_path, checkpoint_path, training_run, epochs, train_split_proportion, valid_split_proportion, test_split_proportion, batch_size, random_seed, workers) (function, baker command) - Train a family classifier model on the fresh dataset for the malware family classification task.

  • fresh_ds_path (arg) - Path of the directory where to find the fresh dataset (containing .dat files)
  • checkpoint_path (arg) - Path to the model checkpoint to load (default: 'None')
  • training_run (arg) - Training run identifier (default: 0)
  • epochs (arg) - How many epochs to train for (default: 25)
  • train_split_proportion (arg) - Train subsplit proportion value (default: 7)
  • valid_split_proportion (arg) - Valid subsplit proportion value (default: 1)
  • test_split_proportion (arg) - Test subsplit proportion value (default: 2)
  • batch_size (arg) - How many samples per batch to load (default: 250)
  • random_seed (arg) - If provided, seed random number generation with this value (default: None -> no seeding)
  • workers (arg) - How many workers (threads) should the dataloader use (default: 0 -> use multiprocessing.cpu_count())

__main__ (main) - Start baker in order to make it possible to run the script and use function names and parameters as the command line interface, using optparse-style options


Back to top