fix(ml): align input shape to 224x224 + pass ArrayBuffer to fast-tflite

The .tflite reports inputs[0].shape == [1, 224, 224, 3] but the
preprocessing was producing 256x256x3 buffers, so 196608 floats were
written into a tensor sized for 150528. Inference still completed
(no allocator check on the JS side) but ran on shifted, decadred
data — predictions were effectively random.

Also: the v3 react-native-fast-tflite binding rejects raw TypedArray
views with "TfliteModel.runSync(...): Object \"<element dump>\"" and
only accepts the underlying ArrayBuffer. We now pass `input.buffer`
as the primary path and keep the TypedArray as a fallback.

Bonus:
- Read input dtype from the model and dispatch preprocess accordingly
  (float32, uint8, int8) instead of hard-coding float32. Future-proof
  for a quantized re-export.
- Dequantize uint8/int8 outputs to floats so argmax stays consistent.
- Log model.inputs and model.outputs at load time — invaluable when
  re-exporting the .tflite and discovering shape mismatches.

Validated on device (Samsung S23): preprocess 700ms + inference 39ms,
no fallback chain. Still ~25% accuracy because the model itself is
overfit (see docs/audit_report.md), but the inference plumbing is
finally honest.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Yanis 2026-05-01 13:14:12 +02:00
parent 26d0f39986
commit c5d59fc092
2 changed files with 138 additions and 21 deletions

View file

