From 1a39c53370746ba305b203abb6a31e67533e15ee Mon Sep 17 00:00:00 2001 From: Leon Camus Date: Fri, 28 Aug 2020 14:03:08 +0200 Subject: [PATCH] feat: Add mixed backend --- lib/api/onnx-impl.ts | 4 +++- lib/backend.ts | 2 +- lib/backends/backend-mixed.ts | 14 ++++++++++++++ lib/backends/mixed-session-handler.ts | 24 ++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 lib/backends/backend-mixed.ts create mode 100644 lib/backends/mixed-session-handler.ts diff --git a/lib/api/onnx-impl.ts b/lib/api/onnx-impl.ts index 1f20accf..52f4896d 100644 --- a/lib/api/onnx-impl.ts +++ b/lib/api/onnx-impl.ts @@ -8,6 +8,7 @@ import {WebGLBackend} from '../backends/backend-webgl'; import {Environment} from './env'; import {envImpl} from './env-impl'; import {Backend} from './onnx'; +import {MixedBackend} from '../backends/backend-mixed'; export * from './env'; export * from './onnx'; @@ -17,7 +18,8 @@ export * from './inference-session'; export const backend: Backend = { cpu: new CpuBackend(), wasm: new WasmBackend(), - webgl: new WebGLBackend() + webgl: new WebGLBackend(), + mixed: new MixedBackend() }; export const ENV: Environment = envImpl; diff --git a/lib/backend.ts b/lib/backend.ts index 5df89984..ed135a6b 100644 --- a/lib/backend.ts +++ b/lib/backend.ts @@ -82,7 +82,7 @@ const backendsCache: Map = new Map(); */ export async function Backend(hint?: string|ReadonlyArray): Promise { if (!hint) { - return Backend(['webgl', 'wasm', 'cpu']); + return Backend(['webgl', 'mixed', 'wasm', 'cpu']); } else { const hints = typeof hint === 'string' ? [hint] : hint; diff --git a/lib/backends/backend-mixed.ts b/lib/backends/backend-mixed.ts new file mode 100644 index 00000000..f4f91e46 --- /dev/null +++ b/lib/backends/backend-mixed.ts @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {SessionHandler} from '../backend'; +import {Session} from '../session'; +import {MixedSessionHandler} from './mixed-session-handler'; +import {WebGLBackend} from './backend-webgl'; + + +export class MixedBackend extends WebGLBackend { + createSessionHandler(context: Session.Context): SessionHandler { + return new MixedSessionHandler(this, context); + } +} diff --git a/lib/backends/mixed-session-handler.ts b/lib/backends/mixed-session-handler.ts new file mode 100644 index 00000000..fcff77df --- /dev/null +++ b/lib/backends/mixed-session-handler.ts @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {WebGLSessionHandler} from './webgl/session-handler'; +import {Graph} from '../graph'; +import {OpSet, resolveOperator} from '../opset'; +import {Operator} from '../operators'; +import {CPU_OP_RESOLVE_RULES} from './cpu/op-resolve-rules'; +import {Logger} from '../instrument'; + +export class MixedSessionHandler extends WebGLSessionHandler { + resolve(node: Graph.Node, opsets: ReadonlyArray): Operator { + try { + return super.resolve(node, opsets); + } catch (e) { + Logger.warning( + 'MixedSessionHandler', + `Unable to initialize operator '${node.opType}' with webgl. trying with cpu...`); + const op = resolveOperator(node, opsets, CPU_OP_RESOLVE_RULES); + op.initialize(node.attributes); + return op; + } + } +} \ No newline at end of file