Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions node/src/core/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ import {
enable_compile,
disable_compile,
checkpoint,
export_function,
export_to_dot,
type NanToNumOptions,
type CloseOptions,
type ContiguousOptions,
Expand Down Expand Up @@ -556,6 +558,8 @@ export {
enable_compile,
disable_compile,
checkpoint,
export_function,
export_to_dot,
};
export type {
SplitOptions, PadOptions, SliceOptions, AsStridedOptions, NumberOfElementsOptions,
Expand Down Expand Up @@ -835,6 +839,8 @@ const core = {
enable_compile,
disable_compile,
checkpoint,
export_function,
export_to_dot,
device: deviceModule,
Dtype,
dtype: dtypeModule,
Expand Down
27 changes: 27 additions & 0 deletions node/src/core/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3246,3 +3246,30 @@ export function checkpoint(fn: MultiArrayFn): (...args: MLXArray[]) => MLXArray
};
}

// ============================================================
// Export ops
// ============================================================

/** Export a function trace to a file for later import */
export function export_function(
file: string,
fn: MultiArrayFn,
args: MLXArray[],
shapeless?: boolean,
): void {
const nativeFn = (...nativeArgs: any[]) => {
const jsArgs = nativeArgs.map((a: any) => MLXArray.fromHandle(a));
const result = fn(...jsArgs);
if (Array.isArray(result)) return result.map(toNativeHandle);
return toNativeHandle(result);
};
const nativeArgs = args.map(toNativeHandle);
addon.export_function(file, nativeFn, nativeArgs, shapeless);
}

/** Export a computation graph to DOT format (Graphviz) */
export function export_to_dot(...arrays: MLXArray[]): string {
const nativeArrays = arrays.map(toNativeHandle);
return addon.export_to_dot(...nativeArrays);
}

2 changes: 2 additions & 0 deletions node/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ export const {
enable_compile,
disable_compile,
checkpoint,
export_function,
export_to_dot,
// DType constants
bool,
int8,
Expand Down
67 changes: 67 additions & 0 deletions node/src/native/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <numeric>
#include <optional>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -25,6 +26,7 @@
#include "mlx/memory.h"
#include "mlx/device.h"
#include "mlx/einsum.h"
#include "mlx/graph_utils.h"
#include "mlx_bridge.h"

#include "dtype.h"
Expand Down Expand Up @@ -8352,6 +8354,69 @@ Napi::Value SaveGguf(const Napi::CallbackInfo& info) {
return env.Undefined();
}

// ============================================================
// Export ops
// ============================================================

// export_function(file, fn, args, shapeless?)
Napi::Value ExportFunction(const Napi::CallbackInfo& info) {
auto env = info.Env();
if (info.Length() < 3 || !info[0].IsString() || !info[1].IsFunction()) {
Napi::TypeError::New(env, "export_function requires (file: string, fn: Function, args: MLXArray[])")
.ThrowAsJavaScriptException();
return env.Null();
}
std::string file = info[0].As<Napi::String>().Utf8Value();
auto multiFn = WrapJsFn(info[1].As<Napi::Function>());
// Parse args array
std::vector<mlx::core::array> args;
if (info[2].IsArray()) {
auto jsArr = info[2].As<Napi::Array>();
for (uint32_t i = 0; i < jsArr.Length(); i++) {
args.push_back(ToArray(env, jsArr.Get(i)));
if (env.IsExceptionPending()) return env.Null();
}
} else {
args.push_back(ToArray(env, info[2]));
if (env.IsExceptionPending()) return env.Null();
}
bool shapeless = false;
if (info.Length() > 3 && info[3].IsBoolean())
shapeless = info[3].As<Napi::Boolean>().Value();
try {
mlx::core::export_function(file, multiFn, args, shapeless);
} catch (const std::exception& e) {
Napi::Error::New(env, e.what()).ThrowAsJavaScriptException();
}
return env.Undefined();
}

// export_to_dot(arrays) -> string (DOT format)
Napi::Value ExportToDot(const Napi::CallbackInfo& info) {
auto env = info.Env();
std::vector<mlx::core::array> outputs;
for (size_t i = 0; i < info.Length(); i++) {
if (info[i].IsArray()) {
auto jsArr = info[i].As<Napi::Array>();
for (uint32_t j = 0; j < jsArr.Length(); j++) {
outputs.push_back(ToArray(env, jsArr.Get(j)));
if (env.IsExceptionPending()) return env.Null();
}
} else {
outputs.push_back(ToArray(env, info[i]));
if (env.IsExceptionPending()) return env.Null();
}
}
try {
std::ostringstream oss;
mlx::core::export_to_dot(oss, outputs);
return Napi::String::New(env, oss.str());
} catch (const std::exception& e) {
Napi::Error::New(env, e.what()).ThrowAsJavaScriptException();
return env.Null();
}
}

} // namespace

Napi::Object Init(Napi::Env env, Napi::Object exports) {
Expand Down Expand Up @@ -8668,6 +8733,8 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {

// Export/import operations
core.Set("import_function", Napi::Function::New(env, ImportFunction, "import_function", &data));
core.Set("export_function", Napi::Function::New(env, ExportFunction, "export_function", &data));
core.Set("export_to_dot", Napi::Function::New(env, ExportToDot, "export_to_dot", &data));

// Eval operations
core.Set("eval", Napi::Function::New(env, Eval, "eval", &data));
Expand Down
37 changes: 37 additions & 0 deletions node/test/core/batch7-export.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { describe, it } from 'mocha';
import * as assert from 'assert';
import * as core from '../../src/core';
import * as fs from 'fs';
import * as path from 'path';
import * as os from 'os';

describe('batch 7: export ops', () => {
const tmpDir = os.tmpdir();

it('export_function and import_function roundtrip', () => {
const f = (x: InstanceType<typeof core.MLXArray>) =>
core.multiply(x, x);
const x = core.array(new Float32Array([2, 3, 4]), [3]);
const file = path.join(tmpDir, `mlx_export_${Date.now()}.mlxfn`);
try {
core.export_function(file, f as any, [x]);
assert.ok(fs.existsSync(file));
const imported = core.import_function(file);
const results = imported(x);
assert.ok(Array.isArray(results));
const result = results[0];
core.eval_op(result);
assert.deepStrictEqual(result.toArray(), [4, 9, 16]);
} finally {
if (fs.existsSync(file)) fs.unlinkSync(file);
}
});

it('export_to_dot returns DOT string', () => {
const a = core.array(new Float32Array([1, 2]), [2]);
const b = core.add(a, a);
const dot = core.export_to_dot(b);
assert.ok(typeof dot === 'string');
assert.ok(dot.includes('digraph'));
});
});
Loading