147 lines
4.1 KiB
Python
147 lines
4.1 KiB
Python
### COPIE-COLLE DE TRAINDATA.PY SI LE CODE NE FONCTIONNE PAS
|
|
|
|
# LANCER L'ENTRAINEMENT, CALCULE LA LOSS ET L'ACC ET AFFICHE LES GRAPHIQUES ASSOCIES
|
|
|
|
from datasettransform import *
|
|
from model import *
|
|
|
|
from torchmetrics import Accuracy
|
|
import torch.nn as nn
|
|
import matplotlib.pyplot as plt
|
|
import torch.optim as optim
|
|
import pandas as pd
|
|
import os
|
|
|
|
device= 'cpu'
|
|
|
|
|
|
### TRAIN FUNCTION ###
|
|
def train_one_epoch(model, train_loader, loss_fn, optimizer):
|
|
model.train()
|
|
loss_train = AverageMeter()
|
|
acc_train = Accuracy(task="multiclass", num_classes= num_class).to(device)
|
|
|
|
for i, (inputs, targets) in enumerate(train_loader):
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
|
|
outputs = model(inputs)
|
|
loss = loss_fn(outputs, targets)
|
|
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
loss_train.update(loss.item())
|
|
acc_train(outputs, targets.int())
|
|
|
|
return model, loss_train.avg, acc_train.compute().item()
|
|
|
|
def validation(model, val_loader, loss_fn):
|
|
model.eval()
|
|
loss_val = AverageMeter()
|
|
acc_val = Accuracy(task="multiclass", num_classes= num_class).to(device)
|
|
|
|
for i, (inputs, targets) in enumerate(val_loader):
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
|
|
outputs = model(inputs)
|
|
loss = loss_fn(outputs, targets)
|
|
|
|
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
|
|
|
loss_val.update(loss.item())
|
|
acc_val(outputs, targets.int())
|
|
|
|
return model, loss_val.avg, acc_val.compute().item()
|
|
|
|
# rajouter def test ici ?
|
|
|
|
|
|
### UTILS ###
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
### TRAINING ###
|
|
loss_fn= nn.CrossEntropyLoss().to(device)
|
|
optimizer= optim.NAdam(model001.parameters(), lr= 0.01)
|
|
|
|
loss_train_hist = []
|
|
acc_train_hist = []
|
|
loss_val_hist = []
|
|
acc_val_hist = []
|
|
|
|
for epoch in range(EPOCHS):
|
|
model, loss_train, acc_train = train_one_epoch(model001,
|
|
data_loader,
|
|
loss_fn,
|
|
optimizer)
|
|
model, loss_val, acc_val = validation(model001,
|
|
data_loader,
|
|
loss_fn)
|
|
|
|
loss_train_hist.append(loss_train)
|
|
acc_train_hist.append(acc_train)
|
|
loss_val_hist.append(loss_val)
|
|
acc_val_hist.append(acc_val)
|
|
|
|
if (epoch%10== 5)or(epoch%10== 0):
|
|
print(f'epoch {epoch}:')
|
|
print(f' train loss= {loss_train:.4}, val loss={loss_val:.4}, train acc= {int(acc_train*100)}%, val acc= {int(acc_val*100)}% \n')
|
|
|
|
|
|
# Sauvegarder des resultats en csv
|
|
dataframe = pd.DataFrame({"loss_train": loss_train_hist, "acc_train": acc_train_hist, "loss_val": loss_val_hist, "acc_val": acc_val_hist})
|
|
n = 0
|
|
dir_export = "./resultats"
|
|
for file in os.listdir(dir_export):
|
|
if file.endswith(".csv"):
|
|
n += 1
|
|
dataframe.to_csv(f"./resultats/model_perf_{n}.csv", sep = ',', index = False) # export du csv
|
|
|
|
### COURBE D'APPRENTISSAGE ###
|
|
# plt.plot(range(EPOCHS), acc_train_hist, 'b-', label='Train')
|
|
# plt.xlabel('Epoch')
|
|
# plt.ylabel('Acc')
|
|
# plt.grid(True)
|
|
# plt.legend()
|
|
# plt.title('Evolution de l entraînement (train)')
|
|
# plt.show()
|
|
# # print(len(acc_train_hist))
|
|
|
|
# ### COURBES LOSS ###
|
|
# plt.plot(loss_train_hist, label='Train loss')
|
|
# plt.plot(loss_val_hist, label='Val loss')
|
|
# plt.xlabel('Epoch')
|
|
# plt.ylabel('Loss')
|
|
# plt.grid(True)
|
|
# plt.legend()
|
|
# plt.title('Evolution de la perte (loss)')
|
|
# plt.show()
|
|
|
|
# ### COURBES ACC ###
|
|
# plt.plot(acc_train_hist, label='Train acc')
|
|
# plt.plot(acc_val_hist, label='Val acc')
|
|
# plt.xlabel('Epoch')
|
|
# plt.ylabel('Acc')
|
|
# plt.grid(True)
|
|
# plt.legend()
|
|
# plt.title('Evolution de la précision (acc)')
|
|
# plt.show() |