124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
# CONTIENT UNE VERSION DU MODELE ET UN DATASET PERSONNALISE,
|
|
# JE PEUX M'EN INSPIRER POUR NETTOYER DATASETTRANSFORM.PY ET LIRE LES FICHIERS
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from torchmetrics import Accuracy
|
|
import os
|
|
|
|
# --- Paramètres globaux ---
|
|
BATCH_SIZE = 8
|
|
EPOCHS = 50
|
|
fs = 173
|
|
channel = 1
|
|
num_input = 1
|
|
num_class = 5
|
|
signal_length = 4097
|
|
device = 'cpu'
|
|
|
|
# --- Architecture EEGNet (inchangée) ---
|
|
F1, D = 8, 3
|
|
F2 = D * F1
|
|
|
|
class EEGNet(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv2d = nn.Conv2d(num_input, F1, (1, round(fs/2)), padding=(0, round(fs/4)-1))
|
|
self.Batch_normalization_1 = nn.BatchNorm2d(F1)
|
|
self.Depthwise_conv2D = nn.Conv2d(F1, D*F1, (channel, 1), groups=F1)
|
|
self.Batch_normalization_2 = nn.BatchNorm2d(D*F1)
|
|
self.Elu = nn.ELU()
|
|
self.Average_pooling2D_1 = nn.AvgPool2d((1, 4))
|
|
self.Dropout = nn.Dropout2d(0.2)
|
|
self.Separable_conv2D_depth = nn.Conv2d(D*F1, D*F1, (1, round(fs/8)), padding=(0, round(fs/16)), groups=D*F1)
|
|
self.Separable_conv2D_point = nn.Conv2d(D*F1, F2, (1, 1))
|
|
self.Batch_normalization_3 = nn.BatchNorm2d(F2)
|
|
self.Average_pooling2D_2 = nn.AvgPool2d((1, 8))
|
|
self.Flatten = nn.Flatten()
|
|
# Calcul dynamique de la taille de sortie pour la couche Dense
|
|
self.Dense = nn.Linear(F2 * (signal_length // 32), num_class)
|
|
self.Softmax = nn.Softmax(dim=1)
|
|
|
|
def forward(self, x):
|
|
y = self.Batch_normalization_1(self.conv2d(x))
|
|
y = self.Batch_normalization_2(self.Depthwise_conv2D(y))
|
|
y = self.Elu(y)
|
|
y = self.Dropout(self.Average_pooling2D_1(y))
|
|
y = self.Separable_conv2D_depth(y)
|
|
y = self.Batch_normalization_3(self.Separable_conv2D_point(y))
|
|
y = self.Elu(y)
|
|
y = self.Dropout(self.Average_pooling2D_2(y))
|
|
y = self.Flatten(y)
|
|
return self.Softmax(self.Dense(y))
|
|
|
|
# --- Dataset adapté à ton architecture de dossiers ---
|
|
class EEGDataset(torch.utils.data.Dataset):
|
|
def __init__(self, root_path, signal_length=4097):
|
|
self.root_path = root_path
|
|
self.signal_length = signal_length
|
|
self.class_map = {'F': 0, 'N': 1, 'O': 2, 'S': 3, 'Z': 4}
|
|
self.samples = []
|
|
|
|
# Parcourt récursivement root_path (ex: data/train/) pour trouver les .txt dans F, N, O, S, Z
|
|
for class_name in self.class_map.keys():
|
|
class_folder = os.path.join(root_path, class_name)
|
|
if os.path.exists(class_folder):
|
|
for f in os.listdir(class_folder):
|
|
if f.lower().endswith('.txt'):
|
|
self.samples.append((os.path.join(class_folder, f), self.class_map[class_name]))
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
file_path, label = self.samples[idx]
|
|
signal = []
|
|
with open(file_path, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line and not line.startswith('['):
|
|
try:
|
|
signal.append(float(line))
|
|
except ValueError: continue
|
|
|
|
# Padding ou troncature pour assurer signal_length
|
|
if len(signal) < self.signal_length:
|
|
signal.extend([0.0] * (self.signal_length - len(signal)))
|
|
else:
|
|
signal = signal[:self.signal_length]
|
|
|
|
x = torch.tensor(signal, dtype=torch.float32).view(1, 1, -1)
|
|
return x, torch.tensor(label, dtype=torch.long)
|
|
|
|
# --- Initialisation des données ---
|
|
data_root = "C:/DATA/M1/Stages/Fablab/data"
|
|
train_loader = DataLoader(EEGDataset(os.path.join(data_root, 'train')),
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=True)
|
|
val_loader = DataLoader(EEGDataset(os.path.join(data_root, 'val')),
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False)
|
|
test_loader = DataLoader(EEGDataset(os.path.join(data_root, 'test')),
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False)
|
|
|
|
# --- Fonction de Visualisation ---
|
|
def plot_eeg_samples(loader, n_samples=3):
|
|
samples, labels = next(iter(loader))
|
|
classes_names = {0: 'F', 1: 'N', 2: 'O', 3: 'S', 4: 'Z'}
|
|
|
|
plt.figure(figsize=(12, 8))
|
|
for i in range(min(n_samples, len(samples))):
|
|
plt.subplot(n_samples, 1, i+1)
|
|
plt.plot(samples[i].view(-1).numpy())
|
|
plt.title(f"Classe: {classes_names[labels[i].item()]}")
|
|
plt.ylabel("Amplitude")
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
# Test de visualisation
|
|
plot_eeg_samples(train_loader) |