fix(ml): align input shape to 224x224 + pass ArrayBuffer to fast-tflite
The .tflite reports inputs[0].shape == [1, 224, 224, 3] but the preprocessing was producing 256x256x3 buffers, so 196608 floats were written into a tensor sized for 150528. Inference still completed (no allocator check on the JS side) but ran on shifted, decadred data — predictions were effectively random. Also: the v3 react-native-fast-tflite binding rejects raw TypedArray views with "TfliteModel.runSync(...): Object \"<element dump>\"" and only accepts the underlying ArrayBuffer. We now pass `input.buffer` as the primary path and keep the TypedArray as a fallback. Bonus: - Read input dtype from the model and dispatch preprocess accordingly (float32, uint8, int8) instead of hard-coding float32. Future-proof for a quantized re-export. - Dequantize uint8/int8 outputs to floats so argmax stays consistent. - Log model.inputs and model.outputs at load time — invaluable when re-exporting the .tflite and discovering shape mismatches. Validated on device (Samsung S23): preprocess 700ms + inference 39ms, no fallback chain. Still ~25% accuracy because the model itself is overfit (see docs/audit_report.md), but the inference plumbing is finally honest. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
26d0f39986
commit
c5d59fc092
|
|
@ -1,11 +1,24 @@
|
||||||
import { manipulateAsync, SaveFormat } from 'expo-image-manipulator';
|
import { manipulateAsync, SaveFormat } from 'expo-image-manipulator';
|
||||||
import * as jpeg from 'jpeg-js';
|
import * as jpeg from 'jpeg-js';
|
||||||
|
|
||||||
// Le modèle Python a été entraîné en 256×256 (MobileNetV2).
|
// Le .tflite exporté attend [1, 224, 224, 3] (la shape par défaut MobileNetV2,
|
||||||
// Toute modification doit rester synchronisée avec l'export TFLite.
|
// confirmée par model.inputs[0].shape sur device). Si tu réexportes le modèle
|
||||||
export const MODEL_INPUT_SIZE = 256;
|
// avec une autre shape, mets à jour cette constante en miroir — sinon le
|
||||||
|
// preprocess produit un buffer de mauvaise taille et l'inférence tourne sur
|
||||||
|
// des données décadrées (= prédictions aléatoires sans erreur visible).
|
||||||
|
export const MODEL_INPUT_SIZE = 224;
|
||||||
|
|
||||||
export async function preprocessImage(uri: string): Promise<Float32Array> {
|
export type TfliteInputDType = 'float32' | 'float16' | 'uint8' | 'int8';
|
||||||
|
|
||||||
|
export type TfliteInputArray =
|
||||||
|
| Float32Array
|
||||||
|
| Uint8Array
|
||||||
|
| Int8Array;
|
||||||
|
|
||||||
|
export async function preprocessImage(
|
||||||
|
uri: string,
|
||||||
|
dtype: TfliteInputDType | string = 'float32',
|
||||||
|
): Promise<TfliteInputArray> {
|
||||||
const resized = await manipulateAsync(
|
const resized = await manipulateAsync(
|
||||||
uri,
|
uri,
|
||||||
[{ resize: { width: MODEL_INPUT_SIZE, height: MODEL_INPUT_SIZE } }],
|
[{ resize: { width: MODEL_INPUT_SIZE, height: MODEL_INPUT_SIZE } }],
|
||||||
|
|
@ -27,15 +40,38 @@ export async function preprocessImage(uri: string): Promise<Float32Array> {
|
||||||
|
|
||||||
const rgba = decoded.data;
|
const rgba = decoded.data;
|
||||||
const pixelCount = MODEL_INPUT_SIZE * MODEL_INPUT_SIZE;
|
const pixelCount = MODEL_INPUT_SIZE * MODEL_INPUT_SIZE;
|
||||||
const input = new Float32Array(pixelCount * 3);
|
|
||||||
|
|
||||||
for (let i = 0; i < pixelCount; i++) {
|
// Quantized models keep the original 0-255 byte range.
|
||||||
input[i * 3 + 0] = rgba[i * 4 + 0] / 255;
|
if (dtype === 'uint8') {
|
||||||
input[i * 3 + 1] = rgba[i * 4 + 1] / 255;
|
const out = new Uint8Array(pixelCount * 3);
|
||||||
input[i * 3 + 2] = rgba[i * 4 + 2] / 255;
|
for (let i = 0; i < pixelCount; i++) {
|
||||||
|
out[i * 3 + 0] = rgba[i * 4 + 0];
|
||||||
|
out[i * 3 + 1] = rgba[i * 4 + 1];
|
||||||
|
out[i * 3 + 2] = rgba[i * 4 + 2];
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
return input;
|
if (dtype === 'int8') {
|
||||||
|
const out = new Int8Array(pixelCount * 3);
|
||||||
|
for (let i = 0; i < pixelCount; i++) {
|
||||||
|
// shift to [-128, 127]
|
||||||
|
out[i * 3 + 0] = rgba[i * 4 + 0] - 128;
|
||||||
|
out[i * 3 + 1] = rgba[i * 4 + 1] - 128;
|
||||||
|
out[i * 3 + 2] = rgba[i * 4 + 2] - 128;
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default float32 path: normalised to [0, 1] (matches the Keras
|
||||||
|
// preprocess_input used during training when rescale=1./255).
|
||||||
|
const out = new Float32Array(pixelCount * 3);
|
||||||
|
for (let i = 0; i < pixelCount; i++) {
|
||||||
|
out[i * 3 + 0] = rgba[i * 4 + 0] / 255;
|
||||||
|
out[i * 3 + 1] = rgba[i * 4 + 1] / 255;
|
||||||
|
out[i * 3 + 2] = rgba[i * 4 + 2] / 255;
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
function base64ToBytes(base64: string): Uint8Array {
|
function base64ToBytes(base64: string): Uint8Array {
|
||||||
|
|
|
||||||
|
|
@ -29,10 +29,27 @@ import {
|
||||||
softmax,
|
softmax,
|
||||||
} from '@/services/ml/preprocessing';
|
} from '@/services/ml/preprocessing';
|
||||||
|
|
||||||
|
type TensorDataType =
|
||||||
|
| 'float32'
|
||||||
|
| 'float16'
|
||||||
|
| 'int32'
|
||||||
|
| 'int64'
|
||||||
|
| 'uint8'
|
||||||
|
| 'int8'
|
||||||
|
| 'bool';
|
||||||
|
|
||||||
|
type TensorInfo = {
|
||||||
|
name?: string;
|
||||||
|
dataType: TensorDataType;
|
||||||
|
shape: number[];
|
||||||
|
};
|
||||||
|
|
||||||
|
type TensorInput = Float32Array | Int32Array | Uint8Array | Int8Array;
|
||||||
|
|
||||||
type FastTfliteModel = {
|
type FastTfliteModel = {
|
||||||
runSync: (
|
inputs: TensorInfo[];
|
||||||
inputs: (Float32Array | Int32Array | Uint8Array)[],
|
outputs: TensorInfo[];
|
||||||
) => (Float32Array | Int32Array | Uint8Array)[];
|
runSync: (inputs: TensorInput[]) => TensorInput[];
|
||||||
};
|
};
|
||||||
|
|
||||||
let cachedModel: FastTfliteModel | null = null;
|
let cachedModel: FastTfliteModel | null = null;
|
||||||
|
|
@ -49,9 +66,17 @@ async function getModel(): Promise<FastTfliteModel | null> {
|
||||||
// Path RELATIF (pas '@/') car require runtime ne résout pas les alias TS.
|
// Path RELATIF (pas '@/') car require runtime ne résout pas les alias TS.
|
||||||
const tflite = require('react-native-fast-tflite');
|
const tflite = require('react-native-fast-tflite');
|
||||||
const asset = require('../../assets/models/grapevine_v1.tflite');
|
const asset = require('../../assets/models/grapevine_v1.tflite');
|
||||||
const loaded: FastTfliteModel = await tflite.loadTensorflowModel(asset);
|
// 2e arg `delegates` OBLIGATOIRE même pour CPU (sinon native reçoit
|
||||||
|
// `undefined` et plante avec "Value is undefined, expected an Object").
|
||||||
|
const loaded: FastTfliteModel = await tflite.loadTensorflowModel(asset, []);
|
||||||
cachedModel = loaded;
|
cachedModel = loaded;
|
||||||
console.log(`[TFLite] Model loaded in ${Date.now() - start}ms`);
|
console.log(`[TFLite] Model loaded in ${Date.now() - start}ms`);
|
||||||
|
try {
|
||||||
|
console.log('[TFLite] Inputs:', JSON.stringify(loaded.inputs));
|
||||||
|
console.log('[TFLite] Outputs:', JSON.stringify(loaded.outputs));
|
||||||
|
} catch {
|
||||||
|
// some versions expose these as getters that may not stringify
|
||||||
|
}
|
||||||
return loaded;
|
return loaded;
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('[TFLite] Failed to load model:', err);
|
console.error('[TFLite] Failed to load model:', err);
|
||||||
|
|
@ -81,11 +106,14 @@ export async function runInference(imageUri?: string): Promise<Detection> {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const t0 = Date.now();
|
const t0 = Date.now();
|
||||||
const input = await preprocessImage(imageUri);
|
const inputType = model.inputs?.[0]?.dataType ?? 'float32';
|
||||||
|
const input = await preprocessImage(imageUri, inputType);
|
||||||
const t1 = Date.now();
|
const t1 = Date.now();
|
||||||
console.log(`[TFLite] Preprocess: ${t1 - t0}ms`);
|
console.log(
|
||||||
|
`[TFLite] Preprocess: ${t1 - t0}ms (dtype=${inputType}, len=${input.length})`,
|
||||||
|
);
|
||||||
|
|
||||||
const outputs = model.runSync([input]);
|
const outputs = runSyncWithFallbacks(model, input);
|
||||||
const t2 = Date.now();
|
const t2 = Date.now();
|
||||||
console.log(
|
console.log(
|
||||||
`[TFLite] Inference: ${t2 - t1}ms (total: ${t2 - t0}ms)`,
|
`[TFLite] Inference: ${t2 - t1}ms (total: ${t2 - t0}ms)`,
|
||||||
|
|
@ -96,10 +124,8 @@ export async function runInference(imageUri?: string): Promise<Detection> {
|
||||||
}
|
}
|
||||||
|
|
||||||
const raw = outputs[0];
|
const raw = outputs[0];
|
||||||
const rawArr =
|
const outputType = model.outputs?.[0]?.dataType ?? 'float32';
|
||||||
raw instanceof Float32Array
|
const rawArr = dequantizeOutput(raw, outputType);
|
||||||
? Array.from(raw)
|
|
||||||
: Array.from(raw as ArrayLike<number>);
|
|
||||||
const probs = isProbabilityVector(rawArr) ? rawArr : softmax(rawArr);
|
const probs = isProbabilityVector(rawArr) ? rawArr : softmax(rawArr);
|
||||||
|
|
||||||
const idx = argmax(probs);
|
const idx = argmax(probs);
|
||||||
|
|
@ -124,6 +150,61 @@ export async function runInference(imageUri?: string): Promise<Detection> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function runSyncWithFallbacks(
|
||||||
|
model: FastTfliteModel,
|
||||||
|
input: TensorInput,
|
||||||
|
): TensorInput[] {
|
||||||
|
// react-native-fast-tflite v3 (Nitro) binds inputs through JSI as raw
|
||||||
|
// ArrayBuffers — passing a TypedArray view triggers
|
||||||
|
// "TfliteModel.runSync(...): Object \"<dump>\""
|
||||||
|
// The underlying buffer works. Keep TypedArray as a fallback in case the
|
||||||
|
// binding ever flips back.
|
||||||
|
const attempts: { label: string; build: () => unknown }[] = [
|
||||||
|
{ label: 'array-buffer', build: () => input.buffer },
|
||||||
|
{ label: 'typed-array', build: () => input },
|
||||||
|
];
|
||||||
|
|
||||||
|
let lastError: unknown = null;
|
||||||
|
for (const attempt of attempts) {
|
||||||
|
try {
|
||||||
|
const candidate = attempt.build();
|
||||||
|
return model.runSync([candidate as TensorInput]);
|
||||||
|
} catch (err) {
|
||||||
|
lastError = err;
|
||||||
|
if (__DEV__) {
|
||||||
|
console.warn(
|
||||||
|
`[TFLite] runSync attempt "${attempt.label}" failed:`,
|
||||||
|
err instanceof Error ? err.message : String(err),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw lastError instanceof Error
|
||||||
|
? lastError
|
||||||
|
: new Error('runSync failed with all fallbacks');
|
||||||
|
}
|
||||||
|
|
||||||
|
function dequantizeOutput(
|
||||||
|
raw: TensorInput,
|
||||||
|
dtype: TensorDataType,
|
||||||
|
): number[] {
|
||||||
|
// float outputs are already in [0,1] (after softmax) or logits.
|
||||||
|
if (dtype === 'float32' || dtype === 'float16') {
|
||||||
|
return Array.from(raw as Float32Array);
|
||||||
|
}
|
||||||
|
// Quantized outputs need a rough rescale. Without scale/zeroPoint metadata
|
||||||
|
// exposed we approximate: uint8 → /255, int8 → (v + 128)/255. Good enough
|
||||||
|
// for argmax (the relative order is preserved).
|
||||||
|
if (dtype === 'uint8') {
|
||||||
|
return Array.from(raw as Uint8Array, (v) => v / 255);
|
||||||
|
}
|
||||||
|
if (dtype === 'int8') {
|
||||||
|
return Array.from(raw as Int8Array, (v) => (v + 128) / 255);
|
||||||
|
}
|
||||||
|
return Array.from(raw as ArrayLike<number>);
|
||||||
|
}
|
||||||
|
|
||||||
function isProbabilityVector(values: number[]): boolean {
|
function isProbabilityVector(values: number[]): boolean {
|
||||||
if (values.length === 0) return false;
|
if (values.length === 0) return false;
|
||||||
const sum = values.reduce((a, b) => a + b, 0);
|
const sum = values.reduce((a, b) => a + b, 0);
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue