Source code for climb.tool.impl.data_suite.models.mcd

import logging

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from climb.tool.impl.data_suite.models.base_model import MyDataset, Net, benchmark_trainer

torch.manual_seed(42)


[docs] def all_equal2(iterator): return len(set(iterator)) <= 1
[docs] class mc_dropout: def __init__(self, epochs=10, lr=0.01, batch_size=5, device="cpu"): self.epochs = epochs self.lr = lr self.batch_size = batch_size self.device = device self.tr = None
[docs] def fit(self, x_train, y_train): """ > The function instantiates a model, trains it, and then checks if the predictions are all equal. If they are, it instantiates a new model and repeats the process Args: x_train: the training data y_train: the training labels """ dataset_train = MyDataset( data=x_train, targets=y_train, transform=None, ) train_loader = DataLoader(dataset_train, batch_size=self.batch_size) dim = x_train.shape[1] all_equal = True runs = 0 while all_equal: # Instantiate the model + optimizer - can be anything logging.info("Instantiating model - MCD") model = Net(dim=dim).to(self.device) optimizer = optim.Adam(model.parameters(), lr=self.lr) # generic training & test loop self.tr = benchmark_trainer(model, self.device) logging.info("Training model (MCD)...") self.tr.fit(train_loader, optimizer, epochs=self.epochs) preds, uncertainty = self.tr.predict(train_loader, mc_samples=3) runs += 1 if runs > 3: break if all_equal2(preds) is False: break
[docs] def predict(self, x_test, y_test, mc_samples=3): """ > The function takes in the test data and test labels, and returns the predictions and the uncertainty of the predictions Args: x_test: the test data y_test: the actual labels of the test set mc_samples: number of Monte Carlo samples to use for prediction. Defaults to 3 Returns: The predictions and the uncertainty of the predictions. """ logging.info("Testing model...") dataset_test = MyDataset(data=x_test, targets=y_test, transform=None) test_loader = DataLoader(dataset_test, batch_size=self.batch_size) preds, uncertainty = self.tr.predict( test_loader, mc_samples=mc_samples, ) return preds, uncertainty