fix(ml): align input shape to 224x224 + pass ArrayBuffer to fast-tflite
The .tflite reports inputs[0].shape == [1, 224, 224, 3] but the preprocessing was producing 256x256x3 buffers, so 196608 floats were written into a tensor sized for 150528. Inference still completed (no allocator check on the JS side) but ran on shifted, decadred data — predictions were effectively random. Also: the v3 react-native-fast-tflite binding rejects raw TypedArray views with "TfliteModel.runSync(...): Object \"<element dump>\"" and only accepts the underlying ArrayBuffer. We now pass `input.buffer` as the primary path and keep the TypedArray as a fallback. Bonus: - Read input dtype from the model and dispatch preprocess accordingly (float32, uint8, int8) instead of hard-coding float32. Future-proof for a quantized re-export. - Dequantize uint8/int8 outputs to floats so argmax stays consistent. - Log model.inputs and model.outputs at load time — invaluable when re-exporting the .tflite and discovering shape mismatches. Validated on device (Samsung S23): preprocess 700ms + inference 39ms, no fallback chain. Still ~25% accuracy because the model itself is overfit (see docs/audit_report.md), but the inference plumbing is finally honest. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
26d0f39986
commit
c5d59fc092
|
|
@ -1,11 +1,24 @@
|
|||
import { manipulateAsync, SaveFormat } from 'expo-image-manipulator';
|
||||
import * as jpeg from 'jpeg-js';
|
||||
|
||||
// Le modèle Python a été entraîné en 256×256 (MobileNetV2).
|
||||
// Toute modification doit rester synchronisée avec l'export TFLite.
|
||||
export const MODEL_INPUT_SIZE = 256;
|
||||
// Le .tflite exporté attend [1, 224, 224, 3] (la shape par défaut MobileNetV2,
|
||||
// confirmée par model.inputs[0].shape sur device). Si tu réexportes le modèle
|
||||
// avec une autre shape, mets à jour cette constante en miroir — sinon le
|
||||
// preprocess produit un buffer de mauvaise taille et l'inférence tourne sur
|
||||
// des données décadrées (= prédictions aléatoires sans erreur visible).
|
||||
export const MODEL_INPUT_SIZE = 224;
|
||||
|
||||
export async function preprocessImage(uri: string): Promise<Float32Array> {
|
||||
export type TfliteInputDType = 'float32' | 'float16' | 'uint8' | 'int8';
|
||||
|
||||
export type TfliteInputArray =
|
||||
| Float32Array
|
||||
| Uint8Array
|
||||
| Int8Array;
|
||||
|
||||
export async function preprocessImage(
|
||||
uri: string,
|
||||
dtype: TfliteInputDType | string = 'float32',
|
||||
): Promise<TfliteInputArray> {
|
||||
const resized = await manipulateAsync(
|
||||
uri,
|
||||
[{ resize: { width: MODEL_INPUT_SIZE, height: MODEL_INPUT_SIZE } }],
|
||||
|
|
@ -27,15 +40,38 @@ export async function preprocessImage(uri: string): Promise<Float32Array> {
|
|||
|
||||
const rgba = decoded.data;
|
||||
const pixelCount = MODEL_INPUT_SIZE * MODEL_INPUT_SIZE;
|
||||
const input = new Float32Array(pixelCount * 3);
|
||||
|
||||
// Quantized models keep the original 0-255 byte range.
|
||||
if (dtype === 'uint8') {
|
||||
const out = new Uint8Array(pixelCount * 3);
|
||||
for (let i = 0; i < pixelCount; i++) {
|
||||
input[i * 3 + 0] = rgba[i * 4 + 0] / 255;
|
||||
input[i * 3 + 1] = rgba[i * 4 + 1] / 255;
|
||||
input[i * 3 + 2] = rgba[i * 4 + 2] / 255;
|
||||
out[i * 3 + 0] = rgba[i * 4 + 0];
|
||||
out[i * 3 + 1] = rgba[i * 4 + 1];
|
||||
out[i * 3 + 2] = rgba[i * 4 + 2];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
return input;
|
||||
if (dtype === 'int8') {
|
||||
const out = new Int8Array(pixelCount * 3);
|
||||
for (let i = 0; i < pixelCount; i++) {
|
||||
// shift to [-128, 127]
|
||||
out[i * 3 + 0] = rgba[i * 4 + 0] - 128;
|
||||
out[i * 3 + 1] = rgba[i * 4 + 1] - 128;
|
||||
out[i * 3 + 2] = rgba[i * 4 + 2] - 128;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Default float32 path: normalised to [0, 1] (matches the Keras
|
||||
// preprocess_input used during training when rescale=1./255).
|
||||
const out = new Float32Array(pixelCount * 3);
|
||||
for (let i = 0; i < pixelCount; i++) {
|
||||
out[i * 3 + 0] = rgba[i * 4 + 0] / 255;
|
||||
out[i * 3 + 1] = rgba[i * 4 + 1] / 255;
|
||||
out[i * 3 + 2] = rgba[i * 4 + 2] / 255;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function base64ToBytes(base64: string): Uint8Array {
|
||||
|
|
|
|||
|
|
@ -29,10 +29,27 @@ import {
|
|||
softmax,
|
||||
} from '@/services/ml/preprocessing';
|
||||
|
||||
type TensorDataType =
|
||||
| 'float32'
|
||||
| 'float16'
|
||||
| 'int32'
|
||||
| 'int64'
|
||||
| 'uint8'
|
||||
| 'int8'
|
||||
| 'bool';
|
||||
|
||||
type TensorInfo = {
|
||||
name?: string;
|
||||
dataType: TensorDataType;
|
||||
shape: number[];
|
||||
};
|
||||
|
||||
type TensorInput = Float32Array | Int32Array | Uint8Array | Int8Array;
|
||||
|
||||
type FastTfliteModel = {
|
||||
runSync: (
|
||||
inputs: (Float32Array | Int32Array | Uint8Array)[],
|
||||
) => (Float32Array | Int32Array | Uint8Array)[];
|
||||
inputs: TensorInfo[];
|
||||
outputs: TensorInfo[];
|
||||
runSync: (inputs: TensorInput[]) => TensorInput[];
|
||||
};
|
||||
|
||||
let cachedModel: FastTfliteModel | null = null;
|
||||
|
|
@ -49,9 +66,17 @@ async function getModel(): Promise<FastTfliteModel | null> {
|
|||
// Path RELATIF (pas '@/') car require runtime ne résout pas les alias TS.
|
||||
const tflite = require('react-native-fast-tflite');
|
||||
const asset = require('../../assets/models/grapevine_v1.tflite');
|
||||
const loaded: FastTfliteModel = await tflite.loadTensorflowModel(asset);
|
||||
// 2e arg `delegates` OBLIGATOIRE même pour CPU (sinon native reçoit
|
||||
// `undefined` et plante avec "Value is undefined, expected an Object").
|
||||
const loaded: FastTfliteModel = await tflite.loadTensorflowModel(asset, []);
|
||||
cachedModel = loaded;
|
||||
console.log(`[TFLite] Model loaded in ${Date.now() - start}ms`);
|
||||
try {
|
||||
console.log('[TFLite] Inputs:', JSON.stringify(loaded.inputs));
|
||||
console.log('[TFLite] Outputs:', JSON.stringify(loaded.outputs));
|
||||
} catch {
|
||||
// some versions expose these as getters that may not stringify
|
||||
}
|
||||
return loaded;
|
||||
} catch (err) {
|
||||
console.error('[TFLite] Failed to load model:', err);
|
||||
|
|
@ -81,11 +106,14 @@ export async function runInference(imageUri?: string): Promise<Detection> {
|
|||
|
||||
try {
|
||||
const t0 = Date.now();
|
||||
const input = await preprocessImage(imageUri);
|
||||
const inputType = model.inputs?.[0]?.dataType ?? 'float32';
|
||||
const input = await preprocessImage(imageUri, inputType);
|
||||
const t1 = Date.now();
|
||||
console.log(`[TFLite] Preprocess: ${t1 - t0}ms`);
|
||||
console.log(
|
||||
`[TFLite] Preprocess: ${t1 - t0}ms (dtype=${inputType}, len=${input.length})`,
|
||||
);
|
||||
|
||||
const outputs = model.runSync([input]);
|
||||
const outputs = runSyncWithFallbacks(model, input);
|
||||
const t2 = Date.now();
|
||||
console.log(
|
||||
`[TFLite] Inference: ${t2 - t1}ms (total: ${t2 - t0}ms)`,
|
||||
|
|
@ -96,10 +124,8 @@ export async function runInference(imageUri?: string): Promise<Detection> {
|
|||
}
|
||||
|
||||
const raw = outputs[0];
|
||||
const rawArr =
|
||||
raw instanceof Float32Array
|
||||
? Array.from(raw)
|
||||
: Array.from(raw as ArrayLike<number>);
|
||||
const outputType = model.outputs?.[0]?.dataType ?? 'float32';
|
||||
const rawArr = dequantizeOutput(raw, outputType);
|
||||
const probs = isProbabilityVector(rawArr) ? rawArr : softmax(rawArr);
|
||||
|
||||
const idx = argmax(probs);
|
||||
|
|
@ -124,6 +150,61 @@ export async function runInference(imageUri?: string): Promise<Detection> {
|
|||
}
|
||||
}
|
||||
|
||||
function runSyncWithFallbacks(
|
||||
model: FastTfliteModel,
|
||||
input: TensorInput,
|
||||
): TensorInput[] {
|
||||
// react-native-fast-tflite v3 (Nitro) binds inputs through JSI as raw
|
||||
// ArrayBuffers — passing a TypedArray view triggers
|
||||
// "TfliteModel.runSync(...): Object \"<dump>\""
|
||||
// The underlying buffer works. Keep TypedArray as a fallback in case the
|
||||
// binding ever flips back.
|
||||
const attempts: { label: string; build: () => unknown }[] = [
|
||||
{ label: 'array-buffer', build: () => input.buffer },
|
||||
{ label: 'typed-array', build: () => input },
|
||||
];
|
||||
|
||||
let lastError: unknown = null;
|
||||
for (const attempt of attempts) {
|
||||
try {
|
||||
const candidate = attempt.build();
|
||||
return model.runSync([candidate as TensorInput]);
|
||||
} catch (err) {
|
||||
lastError = err;
|
||||
if (__DEV__) {
|
||||
console.warn(
|
||||
`[TFLite] runSync attempt "${attempt.label}" failed:`,
|
||||
err instanceof Error ? err.message : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError instanceof Error
|
||||
? lastError
|
||||
: new Error('runSync failed with all fallbacks');
|
||||
}
|
||||
|
||||
function dequantizeOutput(
|
||||
raw: TensorInput,
|
||||
dtype: TensorDataType,
|
||||
): number[] {
|
||||
// float outputs are already in [0,1] (after softmax) or logits.
|
||||
if (dtype === 'float32' || dtype === 'float16') {
|
||||
return Array.from(raw as Float32Array);
|
||||
}
|
||||
// Quantized outputs need a rough rescale. Without scale/zeroPoint metadata
|
||||
// exposed we approximate: uint8 → /255, int8 → (v + 128)/255. Good enough
|
||||
// for argmax (the relative order is preserved).
|
||||
if (dtype === 'uint8') {
|
||||
return Array.from(raw as Uint8Array, (v) => v / 255);
|
||||
}
|
||||
if (dtype === 'int8') {
|
||||
return Array.from(raw as Int8Array, (v) => (v + 128) / 255);
|
||||
}
|
||||
return Array.from(raw as ArrayLike<number>);
|
||||
}
|
||||
|
||||
function isProbabilityVector(values: number[]): boolean {
|
||||
if (values.length === 0) return false;
|
||||
const sum = values.reduce((a, b) => a + b, 0);
|
||||
|
|
|
|||
Loading…
Reference in a new issue