Source code for climb.tool.impl.data_suite.models.ensemble

import logging

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from climb.tool.impl.data_suite.models.base_model import MyDataset, Net, benchmark_trainer

torch.manual_seed(42)


[docs] class ensemble: def __init__(self, epochs=10, lr=0.01, batch_size=5, n_models=5, device="cpu"): self.epochs = epochs self.lr = lr self.batch_size = batch_size self.device = device self.n_models = n_models self.tr = None self.ensemble = {}
[docs] def fit(self, x_train, y_train): """ > This function fits an ensemble of n_models to the data Args: x_train: the training data y_train: the training labels """ dataset_train = MyDataset( data=x_train, targets=y_train, transform=None, ) train_loader = DataLoader(dataset_train, batch_size=self.batch_size) dim = x_train.shape[1] for i in range(self.n_models): # Instantiate the model + optimizer - can be anything all_equal = True def all_equal2(iterator): return len(set(iterator)) <= 1 runs = 0 while all_equal: # Instantiate the model + optimizer - can be anything logging.info(f"Instantiating model {i} - Ensemble") model = Net(dim=dim).to(self.device) optimizer = optim.Adam(model.parameters(), lr=self.lr) # generic training & test loop self.tr = benchmark_trainer(model, self.device) logging.info(f"Training model {i} - Ensemble...") self.tr.fit(train_loader, optimizer, epochs=self.epochs) preds, uncertainty = self.tr.predict( train_loader, mc_samples=3, ) runs += 1 if runs > 2: break if all_equal2(preds) is False: self.ensemble[i] = self.tr break
[docs] def predict(self, x_test, y_test, mc_samples=3): """ > For each model we get the predictions and the uncertainty. Args: x_test: the test data y_test: the true labels of the test set mc_samples: number of Monte Carlo samples to use for prediction Returns: The mean of the predictions and the standard deviation of the predictions. """ dataset_test = MyDataset(data=x_test, targets=y_test, transform=None) test_loader = DataLoader(dataset_test, batch_size=self.batch_size) preds_overall = [] for i in range(self.n_models): logging.info(f"Testing model {i}...") preds, _ = self.tr.predict(test_loader, mc_samples=mc_samples) preds_overall.append(preds) if self.n_models > 1: preds = np.mean(np.array(preds_overall), axis=0) uncertainty = np.std(np.array(preds_overall), axis=0) return preds, uncertainty