diff --git a/node/src/core/index.ts b/node/src/core/index.ts index 8a6b0fb9..d2349840 100644 --- a/node/src/core/index.ts +++ b/node/src/core/index.ts @@ -302,6 +302,8 @@ import { enable_compile, disable_compile, checkpoint, + export_function, + export_to_dot, type NanToNumOptions, type CloseOptions, type ContiguousOptions, @@ -556,6 +558,8 @@ export { enable_compile, disable_compile, checkpoint, + export_function, + export_to_dot, }; export type { SplitOptions, PadOptions, SliceOptions, AsStridedOptions, NumberOfElementsOptions, @@ -835,6 +839,8 @@ const core = { enable_compile, disable_compile, checkpoint, + export_function, + export_to_dot, device: deviceModule, Dtype, dtype: dtypeModule, diff --git a/node/src/core/ops.ts b/node/src/core/ops.ts index c62e3ef5..cf4396e0 100644 --- a/node/src/core/ops.ts +++ b/node/src/core/ops.ts @@ -3246,3 +3246,30 @@ export function checkpoint(fn: MultiArrayFn): (...args: MLXArray[]) => MLXArray }; } +// ============================================================ +// Export ops +// ============================================================ + +/** Export a function trace to a file for later import */ +export function export_function( + file: string, + fn: MultiArrayFn, + args: MLXArray[], + shapeless?: boolean, +): void { + const nativeFn = (...nativeArgs: any[]) => { + const jsArgs = nativeArgs.map((a: any) => MLXArray.fromHandle(a)); + const result = fn(...jsArgs); + if (Array.isArray(result)) return result.map(toNativeHandle); + return toNativeHandle(result); + }; + const nativeArgs = args.map(toNativeHandle); + addon.export_function(file, nativeFn, nativeArgs, shapeless); +} + +/** Export a computation graph to DOT format (Graphviz) */ +export function export_to_dot(...arrays: MLXArray[]): string { + const nativeArrays = arrays.map(toNativeHandle); + return addon.export_to_dot(...nativeArrays); +} + diff --git a/node/src/index.ts b/node/src/index.ts index 1493f9c2..a554f453 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -243,6 +243,8 @@ export const { enable_compile, disable_compile, checkpoint, + export_function, + export_to_dot, // DType constants bool, int8, diff --git a/node/src/native/array.cc b/node/src/native/array.cc index 743cdace..6b520f61 100644 --- a/node/src/native/array.cc +++ b/node/src/native/array.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -25,6 +26,7 @@ #include "mlx/memory.h" #include "mlx/device.h" #include "mlx/einsum.h" +#include "mlx/graph_utils.h" #include "mlx_bridge.h" #include "dtype.h" @@ -8352,6 +8354,69 @@ Napi::Value SaveGguf(const Napi::CallbackInfo& info) { return env.Undefined(); } +// ============================================================ +// Export ops +// ============================================================ + +// export_function(file, fn, args, shapeless?) +Napi::Value ExportFunction(const Napi::CallbackInfo& info) { + auto env = info.Env(); + if (info.Length() < 3 || !info[0].IsString() || !info[1].IsFunction()) { + Napi::TypeError::New(env, "export_function requires (file: string, fn: Function, args: MLXArray[])") + .ThrowAsJavaScriptException(); + return env.Null(); + } + std::string file = info[0].As().Utf8Value(); + auto multiFn = WrapJsFn(info[1].As()); + // Parse args array + std::vector args; + if (info[2].IsArray()) { + auto jsArr = info[2].As(); + for (uint32_t i = 0; i < jsArr.Length(); i++) { + args.push_back(ToArray(env, jsArr.Get(i))); + if (env.IsExceptionPending()) return env.Null(); + } + } else { + args.push_back(ToArray(env, info[2])); + if (env.IsExceptionPending()) return env.Null(); + } + bool shapeless = false; + if (info.Length() > 3 && info[3].IsBoolean()) + shapeless = info[3].As().Value(); + try { + mlx::core::export_function(file, multiFn, args, shapeless); + } catch (const std::exception& e) { + Napi::Error::New(env, e.what()).ThrowAsJavaScriptException(); + } + return env.Undefined(); +} + +// export_to_dot(arrays) -> string (DOT format) +Napi::Value ExportToDot(const Napi::CallbackInfo& info) { + auto env = info.Env(); + std::vector outputs; + for (size_t i = 0; i < info.Length(); i++) { + if (info[i].IsArray()) { + auto jsArr = info[i].As(); + for (uint32_t j = 0; j < jsArr.Length(); j++) { + outputs.push_back(ToArray(env, jsArr.Get(j))); + if (env.IsExceptionPending()) return env.Null(); + } + } else { + outputs.push_back(ToArray(env, info[i])); + if (env.IsExceptionPending()) return env.Null(); + } + } + try { + std::ostringstream oss; + mlx::core::export_to_dot(oss, outputs); + return Napi::String::New(env, oss.str()); + } catch (const std::exception& e) { + Napi::Error::New(env, e.what()).ThrowAsJavaScriptException(); + return env.Null(); + } +} + } // namespace Napi::Object Init(Napi::Env env, Napi::Object exports) { @@ -8668,6 +8733,8 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) { // Export/import operations core.Set("import_function", Napi::Function::New(env, ImportFunction, "import_function", &data)); + core.Set("export_function", Napi::Function::New(env, ExportFunction, "export_function", &data)); + core.Set("export_to_dot", Napi::Function::New(env, ExportToDot, "export_to_dot", &data)); // Eval operations core.Set("eval", Napi::Function::New(env, Eval, "eval", &data)); diff --git a/node/test/core/batch7-export.test.ts b/node/test/core/batch7-export.test.ts new file mode 100644 index 00000000..ee83b00b --- /dev/null +++ b/node/test/core/batch7-export.test.ts @@ -0,0 +1,37 @@ +import { describe, it } from 'mocha'; +import * as assert from 'assert'; +import * as core from '../../src/core'; +import * as fs from 'fs'; +import * as path from 'path'; +import * as os from 'os'; + +describe('batch 7: export ops', () => { + const tmpDir = os.tmpdir(); + + it('export_function and import_function roundtrip', () => { + const f = (x: InstanceType) => + core.multiply(x, x); + const x = core.array(new Float32Array([2, 3, 4]), [3]); + const file = path.join(tmpDir, `mlx_export_${Date.now()}.mlxfn`); + try { + core.export_function(file, f as any, [x]); + assert.ok(fs.existsSync(file)); + const imported = core.import_function(file); + const results = imported(x); + assert.ok(Array.isArray(results)); + const result = results[0]; + core.eval_op(result); + assert.deepStrictEqual(result.toArray(), [4, 9, 16]); + } finally { + if (fs.existsSync(file)) fs.unlinkSync(file); + } + }); + + it('export_to_dot returns DOT string', () => { + const a = core.array(new Float32Array([1, 2]), [2]); + const b = core.add(a, a); + const dot = core.export_to_dot(b); + assert.ok(typeof dot === 'string'); + assert.ok(dot.includes('digraph')); + }); +});