Modelo Mandani en Pytorch

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os


class MamdaniFISPyTorch(nn.Module):

    def __init__(self, n_rules=6, n_points=1000):
        super().__init__()
        self.n_rules = n_rules
        self.n_points = n_points
        self.is_fit = False

    def initialize_parameters(self, X, Y):

        self.x_min = X.min(dim=0).values
        self.x_max = X.max(dim=0).values
        self.y_min = Y.min(dim=0).values
        self.y_max = Y.max(dim=0).values
        
        self.x_ = torch.zeros(self.n_points, self.input_dim)
        self.y_ = torch.zeros(self.n_points, self.output_dim)
        for idx_dim in range(self.input_dim):
            self.x_[:, idx_dim] = torch.linspace(self.x_min[idx_dim], self.x_max[idx_dim], self.n_points)
        for idx_dim in range(self.output_dim):
            self.y_[:, idx_dim] = torch.linspace(self.y_min[idx_dim], self.y_max[idx_dim], self.n_points)

        delta_x = (self.x_max - self.x_min) / (self.n_rules - 1)
        sigma_x = delta_x / (2 * np.sqrt(2 * np.log(2)))
        for idim in range(self.input_dim):
            centers = self.x_min[idim] + torch.arange(self.n_rules) * delta_x[idim]
            self.p_premises.data[:, idim, 0] = centers
            self.p_premises.data[:, idim, 1] = sigma_x[idim]
        
        delta_y = (self.y_max - self.y_min) / (self.n_rules - 1)
        sigma_y = delta_y / (2 * np.sqrt(2 * np.log(2)))
        for idim in range(self.output_dim):
            centers = self.y_min[idim] + torch.arange(self.n_rules) * delta_y[idim]
            self.p_consequent.data[:, idim, 0] = centers
            self.p_consequent.data[:, idim, 1] = sigma_y[idim]

    def compute_membership_values(self, x, parameters):
        n_samples = x.shape[0]
        x_rep = x.unsqueeze(1).repeat(1, self.n_rules, 1)
        centers = parameters[:,:,0].unsqueeze(0).repeat(n_samples, 1, 1)
        spreads = parameters[:,:,1].unsqueeze(0).repeat(n_samples, 1, 1)
        mv = torch.exp(-((x_rep - centers) ** 2) \
                               / (2 * spreads ** 2))
        mv = torch.where(mv < 1e-3, torch.tensor(0.0, device=mv.device, dtype=mv.dtype), mv)
        return mv

    def forward(self, X):
        if not self.is_fit:
            raise RuntimeError("Model not fit. Call fit() first.")

        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32)
        else:
            X = X.detach().clone()

        if X.dim() == 1:
            X = X.unsqueeze(1)

        n_samples = X.shape[0]
        
        mv_inputs = self.compute_membership_values(X, self.p_premises) # (n_samples, n_rules, n_dims)
        mv_outputs = self.compute_membership_values(self.y_, self.p_consequent) # (n_points, n_rules, n_dims)
        mv_outputs = mv_outputs.unsqueeze(0).repeat(n_samples, 1, 1, 1) # (n_samples, n_points, n_rules, n_dims)

        firing_strengths = mv_inputs.min(dim=2).values # (n_samples, n_rules)
        firing_strengths = firing_strengths.unsqueeze(1).unsqueeze(-1).repeat(1, self.n_points, 1, 1) # (n_samples, n_points, n_rules, n_dims)
        outputs_rules = torch.min(firing_strengths, mv_outputs) # (n_samples, n_points, n_rules, n_dims)
        outputs_aggreg = torch.max(outputs_rules, dim=2).values # (n_samples, n_points, n_dims)
        output_defuzz = torch.sum(self.y_.unsqueeze(0).repeat(n_samples, 1, 1) * outputs_aggreg, dim=1) \
                                         / (torch.sum(outputs_aggreg, dim=1) + torch.finfo(torch.float32).eps)
        return output_defuzz

    def fit(self, X, Y, lr=0.01, epochs=500, savefigs=False, val_ratio=0.2, random_seed=42, idx_point=None):

        torch.manual_seed(random_seed)

        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32)
        else:
            X = X.detach().clone()
        if not isinstance(Y, torch.Tensor):
            Y = torch.tensor(Y, dtype=torch.float32)
        else:
            Y = Y.detach().clone()

        if X.dim() == 1:
            X = X.unsqueeze(1)
        if Y.dim() == 1:
            Y = Y.unsqueeze(1)

        self.input_dim = X.shape[1]
        self.output_dim = Y.shape[1]
        
        self.p_premises = nn.Parameter(torch.zeros(self.n_rules, self.input_dim, 2))   # [center, sigma]
        self.p_consequent = nn.Parameter(torch.zeros(self.n_rules, self.output_dim, 2)) # [center, sigma]

        self.initialize_parameters(X, Y)
        self.is_fit = True

        optimizer = optim.Adam(self.parameters(), lr=lr)
        loss_fn = nn.MSELoss()

        train_losses = []
        val_losses = []

        n_samples = X.shape[0]
        idx = torch.randperm(n_samples)
        n_val = int(val_ratio * n_samples)
        val_idx = idx[:n_val]
        train_idx = idx[n_val:]

        X_train = X[train_idx]
        Y_train = Y[train_idx]
        X_val = X[val_idx]
        Y_val = Y[val_idx]

        if savefigs:
            if idx_point is None:
                idx_point = 40
            os.makedirs('temp', exist_ok=True)
            x_point = X_train[idx_point].unsqueeze(0)
            y_point = Y_train[idx_point].unsqueeze(0)
            print(x_point.shape, y_point.shape)
            self.plot_fs(figname='fig-epoch000')
            self.plot_fs(figname='fig-point-epoch000', sample_point=(x_point, y_point))

        for epoch in range(epochs):
            self.train()
            optimizer.zero_grad()
            y_pred = self.forward(X_train)
            loss = loss_fn(y_pred, Y_train)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()

            # Val loss
            self.eval()
            with torch.no_grad():
                y_pred_val = self.forward(X_val)
                val_loss = loss_fn(y_pred_val, Y_val)
                val_losses.append(val_loss.item())

            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, \
                      Train Loss: {loss.item():.6f}, \
                        Val Loss: {val_loss.item():.6f}")
                
                if savefigs:
                    os.makedirs('temp', exist_ok=True)
                    self.plot_fs(figname=f'fig-epoch{epoch+1}')
                    self.plot_fs(figname=f'fig-point-epoch{epoch+1}', sample_point=(x_point, y_point))

        return train_losses, val_losses


    def plot_fs(self, sample_point=None, figname=None, output_dir='temp'):

        mv_inputs = self.compute_membership_values(self.x_, self.p_premises) # (n_points, n_rules, n_dims)
        mv_outputs = self.compute_membership_values(self.y_, self.p_consequent) # (n_points, n_rules, n_dims)

        ########################
        nrows = self.n_rules
        if sample_point is not None:
            nrows = self.n_rules+1
            x_point, y_point = sample_point
            mv_inputs_ = self.compute_membership_values(x_point, self.p_premises) # (n_samples=1, n_rules, n_dims)
            mv_outputs_augm = mv_outputs.unsqueeze(0) # (n_samples, n_points, n_rules, n_dims)
            firing_strengths = mv_inputs_.min(dim=2).values # (n_samples, n_rules)
            firing_strengths = firing_strengths.unsqueeze(1).unsqueeze(-1).repeat(1, self.n_points, 1, 1) # (n_samples, n_points, n_rules, n_dims)
            outputs_rules = torch.min(firing_strengths, mv_outputs_augm) # (n_samples, n_points, n_rules, n_dims)
            outputs_aggreg = torch.max(outputs_rules, dim=2).values # (n_samples, n_points, n_dims)
            output_defuzz = torch.sum(self.y_.unsqueeze(0) * outputs_aggreg, dim=1) \
                                            / (torch.sum(outputs_aggreg, dim=1) + torch.finfo(torch.float32).eps)


        fig, axs = plt.subplots(nrows=nrows,
                                ncols=self.input_dim+self.output_dim,
                                figsize=(12, 10))
        # inputs
        for idx_dim in range(self.input_dim):
            for idx_rule in range(self.n_rules):
                axs[idx_rule, idx_dim].plot(self.x_[:, idx_dim].numpy(),
                                            mv_inputs[:, idx_rule, idx_dim].detach().numpy())
                axs[idx_rule, idx_dim].set_ylim([0, 1])
                axs[idx_rule, idx_dim].set_xlim([self.x_min[idx_dim], 
                                                    self.x_max[idx_dim]])
                
                if sample_point is not None:
                    axs[idx_rule, idx_dim].axvline(x_point.numpy().ravel()[idx_dim], ymin=0, ymax=1)
                    axs[idx_rule, idx_dim].axhline(mv_inputs_.detach().numpy()[0, idx_rule, idx_dim], 
                                                   xmin=self.x_min.numpy()[idx_dim], xmax=self.x_max.numpy()[idx_dim])
                    
                # only show y-axis on the first column
                if idx_dim == 0:
                    axs[idx_rule, 0].set_ylabel(f'Rule {idx_rule+1}')
                else:
                    axs[idx_rule, idx_dim].set_yticklabels([])

                # # only show x-axis on the last row
                # if idx_rule != self.n_rules - 1:
                #     axs[idx_rule, idx_dim].set_xticklabels([])
                # only show x-axis on the first row
                if idx_rule == 0:
                    axs[idx_rule, idx_dim].xaxis.set_ticks_position('top')
                else:
                    axs[idx_rule, idx_dim].set_xticklabels([])

            axs[0, idx_dim].set_title(f'$x_{idx_dim+1}$')

        # outputs
        for idx_dim in range(self.output_dim):
            for idx_rule in range(self.n_rules):
                axs[idx_rule, self.input_dim + idx_dim].plot(self.y_[:, idx_dim].numpy(),
                                            mv_outputs[:, idx_rule, idx_dim].detach().numpy())
                if sample_point is not None:
                    axs[idx_rule, self.input_dim + idx_dim].plot(self.y_[:, idx_dim].numpy(),
                        outputs_rules.detach().numpy()[0, :, idx_rule, idx_dim])

                axs[idx_rule, self.input_dim + idx_dim].set_ylim([0, 1])
                axs[idx_rule, self.input_dim + idx_dim].set_xlim(
                    [self.y_min[idx_dim], self.y_max[idx_dim]])

                if idx_rule == 0:
                    axs[idx_rule, self.input_dim + idx_dim].xaxis.set_ticks_position('top')
                    axs[idx_rule, self.input_dim + idx_dim].set_title(f'$y_{idx_dim+1}$')
                else:
                    axs[idx_rule, self.input_dim + idx_dim].set_xticklabels([])
                axs[idx_rule, self.input_dim + idx_dim].set_yticklabels([])

        if sample_point is not None:
            # show the aggregated fuzzy set
            for idx_dim in range(self.input_dim):
                axs[-1, idx_dim].axis('off')
            for idx_dim in range(self.output_dim):
                axs[-1, self.input_dim + idx_dim].plot(self.y_[:, idx_dim].numpy(),
                                                       outputs_aggreg.detach().numpy()[0, :, idx_dim])
                axs[-1, self.input_dim + idx_dim].set_ylim([0, 1])
                axs[-1, self.input_dim + idx_dim].set_xlim([self.y_min[idx_dim], self.y_max[idx_dim]])
                if idx_dim == 0:
                    axs[-1, self.input_dim + idx_dim].set_ylabel('Output')
                else:
                    axs[-1, self.input_dim + idx_dim].set_yticklabels([])
                axs[-1, self.input_dim + idx_dim].axvline(output_defuzz.detach().numpy().ravel()[idx_dim], ymin=0, ymax=1)
                axs[-1, self.input_dim + idx_dim].axvline(y_point.numpy().ravel()[idx_dim], ymin=0, ymax=1, color='red')

        plt.tight_layout(pad=0.05, w_pad=0.02, h_pad=0.02)
        if figname is not None:
            plt.savefig(os.path.join(output_dir, f'{figname}.pdf'))
            plt.close(fig)
        else:
            plt.show()