@ -1,11 +1,24 @@
import { manipulateAsync, SaveFormat } from 'expo-image-manipulator';
import * as jpeg from 'jpeg-js';
// Le modèle Python a été entraîné en 256×256 (MobileNetV2).
// Toute modification doit rester synchronisée avec l'export TFLite.
export const MODEL_INPUT_SIZE = 256;
// Le .tflite exporté attend [1, 224, 224, 3] (la shape par défaut MobileNetV2,
// confirmée par model.inputs[0].shape sur device). Si tu réexportes le modèle
// avec une autre shape, mets à jour cette constante en miroir — sinon le
// preprocess produit un buffer de mauvaise taille et l'inférence tourne sur
// des données décadrées (= prédictions aléatoires sans erreur visible).
export const MODEL_INPUT_SIZE = 224;
export async function preprocessImage(uri: string): Promise<Float32Array> {
export type TfliteInputDType = 'float32' | 'float16' | 'uint8' | 'int8';
export type TfliteInputArray =
| Float32Array
| Uint8Array
| Int8Array;
export async function preprocessImage(
uri: string,
dtype: TfliteInputDType | string = 'float32',
): Promise<TfliteInputArray> {
const resized = await manipulateAsync(
uri,
[{ resize: { width: MODEL_INPUT_SIZE, height: MODEL_INPUT_SIZE } }],
@ -27,15 +40,38 @@ export async function preprocessImage(uri: string): Promise<Float32Array> {
const rgba = decoded.data;
const pixelCount = MODEL_INPUT_SIZE * MODEL_INPUT_SIZE;
const input = new Float32Array(pixelCount * 3);
// Quantized models keep the original 0-255 byte range.
if (dtype === 'uint8') {
const out = new Uint8Array(pixelCount * 3);
for (let i = 0; i < pixelCount; i++) {
input[i * 3 + 0] = rgba[i * 4 + 0] / 255;
input[i * 3 + 1] = rgba[i * 4 + 1] / 255;
input[i * 3 + 2] = rgba[i * 4 + 2] / 255;
out[i * 3 + 0] = rgba[i * 4 + 0];
out[i * 3 + 1] = rgba[i * 4 + 1];
out[i * 3 + 2] = rgba[i * 4 + 2];
}
return out;
}
return input;
if (dtype === 'int8') {
const out = new Int8Array(pixelCount * 3);
for (let i = 0; i < pixelCount; i++) {
// shift to [-128, 127]
out[i * 3 + 0] = rgba[i * 4 + 0] - 128;
out[i * 3 + 1] = rgba[i * 4 + 1] - 128;
out[i * 3 + 2] = rgba[i * 4 + 2] - 128;
}
return out;
}
// Default float32 path: normalised to [0, 1] (matches the Keras
// preprocess_input used during training when rescale=1./255).
const out = new Float32Array(pixelCount * 3);
for (let i = 0; i < pixelCount; i++) {
out[i * 3 + 0] = rgba[i * 4 + 0] / 255;
out[i * 3 + 1] = rgba[i * 4 + 1] / 255;
out[i * 3 + 2] = rgba[i * 4 + 2] / 255;
}
return out;
}
function base64ToBytes(base64: string): Uint8Array {

View file

@ -29,10 +29,27 @@ import {
softmax,
} from '@/services/ml/preprocessing';
type TensorDataType =
| 'float32'
| 'float16'
| 'int32'
| 'int64'
| 'uint8'
| 'int8'
| 'bool';
type TensorInfo = {
name?: string;
dataType: TensorDataType;
shape: number[];
};
type TensorInput = Float32Array | Int32Array | Uint8Array | Int8Array;
type FastTfliteModel = {
runSync: (
inputs: (Float32Array | Int32Array | Uint8Array)[],
) => (Float32Array | Int32Array | Uint8Array)[];
inputs: TensorInfo[];
outputs: TensorInfo[];
runSync: (inputs: TensorInput[]) => TensorInput[];
};
let cachedModel: FastTfliteModel | null = null;
@ -49,9 +66,17 @@ async function getModel(): Promise<FastTfliteModel | null> {
// Path RELATIF (pas '@/') car require runtime ne résout pas les alias TS.
const tflite = require('react-native-fast-tflite');
const asset = require('../../assets/models/grapevine_v1.tflite');
const loaded: FastTfliteModel = await tflite.loadTensorflowModel(asset);
// 2e arg `delegates` OBLIGATOIRE même pour CPU (sinon native reçoit
// `undefined` et plante avec "Value is undefined, expected an Object").
const loaded: FastTfliteModel = await tflite.loadTensorflowModel(asset, []);
cachedModel = loaded;
console.log(`[TFLite] Model loaded in ${Date.now() - start}ms`);
try {
console.log('[TFLite] Inputs:', JSON.stringify(loaded.inputs));
console.log('[TFLite] Outputs:', JSON.stringify(loaded.outputs));
} catch {
// some versions expose these as getters that may not stringify
}
return loaded;
} catch (err) {
console.error('[TFLite] Failed to load model:', err);
@ -81,11 +106,14 @@ export async function runInference(imageUri?: string): Promise<Detection> {
try {
const t0 = Date.now();
const input = await preprocessImage(imageUri);
const inputType = model.inputs?.[0]?.dataType ?? 'float32';
const input = await preprocessImage(imageUri, inputType);
const t1 = Date.now();
console.log(`[TFLite] Preprocess: ${t1 - t0}ms`);
console.log(
`[TFLite] Preprocess: ${t1 - t0}ms (dtype=${inputType}, len=${input.length})`,
);
const outputs = model.runSync([input]);
const outputs = runSyncWithFallbacks(model, input);
const t2 = Date.now();
console.log(
`[TFLite] Inference: ${t2 - t1}ms (total: ${t2 - t0}ms)`,
@ -96,10 +124,8 @@ export async function runInference(imageUri?: string): Promise<Detection> {
}
const raw = outputs[0];
const rawArr =
raw instanceof Float32Array
? Array.from(raw)
: Array.from(raw as ArrayLike<number>);
const outputType = model.outputs?.[0]?.dataType ?? 'float32';
const rawArr = dequantizeOutput(raw, outputType);
const probs = isProbabilityVector(rawArr) ? rawArr : softmax(rawArr);
const idx = argmax(probs);
@ -124,6 +150,61 @@ export async function runInference(imageUri?: string): Promise<Detection> {
}
}
function runSyncWithFallbacks(
model: FastTfliteModel,
input: TensorInput,
): TensorInput[] {
// react-native-fast-tflite v3 (Nitro) binds inputs through JSI as raw
// ArrayBuffers — passing a TypedArray view triggers
// "TfliteModel.runSync(...): Object \"<dump>\""
// The underlying buffer works. Keep TypedArray as a fallback in case the
// binding ever flips back.
const attempts: { label: string; build: () => unknown }[] = [
{ label: 'array-buffer', build: () => input.buffer },
{ label: 'typed-array', build: () => input },
];
let lastError: unknown = null;
for (const attempt of attempts) {
try {
const candidate = attempt.build();
return model.runSync([candidate as TensorInput]);
} catch (err) {
lastError = err;
if (__DEV__) {
console.warn(
`[TFLite] runSync attempt "${attempt.label}" failed:`,
err instanceof Error ? err.message : String(err),
);
}
}
}
throw lastError instanceof Error
? lastError
: new Error('runSync failed with all fallbacks');
}
function dequantizeOutput(
raw: TensorInput,
dtype: TensorDataType,
): number[] {
// float outputs are already in [0,1] (after softmax) or logits.
if (dtype === 'float32' || dtype === 'float16') {
return Array.from(raw as Float32Array);
}
// Quantized outputs need a rough rescale. Without scale/zeroPoint metadata
// exposed we approximate: uint8 → /255, int8 → (v + 128)/255. Good enough
// for argmax (the relative order is preserved).
if (dtype === 'uint8') {
return Array.from(raw as Uint8Array, (v) => v / 255);
}
if (dtype === 'int8') {
return Array.from(raw as Int8Array, (v) => (v + 128) / 255);
}
return Array.from(raw as ArrayLike<number>);
}
function isProbabilityVector(values: number[]): boolean {
if (values.length === 0) return false;
const sum = values.reduce((a, b) => a + b, 0);