Torch_EEG_plants_electrophy.../inspi.py
2026-05-13 11:39:59 +02:00

124 lines
4.7 KiB
Python

# CONTIENT UNE VERSION DU MODELE ET UN DATASET PERSONNALISE,
# JE PEUX M'EN INSPIRER POUR NETTOYER DATASETTRANSFORM.PY ET LIRE LES FICHIERS
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics import Accuracy
import os
# --- Paramètres globaux ---
BATCH_SIZE = 8
EPOCHS = 50
fs = 173
channel = 1
num_input = 1
num_class = 5
signal_length = 4097
device = 'cpu'
# --- Architecture EEGNet (inchangée) ---
F1, D = 8, 3
F2 = D * F1
class EEGNet(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(num_input, F1, (1, round(fs/2)), padding=(0, round(fs/4)-1))
self.Batch_normalization_1 = nn.BatchNorm2d(F1)
self.Depthwise_conv2D = nn.Conv2d(F1, D*F1, (channel, 1), groups=F1)
self.Batch_normalization_2 = nn.BatchNorm2d(D*F1)
self.Elu = nn.ELU()
self.Average_pooling2D_1 = nn.AvgPool2d((1, 4))
self.Dropout = nn.Dropout2d(0.2)
self.Separable_conv2D_depth = nn.Conv2d(D*F1, D*F1, (1, round(fs/8)), padding=(0, round(fs/16)), groups=D*F1)
self.Separable_conv2D_point = nn.Conv2d(D*F1, F2, (1, 1))
self.Batch_normalization_3 = nn.BatchNorm2d(F2)
self.Average_pooling2D_2 = nn.AvgPool2d((1, 8))
self.Flatten = nn.Flatten()
# Calcul dynamique de la taille de sortie pour la couche Dense
self.Dense = nn.Linear(F2 * (signal_length // 32), num_class)
self.Softmax = nn.Softmax(dim=1)
def forward(self, x):
y = self.Batch_normalization_1(self.conv2d(x))
y = self.Batch_normalization_2(self.Depthwise_conv2D(y))
y = self.Elu(y)
y = self.Dropout(self.Average_pooling2D_1(y))
y = self.Separable_conv2D_depth(y)
y = self.Batch_normalization_3(self.Separable_conv2D_point(y))
y = self.Elu(y)
y = self.Dropout(self.Average_pooling2D_2(y))
y = self.Flatten(y)
return self.Softmax(self.Dense(y))
# --- Dataset adapté à ton architecture de dossiers ---
class EEGDataset(torch.utils.data.Dataset):
def __init__(self, root_path, signal_length=4097):
self.root_path = root_path
self.signal_length = signal_length
self.class_map = {'F': 0, 'N': 1, 'O': 2, 'S': 3, 'Z': 4}
self.samples = []
# Parcourt récursivement root_path (ex: data/train/) pour trouver les .txt dans F, N, O, S, Z
for class_name in self.class_map.keys():
class_folder = os.path.join(root_path, class_name)
if os.path.exists(class_folder):
for f in os.listdir(class_folder):
if f.lower().endswith('.txt'):
self.samples.append((os.path.join(class_folder, f), self.class_map[class_name]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
file_path, label = self.samples[idx]
signal = []
with open(file_path, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('['):
try:
signal.append(float(line))
except ValueError: continue
# Padding ou troncature pour assurer signal_length
if len(signal) < self.signal_length:
signal.extend([0.0] * (self.signal_length - len(signal)))
else:
signal = signal[:self.signal_length]
x = torch.tensor(signal, dtype=torch.float32).view(1, 1, -1)
return x, torch.tensor(label, dtype=torch.long)
# --- Initialisation des données ---
data_root = "C:/DATA/M1/Stages/Fablab/data"
train_loader = DataLoader(EEGDataset(os.path.join(data_root, 'train')),
batch_size=BATCH_SIZE,
shuffle=True)
val_loader = DataLoader(EEGDataset(os.path.join(data_root, 'val')),
batch_size=BATCH_SIZE,
shuffle=False)
test_loader = DataLoader(EEGDataset(os.path.join(data_root, 'test')),
batch_size=BATCH_SIZE,
shuffle=False)
# --- Fonction de Visualisation ---
def plot_eeg_samples(loader, n_samples=3):
samples, labels = next(iter(loader))
classes_names = {0: 'F', 1: 'N', 2: 'O', 3: 'S', 4: 'Z'}
plt.figure(figsize=(12, 8))
for i in range(min(n_samples, len(samples))):
plt.subplot(n_samples, 1, i+1)
plt.plot(samples[i].view(-1).numpy())
plt.title(f"Classe: {classes_names[labels[i].item()]}")
plt.ylabel("Amplitude")
plt.tight_layout()
plt.show()
# Test de visualisation
plot_eeg_samples(train_loader)