Generación de datos de prueba

np.random.seed(0)
d = 2
n = 100
X = np.random.normal(size=(n, d))
c = 0.7
mean = np.asarray([0., 0.])
cov = np.asarray([[1., c], [c, 1.]])
L = np.linalg.cholesky(cov)
X1 = X @ L.T + mean

c = 0.5
mean = np.asarray([5., 2.])
cov = np.asarray([[1., c], [c, 1.]])
L = np.linalg.cholesky(cov)
X2 = X @ L.T + mean

X = np.concatenate([X1, X2], axis=0)
Y1 = X @ np.array([2, 5]) + 3
Y2 = X @ np.array([4, 10]) + 20
Y = np.concatenate([Y1.reshape(-1, 1), Y2.reshape(-1, 1)], axis=1)

Entrenamiento del modelo



model = MamdaniFISPyTorch(n_rules=10)
train_losses, val_losses = model.fit(X, Y, lr=0.05, epochs=100, savefigs=True)
y_pred = model.forward(X)
plt.figure()
plt.plot(Y[:,0])
plt.plot(y_pred.detach().numpy()[:,0])
plt.figure()
plt.plot(Y[:,1])
plt.plot(y_pred.detach().numpy()[:,1])
plt.figure()
plt.plot(train_losses)
plt.plot(val_losses)
plt.show()

torch.Size([1, 2]) torch.Size([1, 2])
Epoch 10/100,                       Train Loss: 75.618050,                         Val Loss: 2.430635
Epoch 20/100,                       Train Loss: 0.956672,                         Val Loss: 0.846283
Epoch 30/100,                       Train Loss: 0.541443,                         Val Loss: 0.692809
Epoch 40/100,                       Train Loss: 0.375537,                         Val Loss: 0.463002
Epoch 50/100,                       Train Loss: 0.282071,                         Val Loss: 0.450215
Epoch 60/100,                       Train Loss: 0.242044,                         Val Loss: 0.488789
Epoch 70/100,                       Train Loss: 0.225082,                         Val Loss: 0.484702
Epoch 80/100,                       Train Loss: 0.211399,                         Val Loss: 0.482753
Epoch 90/100,                       Train Loss: 0.201069,                         Val Loss: 0.483513
Epoch 100/100,                       Train Loss: 0.189533,                         Val Loss: 0.496988