diff --git a/.gitignore b/.gitignore index 1b1dda6..e199e48 100644 --- a/.gitignore +++ b/.gitignore @@ -20,5 +20,8 @@ # Build script compiled data .lei/ +# Cgo generated files +_obj/ + # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/.gitmodules b/.gitmodules index 2ea0cb6..1fcd3e2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "internal/luau"] - path = internal/luau - url = https://github.com/luau-lang/luau [submodule "ffi/luau"] path = ffi/luau url = https://github.com/luau-lang/luau.git diff --git a/build/build.go b/build/build.go index c3ce84a..aa20229 100644 --- a/build/build.go +++ b/build/build.go @@ -1,152 +1,49 @@ package main import ( - "fmt" + "log" "os" - "path" "strings" - - "github.com/gookit/color" - "golang.org/x/term" ) -const LUAU_VERSION = "0.634" -const ARTIFACT_NAME = "libLuau.VM.a" - -func bail(err error) { - if err != nil { - panic(err) - } -} - -func cloneSrc() string { - color.Blue.Println("> Cloning luau-lang/luau") - - dir, tempDirErr := os.MkdirTemp("", "lei-build") - bail(tempDirErr) - - // Clone down the Luau repo and checkout the required tag - Exec("git", "", "clone", "https://github.com/luau-lang/luau.git", dir) - Exec("git", dir, "checkout", LUAU_VERSION) - - color.Green.Printf("> Cloned repo to%s\n\n", dir) - return dir -} - -func buildVm(srcPath string, artifactPath string, includesDir string, cmakeFlags ...string) { - color.Blue.Println("> Compile libLuau.VM.a") - - // Build the Luau VM using CMake - buildDir := path.Join(srcPath, "cmake") - buildDirErr := os.Mkdir(buildDir, os.ModePerm) - if !os.IsExist(buildDirErr) { - bail(buildDirErr) - } - - defaultCmakeFlags := []string{"..", "-DCMAKE_BUILD_TYPE=RelWithDebInfo", "-DLUAU_EXTERN_C=ON", "-DCMAKE_POLICY_VERSION_MINIMUM=3.5"} - Exec("cmake", buildDir, append(defaultCmakeFlags, cmakeFlags...)...) - Exec("cmake", buildDir, "--build", ".", "--target Luau.VM", "--config", "RelWithDebInfo") - - color.Green.Println("> Successfully compiled!\n") - - // Copy the artifact to the artifact directory - artifactFile, artifactErr := os.ReadFile(path.Join(buildDir, ARTIFACT_NAME)) - bail(artifactErr) - bail(os.WriteFile(artifactPath, artifactFile, os.ModePerm)) - - // Copy the header files into the includes directory - headerDir := path.Join(srcPath, "VM", "include") - headerFiles, headerErr := os.ReadDir(headerDir) - bail(headerErr) - for _, file := range headerFiles { - src := path.Join(headerDir, file.Name()) - dest := path.Join(includesDir, file.Name()) - - headerContents, headerReadErr := os.ReadFile(src) - bail(headerReadErr) - - os.WriteFile(dest, headerContents, os.ModePerm) - } -} - func main() { - workDir, workDirErr := os.Getwd() - bail(workDirErr) - - artifactDir := path.Join(workDir, ".lei") - artifactPath := path.Join(artifactDir, ARTIFACT_NAME) - lockfilePath := path.Join(artifactDir, ".lock") - includesDir := path.Join(artifactDir, "includes") - - bail(os.MkdirAll(includesDir, os.ModePerm)) // includesDir is the deepest dir, creates all - - gitignore, gitignoreErr := os.ReadFile(".gitignore") - if gitignoreErr == nil && !strings.Contains(string(gitignore), ".lei") { - color.Yellow.Println("> WARN: The gitignore in the CWD does not include `.lei`, consider adding it") + usage := func() { log.Fatal("Usage: buildProject ") } + if len(os.Args) < 2 { + usage() } - // TODO: Args for clean build - args := os.Args[1:] - - goArgs := []string{} - cmakeFlags := []string{} - features := []string{} + switch os.Args[1] { + case "buildProject": + for _, project := range os.Args[2:] { + if !strings.HasPrefix(project, "Luau.") { + log.Fatalf("Invalid project name: %s", project) + } - // TODO: maybe use env vars for this config instead - for _, arg := range args { - if arg == "--enable-vector4" { - features = append(features, "LUAU_VECTOR4") - // FIXME: This flag apparently isn't recognized by cmake for some reason - cmakeFlags = append(cmakeFlags, "-DLUAU_VECTOR_SIZE=4") - - } else { - goArgs = append(goArgs, arg) + compileLuauProject(project) } - } - lockfileContents, err := os.ReadFile(lockfilePath) - if !os.IsNotExist(err) { - bail(err) + // Display usage menu + case "-h", "--help": + fallthrough + default: + usage() } +} - serFeatures := fmt.Sprintf("%v", features) - toCleanBuild := (string(lockfileContents) != serFeatures) || os.Getenv("LEI_CLEAN_BUILD") == "true" - if _, err := os.Stat(artifactPath); err == nil && !toCleanBuild { - fmt.Printf("[build] Using existing artifact at %s\n", artifactPath) - } else { - srcPath, notUnset := os.LookupEnv("LEI_LUAU_SRC") - if !notUnset { - srcPath = cloneSrc() - defer os.RemoveAll(srcPath) - } - - buildVm(srcPath, artifactPath, includesDir, cmakeFlags...) - bail(os.WriteFile(lockfilePath, []byte(serFeatures), os.ModePerm)) - } +func compileLuauProject(project string) { + if err := os.Mkdir("_obj", os.ModePerm); err == nil || !os.IsExist(err) { + // Directory already exists, i.e., config files generated + Exec( + "cmake", + "-S", "luau", + "-B", "_obj", + "-G", "Ninja", - buildTags := []string{} - if len(features) > 0 { - buildTags = append(buildTags, []string{"-tags", strings.Join(features, ",")}...) + // Flags + "-DCMAKE_BUILD_TYPE=RelWithDebInfo", + "-DLUAU_EXTERN_C=ON", + ) } - w, _, termErr := term.GetSize(int(os.Stdout.Fd())) - bail(termErr) - fmt.Println(strings.Repeat("=", w)) - - subcommand := goArgs[0] - goArgs = goArgs[1:] - combinedArgs := append(buildTags, goArgs...) - cmd, _, _, _ := Command("go"). - WithArgs(append([]string{subcommand}, combinedArgs...)...). - WithVar( - "CGO_LDFLAGS", - fmt.Sprintf("-L %s -lLuau.VM -lm -lstdc++", artifactDir), - ). - WithVar("CGO_CFLAGS", fmt.Sprintf("-I%s", includesDir)). - WithVar("CGO_ENABLED", "1"). - PipeAll(Forward). - ToCommand() - - bail(cmd.Start()) - bail(cmd.Wait()) + Exec("cmake", "--build", "_obj", "-t", project, "--config", "RelWithDebInfo") } diff --git a/build/cmd.go b/build/cmd.go index 072be6e..1aabf3f 100644 --- a/build/cmd.go +++ b/build/cmd.go @@ -3,8 +3,10 @@ package main import ( "bytes" "io" + "log" "os" "os/exec" + "strings" ) type CommandPipeMode int @@ -127,17 +129,23 @@ func pipeModeToWriter(mode CommandPipeMode, def io.Writer) io.Writer { } } -func Exec(name string, dir string, args ...string) { - cmd, _, _, _ := Command(name).WithArgs(args...).Dir(dir).PipeAll(Forward).ToCommand() - startErr := cmd.Start() - if startErr != nil { - panic(startErr) +func Exec(exe string, args ...string) { + cmd, _, _, _ := Command(exe). + WithArgs(args...). + WithVar("CLICOLOR_FORCE", "1"). + PipeAll(Forward). + ToCommand() + + if err := cmd.Start(); err != nil { + log.Fatalf("Failed to start command %s: %v", exe, err) } - cmdErr := cmd.Wait() - if cmdErr != nil { - panic(cmdErr) - // err := cmdErr.(*exec.ExitError) - // os.Exit(err.ExitCode()) + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + commandStr := strings.Join(append([]string{exe}, args...), " ") + log.Fatalf("'%s' exited with %d", commandStr, exitErr.ExitCode()) + } + + log.Fatalf("%s command failed: %v", exe, err) } } diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..79fb550 --- /dev/null +++ b/doc.go @@ -0,0 +1,3 @@ +// Root package. Use `ffi` or `lua` modules for low-level bindings +// or high-level APIs respectively. +package main diff --git a/ffi/bytecode.go b/ffi/bytecode.go new file mode 100644 index 0000000..1940fee --- /dev/null +++ b/ffi/bytecode.go @@ -0,0 +1,23 @@ +package ffi + +/* +#cgo CFLAGS: -Iluau/Common/include +#include "Luau/Bytecode.h" + +enum LuauBytecodeTag LuauBytecodeTag; +*/ +import "C" + +// +// Version Constants +// + +const ( + LBC_VERSION_MIN = C.LBC_VERSION_MIN + LBC_VERSION_MAX = C.LBC_VERSION_MAX + LBC_VERSION_TARGET = C.LBC_VERSION_TARGET + + LBC_TYPE_VERSION_MIN = C.LBC_TYPE_VERSION_MIN + LBC_TYPE_VERSION_MAX = C.LBC_TYPE_VERSION_MAX + LBC_TYPE_VERSION_TARGET = C.LBC_TYPE_VERSION_TARGET +) diff --git a/ffi/clua.c b/ffi/clua.c index a0a9d27..ad0e06d 100644 --- a/ffi/clua.c +++ b/ffi/clua.c @@ -1,6 +1,5 @@ #include -#include -#include <_cgo_export.h> +#include "luau/VM/include/lua.h" // void* clua_alloc(void* ud, void *ptr, size_t osize, size_t nsize) // { diff --git a/ffi/clua.h b/ffi/clua.h index 004e7ed..b97cc71 100644 --- a/ffi/clua.h +++ b/ffi/clua.h @@ -1,5 +1,5 @@ #include -#include +#include "luau/VM/include/lua.h" lua_State* clua_newstate(void* f, void* ud); l_noret cluaL_errorL(lua_State* L, char* msg); diff --git a/ffi/lauxlib.go b/ffi/lauxlib.go index 08a5a3c..5faad1b 100644 --- a/ffi/lauxlib.go +++ b/ffi/lauxlib.go @@ -1,8 +1,10 @@ package ffi +//go:generate go run ../build buildProject Luau.VM + /* -#cgo CFLAGS: -Iluau/VM/include -I/usr/lib/gcc/x86_64-pc-linux-gnu/14.1.1/include -#cgo LDFLAGS: -Lluau/cmake -lLuau.VM -lm -lstdc++ +#cgo CFLAGS: -Iluau/VM/include +#cgo LDFLAGS: -L_obj -lLuau.VM -lm -lstdc++ #include #include #include @@ -66,8 +68,6 @@ func LArgError(L *LuaState, narg int32, extramsg string) { func LCheckLString(L *LuaState, narg int32, l *uint64) string { p := C.luaL_checklstring(L, C.int(narg), (*C.size_t)(l)) - defer C.free(unsafe.Pointer(p)) - return C.GoString(p) } @@ -76,8 +76,6 @@ func LOptLString(L *LuaState, narg int32, def string, l *uint64) string { defer C.free(unsafe.Pointer(cdef)) p := C.luaL_optlstring(L, C.int(narg), cdef, (*C.ulong)(l)) - defer C.free(unsafe.Pointer(p)) - return C.GoString(p) } @@ -238,31 +236,29 @@ func LOptString(L *LuaState, n int32, d string) string { } const ( - LUA_COLIBNAME = "coroutine" - LUA_TABLIBNAME = "table" - LUA_OSLIBNAME = "os" - LUA_STRLIBNAME = "string" - LUA_BITLIBNAME = "bit32" - LUA_BUFFERLIBNAME = "buffer" - LUA_UTF8LIBNAME = "utf8" - LUA_MATHLIBNAME = "math" - LUA_DBLIBNAME = "debug" + LUA_COLIBNAME = C.LUA_COLIBNAME + LUA_TABLIBNAME = C.LUA_TABLIBNAME + LUA_OSLIBNAME = C.LUA_OSLIBNAME + LUA_STRLIBNAME = C.LUA_STRLIBNAME + LUA_BITLIBNAME = C.LUA_BITLIBNAME + LUA_BUFFERLIBNAME = C.LUA_BUFFERLIBNAME + LUA_UTF8LIBNAME = C.LUA_UTF8LIBNAME + LUA_MATHLIBNAME = C.LUA_MATHLIBNAME + LUA_DBLIBNAME = C.LUA_DBLIBNAME + LUA_VECLIBNAME = C.LUA_VECLIBNAME ) -// DIVERGENCE: We cannot export wrapper functions around C functions if we want to -// pass them to API functions, we preserve the real C pointer by having 'opener' -// functions - -func CoroutineOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_base) } -func BaseOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_base) } -func TableOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_table) } -func OsOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_os) } -func StringOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_string) } -func Bit32Opener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_bit32) } -func BufferOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_buffer) } -func Utf8Opener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_utf8) } -func MathOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_math) } -func DebugOpener() C.lua_CFunction { return C.lua_CFunction(C.luaopen_debug) } -func LibsOpener() C.lua_CFunction { return C.lua_CFunction(C.luaL_openlibs) } +func OpenBase(L *LuaState) { C.luaopen_base(L) } +func OpenCoroutine(L *LuaState) { C.luaopen_coroutine(L) } +func OpenTable(L *LuaState) { C.luaopen_table(L) } +func OpenOs(L *LuaState) { C.luaopen_os(L) } +func OpenString(L *LuaState) { C.luaopen_string(L) } +func OpenBit32(L *LuaState) { C.luaopen_bit32(L) } +func OpenBuffer(L *LuaState) { C.luaopen_buffer(L) } +func OpenUtf8(L *LuaState) { C.luaopen_utf8(L) } +func OpenMath(L *LuaState) { C.luaopen_math(L) } +func OpenDebug(L *LuaState) { C.luaopen_debug(L) } +func OpenVector(L *LuaState) { C.luaopen_vector(L) } +func LOpenLibs(L *LuaState) { C.luaL_openlibs(L) } // TODO: More utility functions, buffer bindings diff --git a/ffi/lua.go b/ffi/lua.go index f8884c3..4314c12 100644 --- a/ffi/lua.go +++ b/ffi/lua.go @@ -1,8 +1,10 @@ package ffi +//go:generate go run ../build buildProject Luau.VM + /* -#cgo CFLAGS: -Iluau/VM/include -I/usr/lib/gcc/x86_64-pc-linux-gnu/15.2.1/include -#cgo LDFLAGS: -Lluau/cmake -lLuau.VM -lm -lstdc++ +#cgo CFLAGS: -Iluau/VM/include +#cgo LDFLAGS: -L_obj -lLuau.VM -lm -lstdc++ #include #include #include @@ -36,9 +38,9 @@ const LUA_MULTRET = -1 // const ( - LUA_REGISTRYINDEX = -LUAI_MAXCSTACK - 2000 - LUA_ENVIRONINDEX = -LUAI_MAXCSTACK - 2001 - LUA_GLOBALSINDEX = -LUAI_MAXCSTACK - 2002 + LUA_REGISTRYINDEX = C.LUA_REGISTRYINDEX + LUA_ENVIRONINDEX = C.LUA_ENVIRONINDEX + LUA_GLOBALSINDEX = C.LUA_GLOBALSINDEX ) // @@ -48,7 +50,7 @@ const ( // const ( - LUA_OK = iota + 1 + LUA_OK = iota LUA_YIELD LUA_ERRRUN LUA_ERRSYNTAX @@ -64,7 +66,7 @@ const ( // const ( - LUA_CORUN = iota + 1 + LUA_CORUN = iota LUA_COSUS LUA_CONOR LUA_COFIN @@ -288,8 +290,21 @@ func ToUnsignedX(L *LuaState, idx int32, isnum *bool) LuaUnsigned { return unsigned } -func ToVector(L *LuaState, idx int32) { - C.lua_tovector(L, C.int(idx)) +// DIVERGENCE: We cannot cast and reinterpret the C owned vector returned as +// a Go value, as it breaks cgo pointer rules. Instead, we allocate new Go +// owned floats on the heap and only read the floats returned by C + +func ToVector(L *C.lua_State, idx int32) (x, y, z *float32) { + vec := C.lua_tovector(L, C.int(idx)) + if vec == nil { + return nil, nil, nil + } + + v := (*[3]C.float)(unsafe.Pointer(vec)) + x, y, z = new(float32), new(float32), new(float32) + *x, *y, *z = float32(v[0]), float32(v[1]), float32(v[2]) + + return } func ToBoolean(L *LuaState, idx int32) bool { @@ -398,9 +413,13 @@ func PushString(L *LuaState, s string) { // arguments from Go->C isn't something that is possible. // func PushFStringL(L *lua_State, fmt string) {} -func PushCClosureK(L *LuaState, f unsafe.Pointer, debugname string, nup int32, cont unsafe.Pointer) { - cdebugname := C.CString(debugname) - defer C.free(unsafe.Pointer(cdebugname)) +func PushCClosureK(L *LuaState, f unsafe.Pointer, debugname *string, nup int32, cont unsafe.Pointer) { + var cdebugname *C.char + if debugname != nil && *debugname != "" { + cdebugname = C.CString(*debugname) + defer C.free(unsafe.Pointer(cdebugname)) + } + C.clua_pushcclosurek(L, f, cdebugname, C.int(nup), cont) } @@ -421,6 +440,10 @@ func PushLightUserdataTagged(L *LuaState, p unsafe.Pointer, tag int32) { C.lua_pushlightuserdatatagged(L, p, C.int(tag)) } +func PushVector(L *LuaState, x, y, z float32) { + C.lua_pushvector(L, C.float(x), C.float(y), C.float(z)) +} + func NewUserdataTagged(L *LuaState, sz uint64, tag int32) unsafe.Pointer { return C.lua_newuserdatatagged(L, C.size_t(sz), C.int(tag)) } @@ -491,8 +514,8 @@ func SetSafeEnv(L *LuaState, idx int32, enabled bool) { C.lua_setsafeenv(L, C.int(idx), cenabled) } -func GetMetatable(L *LuaState, objindex int32) int32 { - return int32(C.lua_getmetatable(L, C.int(objindex))) +func GetMetatable(L *LuaState, objindex int32) bool { + return int32(C.lua_getmetatable(L, C.int(objindex))) == 1 } func Getfenv(L *LuaState, idx int32) { @@ -528,8 +551,8 @@ func SetMetatable(L *LuaState, objindex int32) int32 { return int32(C.lua_setmetatable(L, C.int(objindex))) } -func Setfenv(L *LuaState, idx int32) int32 { - return int32(C.lua_setfenv(L, C.int(idx))) +func Setfenv(L *LuaState, idx int32) bool { + return C.lua_setfenv(L, C.int(idx)) == 0 } // @@ -538,14 +561,21 @@ func Setfenv(L *LuaState, idx int32) int32 { // ========================= // -func LuauLoad(L *LuaState, chunkname string, data string, size uint64, env int32) int32 { +func LuauLoad(L *LuaState, chunkname string, data []byte, size uint64, env int32) bool { cchunkname := C.CString(chunkname) defer C.free(unsafe.Pointer(cchunkname)) - cdata := C.CString(data) - defer C.free(unsafe.Pointer(cdata)) + var cdata *C.char + if size == 0 { + // NULL for empty slices + cdata = (*C.char)(C.NULL) + } else { + cdata = (*C.char)(unsafe.Pointer(&data[0])) + } + + // NOTE: We don't free the bytecode after it's loaded - return int32(C.luau_load(L, cchunkname, cdata, C.size_t(size), C.int(env))) + return C.luau_load(L, cchunkname, cdata, C.size_t(size), C.int(env)) == 0 } func Call(L *LuaState, nargs int32, nresults int32) { @@ -562,23 +592,23 @@ func Pcall(L *LuaState, nargs int32, nresults int32, errfunc int32) int32 { // ======================== // -func LuaYield(L *LuaState, nresults int32) int32 { +func Yield(L *LuaState, nresults int32) int32 { return int32(C.lua_yield(L, C.int(nresults))) } -func LuaBreak(L *LuaState) int32 { +func Break(L *LuaState) int32 { return int32(C.lua_break(L)) } -func LuaResume(L *LuaState, from *LuaState, nargs int32) int32 { +func Resume(L *LuaState, from *LuaState, nargs int32) int32 { return int32(C.lua_resume(L, from, C.int(nargs))) } -func LuaResumeError(L *LuaState, from *LuaState) int32 { +func ResumeError(L *LuaState, from *LuaState) int32 { return int32(C.lua_resumeerror(L, from)) } -func LuaStatus(L *LuaState) int32 { +func Status(L *LuaState) int32 { return int32(C.lua_status(L)) } @@ -726,6 +756,10 @@ func Unref(L *LuaState, ref int32) { C.lua_unref(L, C.int(ref)) } +func GetRef(L *LuaState, ref int32) int32 { + return RawGetI(L, LUA_REGISTRYINDEX, ref) +} + // // ================== // Debug API @@ -957,18 +991,18 @@ func PushLiteral(L *LuaState, s string) { } func PushCFunction(L *LuaState, f unsafe.Pointer) { - PushCClosureK(L, f, *new(string), 0, nil) + PushCClosureK(L, f, new(string), 0, nil) } -func PushCFunctionD(L *LuaState, f unsafe.Pointer, debugname string) { +func PushCFunctionD(L *LuaState, f unsafe.Pointer, debugname *string) { PushCClosureK(L, f, debugname, 0, nil) } func PushCClosure(L *LuaState, f unsafe.Pointer, nup int32) { - PushCClosureK(L, f, *new(string), nup, nil) + PushCClosureK(L, f, new(string), nup, nil) } -func PushCClosureD(L *LuaState, f unsafe.Pointer, debugname string, nup int32) { +func PushCClosureD(L *LuaState, f unsafe.Pointer, debugname *string, nup int32) { PushCClosureK(L, f, debugname, nup, nil) } diff --git a/ffi/luacode.go b/ffi/luacode.go new file mode 100644 index 0000000..c8500fe --- /dev/null +++ b/ffi/luacode.go @@ -0,0 +1,163 @@ +package ffi + +//go:generate go run ../build buildProject Luau.VM Luau.Compiler Luau.Ast + +/* +#cgo CFLAGS: -Iluau/Compiler/include +#cgo LDFLAGS: -L_obj -lLuau.Compiler -lLuau.Ast -lm -lstdc++ +#include +#include +*/ +import "C" +import "unsafe" + +type CompileConstant *C.void + +type CompileOptions struct { + OptimizationLevel int + DebugLevel int + TypeInfoLevel int + CoverageLevel int + + VectorLib string + VectorCtor string + VectorType string + + MutableGlobals []string + UserdataTypes []string + + LibrariesWithKnownMembers []string + LibraryMemberTypeCb unsafe.Pointer + LibraryMemberConstantCb unsafe.Pointer + + DisabledBuiltins []string +} + +func LuauCompile(source string, size int, options *CompileOptions, outsize *int) []byte { + var goArrToC = func(goArr []string) **C.char { + if len(goArr) == 0 { + return nil + } + + // Allocate space for N+1 pointers (extra for NULL terminator) + arr := C.malloc(C.size_t(len(goArr)+1) * C.size_t(unsafe.Sizeof(uintptr(0)))) + slice := (*[1 << 30]*C.char)(arr)[: len(goArr)+1 : len(goArr)+1] + + for i, s := range goArr { + slice[i] = C.CString(s) + } + slice[len(goArr)] = nil // NULL terminator + return (**C.char)(arr) + } + + var freeCArr = func(arr **C.char) { + if arr == nil { + return + } + // Free strings until we hit NULL + for i := 0; ; i++ { + ptr := *(**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(arr)) + uintptr(i)*unsafe.Sizeof(uintptr(0)))) + if ptr == nil { + break + } + C.free(unsafe.Pointer(ptr)) + } + C.free(unsafe.Pointer(arr)) + } + + csource := C.CString(source) + coutsize := C.size_t(*outsize) + coptions := (*C.lua_CompileOptions)(C.malloc(C.sizeof_lua_CompileOptions)) + + coptions.optimizationLevel = C.int(options.OptimizationLevel) + coptions.debugLevel = C.int(options.DebugLevel) + coptions.typeInfoLevel = C.int(options.TypeInfoLevel) + coptions.coverageLevel = C.int(options.CoverageLevel) + + coptions.vectorLib = C.CString(options.VectorLib) + coptions.vectorCtor = C.CString(options.VectorCtor) + coptions.vectorType = C.CString(options.VectorType) + + coptions.mutableGlobals = goArrToC(options.MutableGlobals) + coptions.userdataTypes = goArrToC(options.UserdataTypes) + coptions.librariesWithKnownMembers = goArrToC(options.LibrariesWithKnownMembers) + + coptions.libraryMemberTypeCb = C.lua_LibraryMemberTypeCallback(options.LibraryMemberTypeCb) + coptions.libraryMemberConstantCb = C.lua_LibraryMemberConstantCallback(options.LibraryMemberConstantCb) + + coptions.disabledBuiltins = goArrToC(options.DisabledBuiltins) + + defer C.free(unsafe.Pointer(csource)) + defer C.free(unsafe.Pointer(coptions.vectorLib)) + defer C.free(unsafe.Pointer(coptions.vectorCtor)) + defer C.free(unsafe.Pointer(coptions.vectorType)) + defer C.free(unsafe.Pointer(coptions)) + + defer freeCArr(coptions.mutableGlobals) + defer freeCArr(coptions.userdataTypes) + defer freeCArr(coptions.librariesWithKnownMembers) + defer freeCArr(coptions.disabledBuiltins) + + bytecode := C.luau_compile(csource, C.size_t(size), coptions, &coutsize) + defer C.free(unsafe.Pointer(bytecode)) + + *outsize = int(coutsize) + result := make([]byte, coutsize) + + copy(result, (*[1 << 30]byte)(unsafe.Pointer(bytecode))[:coutsize:coutsize]) + + return result +} + +func LuauSetCompileConstantNil(constant unsafe.Pointer) { + C.luau_set_compile_constant_nil((*C.lua_CompileConstant)(constant)) +} + +func LuauSetCompileConstantBoolean(constant unsafe.Pointer, b bool) { + var cBool C.int + if b { + cBool = 1 + } else { + cBool = 0 + } + C.luau_set_compile_constant_boolean((*C.lua_CompileConstant)(constant), cBool) +} + +func LuauSetCompileConstantNumber(constant unsafe.Pointer, n float64) { + C.luau_set_compile_constant_number((*C.lua_CompileConstant)(constant), C.double(n)) +} + +func LuauSetCompileConstantVector(constant unsafe.Pointer, x, y, z, w float32) { + C.luau_set_compile_constant_vector( + (*C.lua_CompileConstant)(constant), + C.float(x), + C.float(y), + C.float(z), + C.float(w), + ) +} + +func LuauSetCompileConstantString(constant unsafe.Pointer, s string) { + if len(s) == 0 { + C.luau_set_compile_constant_string((*C.lua_CompileConstant)(constant), nil, 0) + return + } + + bytes := []byte(s) + ptr := (*C.char)(unsafe.Pointer(&bytes[0])) + size := C.size_t(len(s)) + + C.luau_set_compile_constant_string((*C.lua_CompileConstant)(constant), ptr, size) +} + +func LuauSetCompileConstantStringBytes(constant unsafe.Pointer, data []byte) { + if len(data) == 0 { + C.luau_set_compile_constant_string((*C.lua_CompileConstant)(constant), nil, 0) + return + } + + ptr := (*C.char)(unsafe.Pointer(&data[0])) + size := C.size_t(len(data)) + + C.luau_set_compile_constant_string((*C.lua_CompileConstant)(constant), ptr, size) +} diff --git a/ffi/luacode_test.go b/ffi/luacode_test.go new file mode 100644 index 0000000..b46e267 --- /dev/null +++ b/ffi/luacode_test.go @@ -0,0 +1,151 @@ +package ffi_test + +import ( + "slices" + "testing" + + "github.com/CompeyDev/lei/ffi" +) + +func TestLuauCompile_Basic(t *testing.T) { + source := ` + local function add(a, b) + return a + b + end + return add(1, 2) + ` + + outsize := 0 + options := &ffi.CompileOptions{ + OptimizationLevel: 1, + DebugLevel: 1, + TypeInfoLevel: 0, + CoverageLevel: 0, + } + + bytecode := ffi.LuauCompile(source, len(source), options, &outsize) + + if bytecode == nil { + t.Fatal("LuauCompile returned nil") + } + + if outsize == 0 { + t.Fatal("Output size is 0") + } + + if len(bytecode) != outsize { + t.Errorf("Expected bytecode length %d, got %d", outsize, len(bytecode)) + } + + t.Logf("Compiled successfully: %d bytes", outsize) +} + +func TestLuauCompile_SyntaxError(t *testing.T) { + source := ` + local function broken( + -- missing closing parenthesis and end + ` + + outsize := 0 + options := &ffi.CompileOptions{ + OptimizationLevel: 1, + DebugLevel: 1, + } + + bytecode := ffi.LuauCompile(source, len(source), options, &outsize) + + // The function should still return bytecode containing the error + if bytecode == nil { + t.Fatal("LuauCompile returned nil even for error case") + } + + t.Logf("Error bytecode: %d bytes", outsize) +} + +func TestLuauCompile_WithOptions(t *testing.T) { + source := ` + local x = vector.create(1, 2, 3) + return x + ` + + outsize := 0 + options := &ffi.CompileOptions{ + OptimizationLevel: 2, + DebugLevel: 2, + TypeInfoLevel: 1, + CoverageLevel: 1, + VectorLib: "vector", + VectorCtor: "create", + VectorType: "vector", + MutableGlobals: []string{"_G"}, + UserdataTypes: []string{"MyUserdata"}, + DisabledBuiltins: []string{"math.random"}, + } + + bytecode := ffi.LuauCompile(source, len(source), options, &outsize) + + if bytecode == nil { + t.Fatal("LuauCompile returned nil") + } + + if outsize == 0 { + t.Fatal("Output size is 0") + } + + t.Logf("Compiled with options: %d bytes", outsize) +} + +func TestLuauCompile_EmptySource(t *testing.T) { + source := "" + outsize := 0 + options := &ffi.CompileOptions{ + OptimizationLevel: 1, + DebugLevel: 1, + } + + bytecode := ffi.LuauCompile(source, len(source), options, &outsize) + + if bytecode == nil { + t.Fatal("LuauCompile returned nil for empty source") + } + + t.Logf("Empty source compiled: %d bytes", outsize) +} + +func TestLuauCompile_BinaryDataIntegrity(t *testing.T) { + source := `return "test"` + outsize := 0 + options := &ffi.CompileOptions{ + OptimizationLevel: 1, + DebugLevel: 1, + } + + bytecode := ffi.LuauCompile(source, len(source), options, &outsize) + hasNullByte := slices.Contains(bytecode, 0) + + t.Logf("Bytecode contains null bytes: %v", hasNullByte) + t.Logf("Bytecode length: %d, outsize: %d", len(bytecode), outsize) + + if len(bytecode) != outsize { + t.Errorf("Bytecode length mismatch: expected %d, got %d", outsize, len(bytecode)) + } +} + +func TestLuauCompile_ExecuteBytecode(t *testing.T) { + source := `return 42` + outsize := 0 + options := &ffi.CompileOptions{OptimizationLevel: 1, DebugLevel: 1} + + bytecode := ffi.LuauCompile(source, len(source), options, &outsize) + + L := ffi.LNewState() + defer ffi.LuaClose(L) + + ffi.LuauLoad(L, "test", bytecode, uint64(outsize), 0) + ffi.Pcall(L, 0, 1, 0) + + result := ffi.ToInteger(L, -1) + if result != 42 { + t.Error("Executed result did not match") + } +} diff --git a/ffi/luacodegen.go b/ffi/luacodegen.go new file mode 100644 index 0000000..ae1a2c9 --- /dev/null +++ b/ffi/luacodegen.go @@ -0,0 +1,24 @@ +package ffi + +//go:generate go run ../build buildProject Luau.VM Luau.CodeGen + +/* +#cgo CFLAGS: -Iluau/VM/include -Iluau/CodeGen/include +#cgo LDFLAGS: -L_obj -lLuau.VM -lLuau.CodeGen -lm -lstdc++ +#include +#include +#include +*/ +import "C" + +func LuauCodegenSupported() bool { + return C.luau_codegen_supported() == 1 +} + +func LuauCodegenCreate(state *C.lua_State) { + C.luau_codegen_create(state) +} + +func LuauCodegenCompile(state *C.lua_State, idx int) { + C.luau_codegen_compile(state, C.int(idx)) +} diff --git a/ffi/util.go b/ffi/util.go deleted file mode 100644 index 9b7982a..0000000 --- a/ffi/util.go +++ /dev/null @@ -1,50 +0,0 @@ -package ffi - -//#include -import "C" -import "unsafe" - -func GetSubtable(L *LuaState, idx int32, fname string) bool { - absIdx := AbsIndex(L, idx) - if !CheckStack(L, 3+20) { - panic("stack overflow") - } - - PushString(L, fname) - if GetTable(L, absIdx) == LUA_TTABLE { - return true - } - - Pop(L, 1) - NewTable(L) - PushString(L, fname) - PushValue(L, -2) - SetTable(L, absIdx) - return false -} - -func RequireLib(L *LuaState, modName string, openF unsafe.Pointer, isGlobal bool) { - if !CheckStack(L, 3+20) { - LErrorL(L, "stack overflow") - } - - GetSubtable(L, LUA_REGISTRYINDEX, "_LOADED") - if GetField(L, -1, modName) == LUA_TNIL { - Pop(L, 1) - PushCFunction(L, openF) - PushString(L, modName) - Call(L, 1, 1) - PushValue(L, -1) - SetField(L, -3, modName) - } - - if isGlobal { - PushNil(L) - SetGlobal(L, modName) - } else { - PushValue(L, -1) - SetGlobal(L, modName) - } - - Replace(L, -2) -} diff --git a/ffi/vector3.go b/ffi/vector3.go deleted file mode 100644 index b5b918e..0000000 --- a/ffi/vector3.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !LUAU_VECTOR4 - -package ffi - -/* -#cgo CFLAGS: -Iluau/VM/include -I/usr/lib/gcc/x86_64-pc-linux-gnu/14.1.1/include -// #cgo LDFLAGS: -L${SRCDIR}/luau/cmake -lLuau.VM -lm -lstdc++ -#include -*/ -import "C" - -func PushVector(L *LuaState, x float32, y float32, z float32) { - C.lua_pushvector(L, C.float(x), C.float(y), C.float(z)) -} diff --git a/ffi/vector4.go b/ffi/vector4.go deleted file mode 100644 index 45bd0dc..0000000 --- a/ffi/vector4.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build LUAU_VECTOR4 - -package internal - -/* -#cgo CFLAGS: -Iluau/VM/include -I/usr/lib/gcc/x86_64-pc-linux-gnu/14.1.1/include -DLUA_VECTOR_SIZE=4 -// #cgo LDFLAGS: -L${SRCDIR}/luau/cmake -lLuau.VM -lm -lstdc++ -#include -*/ -import "C" - -func PushVector(L *LuaState, x float32, y float32, z float32, w float32) { - C.lua_pushvector(L, C.float(x), C.float(y), C.float(z), C.float(w)) -} diff --git a/go.mod b/go.mod index 7cd75e9..558c3ab 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,3 @@ module github.com/CompeyDev/lei go 1.23.0 toolchain go1.24.2 - -require ( - github.com/gookit/color v1.5.4 // indirect - github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect - golang.org/x/sys v0.32.0 // indirect - golang.org/x/term v0.31.0 // indirect -) diff --git a/go.sum b/go.sum index 8e636d2..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +0,0 @@ -github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= -github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= -github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8= -github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= -golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= diff --git a/lua/buffer.go b/lua/buffer.go new file mode 100644 index 0000000..f47893c --- /dev/null +++ b/lua/buffer.go @@ -0,0 +1,69 @@ +package lua + +import ( + "unsafe" + + "github.com/CompeyDev/lei/ffi" +) + +type LuaBuffer struct { + vm *Lua + index int + size uint64 +} + +func (b *LuaBuffer) Size() uint64 { return b.size } +func (b *LuaBuffer) IsEmpty() bool { return b.size == 0 } + +func (b *LuaBuffer) Read(offset uint64, count uint64) []byte { + b.deref(b.vm) + defer ffi.Pop(b.vm.state(), 1) + + if buf := ffi.ToBuffer(b.vm.state(), -1, &b.size); buf != nil && offset <= b.size { + // Clamp to the size if the count exceeds it + if offset+count > b.size { + count = b.size - offset + } + + // Copy data to Go owned byte array for safety + data := make([]byte, count) + slice := unsafe.Slice((*byte)(buf), b.size) + copy(data, slice[offset:offset+count]) + + return data + } + + return nil +} + +func (b *LuaBuffer) Write(offset uint64, data []byte) { + if len(data) == 0 { + return + } + + b.deref(b.vm) + defer ffi.Pop(b.vm.state(), 1) + + if buf := ffi.ToBuffer(b.vm.state(), -1, &b.size); buf != nil && offset <= b.size { + // Truncate the data to buffer end if exceeding + count := uint64(len(data)) + if offset+count > b.size { + count = b.size - offset + } + + dest := unsafe.Slice((*byte)(buf), b.size) + copy(dest[offset:offset+count], data[:count]) + } +} + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaBuffer)(nil) + +func (b *LuaBuffer) lua() *Lua { return b.vm } +func (b *LuaBuffer) ref() int { return b.index } +func (b *LuaBuffer) deref(lua *Lua) int { + return int(ffi.GetRef(lua.state(), int32(b.ref()))) +} diff --git a/lua/chunk.go b/lua/chunk.go new file mode 100644 index 0000000..8469182 --- /dev/null +++ b/lua/chunk.go @@ -0,0 +1,123 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +type ChunkMode int + +const ( + // Raw text source code that must be compiled before executing + ChunkModeSOURCE ChunkMode = iota + + // Compiled bytecode that can be directly executed + ChunkModeBYTECODE + + // A C function pointer loaded onto the stack + ChunkModeFUNCTION +) + +type LuaChunk struct { + vm *Lua + env *LuaTable + mode ChunkMode + + // Values only applicable for source or bytecode types + name *string + data []byte + compiler *Compiler + + // An index is held for chunks of the function type + index int +} + +func (c *LuaChunk) Environment() *LuaTable { return c.env } +func (c *LuaChunk) SetEnvironment(env *LuaTable) { c.env = env } + +func (c *LuaChunk) Mode() ChunkMode { return c.mode } +func (c *LuaChunk) SetMode(mode ChunkMode) { c.mode = mode } + +func (c *LuaChunk) Compiler() *Compiler { return c.compiler } +func (c *LuaChunk) SetCompiler(compiler *Compiler) { c.compiler = compiler } + +func (c *LuaChunk) Call(args ...LuaValue) ([]LuaValue, error) { + state := c.vm.state() + + initialStack := ffi.GetTop(state) // Track initial stack size + c.pushToStack() + + argsCount := len(args) + if c.mode == ChunkModeFUNCTION { + // Chunk is a C function, push length and args + ffi.PushNumber(state, ffi.LuaNumber(argsCount)) + argsCount++ + for _, arg := range args { + arg.deref(c.vm) + } + } + + status := ffi.Pcall(state, int32(argsCount), -1, 0) + if status != ffi.LUA_OK { + return nil, newLuaError(state, int(status)) + } + + stackNow := ffi.GetTop(state) + resultsCount := stackNow - initialStack + + if resultsCount == 0 { + return nil, nil + } + + results := make([]LuaValue, resultsCount) + for i := range resultsCount { + // The stack has grown by the number of returns of the chunk from the + // initial value tracked at the beginning. We add one to that due to + // Lua's 1-based indexing system + stackIndex := int32(initialStack + i + 1) + results[i] = intoLuaValue(c.vm, stackIndex) + } + + return results, nil +} + +func (c *LuaChunk) pushToStack() error { + state := c.vm.state() + + if c.data == nil { + // Chunk is of a C function, need to deref + ffi.GetRef(state, int32(c.index)) + } else { + // Chunk is bytecode, load it into the VM + var bytecode []byte + + if c.mode == ChunkModeSOURCE { + // Need to compile + var err error + if bytecode, err = c.compiler.Compile(string(c.data)); err != nil { + return err + } + } else { + // Already compiled + bytecode = c.data + } + + hasLoaded := ffi.LuauLoad(state, *c.name, bytecode, uint64(len(bytecode)), 0) + if !hasLoaded { + // Miscellaneous error is denoted with a -1 code + return &LuaError{Code: -1, Message: ffi.ToLString(state, -1, nil)} + } + + // Apply native code generation if requested + if ffi.LuauCodegenSupported() && c.vm.codegenEnabled { + ffi.LuauCodegenCompile(state, -1) + } + } + + if c.env != nil { + // If a custom environment was provided, set it for the loaded value + c.env.deref(c.vm) + if ok := ffi.Setfenv(c.vm.state(), -2); !ok { + return &LuaError{Code: -1, Message: "Failed to set environment for chunk"} + } + } + + return nil +} diff --git a/lua/compiler.go b/lua/compiler.go new file mode 100644 index 0000000..31d8edb --- /dev/null +++ b/lua/compiler.go @@ -0,0 +1,121 @@ +package lua + +import ( + "github.com/CompeyDev/lei/ffi" +) + +type Compiler struct{ options *ffi.CompileOptions } + +func (c *Compiler) WithOptimizationLevel(lvl int) *Compiler { + opts := *c.options + opts.OptimizationLevel = lvl + return &Compiler{options: &opts} +} + +func (c *Compiler) WithDebugLevel(lvl int) *Compiler { + opts := *c.options + opts.DebugLevel = lvl + return &Compiler{options: &opts} +} + +func (c *Compiler) WithTypeInfoLevel(lvl int) *Compiler { + opts := *c.options + opts.TypeInfoLevel = lvl + return &Compiler{options: &opts} +} + +func (c *Compiler) WithCoverageLevel(lvl int) *Compiler { + opts := *c.options + opts.CoverageLevel = lvl + return &Compiler{options: &opts} +} + +func (c *Compiler) WithMutableGlobals(globals []string) *Compiler { + opts := *c.options + opts.MutableGlobals = append(append([]string{}, c.options.MutableGlobals...), globals...) + return &Compiler{options: &opts} +} + +func (c *Compiler) WithUserdataTypes(types []string) *Compiler { + opts := *c.options + opts.UserdataTypes = append(append([]string{}, c.options.UserdataTypes...), types...) + return &Compiler{options: &opts} +} + +func (c *Compiler) WithConstantLibraries(libs []string) *Compiler { + opts := *c.options + opts.LibrariesWithKnownMembers = append(append([]string{}, c.options.LibrariesWithKnownMembers...), libs...) + return &Compiler{options: &opts} +} + +func (c *Compiler) WithDisabledBuiltins(builtins []string) *Compiler { + opts := *c.options + opts.DisabledBuiltins = append(append([]string{}, c.options.DisabledBuiltins...), builtins...) + return &Compiler{options: &opts} +} + +func (c *Compiler) Compile(source string) ([]byte, error) { + outsize := 0 + bytecode := ffi.LuauCompile(source, len(source), c.options, &outsize) + + // Check for compilation error + // If bytecode starts with 0, the rest is an error message starting with ':' + // See https://github.com/luau-lang/luau/blob/0.671/Compiler/src/Compiler.cpp#L4410 + if outsize > 0 && bytecode[0] == 0 { + // Extract error message (skip the 0 byte and ':' character) + message := "" + if outsize > 2 { + message = string(bytecode[2:]) + } + + // Check if input is incomplete (ends with ) + incompleteInput := len(message) > 0 && + (len(message) >= 5 && message[len(message)-5:] == "") + + return nil, &SyntaxError{ + IncompleteInput: incompleteInput, + Message: message, + } + } + + return bytecode, nil +} + +func DefaultCompiler() *Compiler { + return &Compiler{options: &ffi.CompileOptions{ + OptimizationLevel: 1, + DebugLevel: 1, + TypeInfoLevel: 0, + CoverageLevel: 0, + MutableGlobals: make([]string, 0), + UserdataTypes: make([]string, 0), + LibrariesWithKnownMembers: make([]string, 0), + DisabledBuiltins: make([]string, 0), + }} +} + +type SyntaxError struct { + IncompleteInput bool + Message string +} + +func (e *SyntaxError) Error() string { + if e.IncompleteInput { + return "incomplete input: " + e.Message + } + + return "syntax error: " + e.Message +} + +func isBytecode(data []byte) bool { + // Luau bytecode starts with a version byte (currently 0-5 range) + // See: https://github.com/luau-lang/luau/blob/0.671/Compiler/src/BytecodeBuilder.cpp#L13 + if len(data) == 0 { + return false + } + + // Check if the first byte is within the bytecode versionByte range (source code starting with + // these bytes would be extremely rare) + versionByte := data[0] + return versionByte >= ffi.LBC_VERSION_MIN && versionByte <= ffi.LBC_VERSION_MAX +} diff --git a/lua/errors.go b/lua/errors.go new file mode 100644 index 0000000..dcbcd83 --- /dev/null +++ b/lua/errors.go @@ -0,0 +1,38 @@ +package lua + +import ( + "fmt" + + "github.com/CompeyDev/lei/ffi" +) + +type LuaError struct { + Code int + Message string +} + +func (e *LuaError) Error() string { + switch e.Code { + case ffi.LUA_ERRSYNTAX: + return "syntax error: " + e.Message + case ffi.LUA_ERRMEM: + return "memory allocation error: " + e.Message + case ffi.LUA_ERRERR: + return "error handler error: " + e.Message + default: + return fmt.Sprintf("load error (code %d): %s", e.Code, e.Message) + } +} + +func newLuaError(state *ffi.LuaState, code int) *LuaError { + if code != ffi.LUA_OK { + message := ffi.ToString(state, -1) + err := &LuaError{Code: code, Message: message} + + ffi.Pop(state, 1) + + return err + } + + return nil +} diff --git a/lua/memory.go b/lua/memory.go new file mode 100644 index 0000000..2b63507 --- /dev/null +++ b/lua/memory.go @@ -0,0 +1,168 @@ +package lua + +/* +#include +#include + +extern void* allocator(void* ud, void* ptr, size_t osize, size_t nsize); +*/ +import "C" +import ( + "runtime" + "unsafe" + + "github.com/CompeyDev/lei/ffi" +) + +const SYS_MIN_ALIGN = unsafe.Sizeof(uintptr(0)) * 2 + +type MemoryState struct { + usedMemory int + memoryLimit int + ignoreLimit bool + limitReached bool +} + +func NewMemoryState() *MemoryState { + return &MemoryState{ + usedMemory: 0, + memoryLimit: 0, + ignoreLimit: false, + limitReached: false, + } +} + +func (m *MemoryState) Used() int { + return m.usedMemory +} + +func (m *MemoryState) Limit() int { + return m.memoryLimit +} + +func (m *MemoryState) SetLimit(limit int) int { + prevLimit := m.memoryLimit + m.memoryLimit = limit + return prevLimit +} + +func RelaxLimitWith(state *ffi.LuaState, f func()) { + memState := getMemoryState(state) + if memState != nil { + memState.ignoreLimit = true + f() + memState.ignoreLimit = false + } else { + f() + } +} + +func LimitReached(state *ffi.LuaState) bool { + return getMemoryState(state).limitReached +} + +func getMemoryState(state *ffi.LuaState) *MemoryState { + var memState unsafe.Pointer + ffi.GetAllocF(state, &memState) + + if memState == nil { + panic("Lua state has no allocator userdata") + } + + return (*MemoryState)(memState) +} + +//export allocator +func allocator(ud, ptr unsafe.Pointer, osize, nsize C.size_t) unsafe.Pointer { + memState := (*MemoryState)(ud) + + // Avoid GC of pointer for this call period + runtime.KeepAlive(memState) + memState.limitReached = false + + // Free memory + if nsize == 0 { + if ptr != nil { + C.free(ptr) + memState.usedMemory -= int(osize) + } + return nil + } + + if nsize > C.size_t(^uint(0)>>1) { + return nil + } + + var memDiff int + if ptr != nil { + memDiff = int(nsize) - int(osize) + } else { + memDiff = int(nsize) + } + + memLimit := memState.memoryLimit + newUsedMemory := memState.usedMemory + memDiff + if memLimit > 0 && newUsedMemory > memLimit && !memState.ignoreLimit { + memState.limitReached = true + panic("allocations exceeded set limit for memory") + } + memState.usedMemory = newUsedMemory + + var newPtr unsafe.Pointer + if ptr == nil { + newPtr = C.malloc(nsize) + if newPtr == nil { + panic("memory allocation failed") + } + } else { + newPtr = C.realloc(ptr, nsize) + if newPtr == nil { + panic("memory reallocation failed") + } + } + + return newPtr +} + +type StateWithMemory struct { + luaState *ffi.LuaState + memState *MemoryState + pinner *runtime.Pinner +} + +func newStateWithAllocator(initState *MemoryState) *StateWithMemory { + var memState *MemoryState + if initState != nil { + memState = initState + } else { + memState = NewMemoryState() + } + + // Pin the memory state to prevent GC from moving it + pinner := &runtime.Pinner{} + pinner.Pin(memState) + + state := ffi.NewState(C.allocator, unsafe.Pointer(memState)) + + return &StateWithMemory{ + luaState: state, + memState: memState, + pinner: pinner, + } +} + +func (s *StateWithMemory) LuaState() *ffi.LuaState { + return s.luaState +} + +func (s *StateWithMemory) MemState() *MemoryState { + return s.memState +} + +func (s *StateWithMemory) Close() { + if s.pinner != nil { + s.pinner.Unpin() + } + + ffi.LuaClose(s.luaState) +} diff --git a/lua/nil.go b/lua/nil.go new file mode 100644 index 0000000..c068392 --- /dev/null +++ b/lua/nil.go @@ -0,0 +1,15 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +type LuaNil struct{} + +// +// LuaValue Implementation +// + +var _ LuaValue = (*LuaNil)(nil) + +func (n *LuaNil) lua() *Lua { return nil } +func (n *LuaNil) ref() int { return ffi.LUA_REFNIL } +func (n *LuaNil) deref(_ *Lua) int { return 0 } diff --git a/lua/number.go b/lua/number.go new file mode 100644 index 0000000..f784d06 --- /dev/null +++ b/lua/number.go @@ -0,0 +1,21 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +type LuaNumber float64 + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaNumber)(nil) + +// Numbers are cheap to copy, so we don't store the reference index + +func (n LuaNumber) lua() *Lua { return nil } +func (n LuaNumber) ref() int { return ffi.LUA_NOREF } +func (n LuaNumber) deref(lua *Lua) int { + state := lua.state() + ffi.PushNumber(state, ffi.LuaNumber(n)) + return int(ffi.GetTop(state)) +} diff --git a/lua/registry.c b/lua/registry.c new file mode 100644 index 0000000..7f1eec4 --- /dev/null +++ b/lua/registry.c @@ -0,0 +1,31 @@ +#include +#include +#include +#include <_cgo_export.h> + +typedef struct registryTrampolineImpl_return trampolineResult; + +int registryTrampoline(lua_State* L) { + uintptr_t* handle_ptr = (uintptr_t*)lua_touserdata(L, lua_upvalueindex(1)); + trampolineResult result = registryTrampolineImpl(L, *handle_ptr); + + // Handle errors after crossing the C boundary to prevent a longjmp triggered + // from the Go side, which would violate Go's stack winding rules + + int status = result.r0; + char* err = result.r1; + + // TODO: Figure out what happens if some Lua code calls this without a pcall, longjmp? + if (err != NULL) { + lua_pushstring(L, err); + free(err); + lua_error(L); + } + + return status; +} + +void registryTrampolineDtor(lua_State* L) { + uintptr_t* handle_ptr = (uintptr_t*)lua_touserdata(L, lua_upvalueindex(1)); + registryTrampolineDtorImpl(L, *handle_ptr); +} diff --git a/lua/registry.go b/lua/registry.go new file mode 100644 index 0000000..81b6258 --- /dev/null +++ b/lua/registry.go @@ -0,0 +1,127 @@ +package lua + +import ( + "fmt" + "runtime/cgo" + + "github.com/CompeyDev/lei/ffi" +) + +//go:generate go tool cgo -- -I../ffi/luau/VM/include $GOFILE + +/* +#cgo CFLAGS: -I../ffi/luau/VM/include +#cgo LDFLAGS: -L../ffi/_obj -lLuau.VM -lm -lstdc++ + +#include +#include +#include + +int registryTrampoline(lua_State* L); +void registryTrampolineDtor(lua_State* L); +*/ +import "C" + +var registryTrampoline = C.registryTrampoline +var registryTrampolineDtor = C.registryTrampolineDtor + +//export registryTrampolineImpl +func registryTrampolineImpl(lua *C.lua_State, handle uintptr) (C.int, *C.char) { + rawState := (*ffi.LuaState)(lua) + state := &Lua{ + // FIXME: what about the function registry? + inner: &StateWithMemory{ + memState: getMemoryState(rawState), + luaState: rawState, + }, + } + + entry := cgo.Handle(handle).Value().(*functionEntry) + + fn, ok := entry.registry.get(entry.id) + if !ok { + return C.int(-1), C.CString("function not found in registry") + } + + argsCount := int(ffi.ToNumber(rawState, 1)) + args := make([]LuaValue, argsCount) + + for i := range argsCount { + // Lua stack is 1-based, and the first argument is at index 2 (since index 1 is the count) + stackIndex := int32(i + 2) + args[i] = intoLuaValue(state, stackIndex) + } + + returns, callErr := fn(state, args...) + + // SAFETY: This must be caught elsewhere to avoid the longjmp + if callErr != nil { + return C.int(-1), C.CString(callErr.Error()) + } + + for _, ret := range returns { + ret.deref(state) + } + + return C.int(len(returns)), nil +} + +//export registryTrampolineDtorImpl +func registryTrampolineDtorImpl(_ *C.lua_State, handle C.uintptr_t) { + entry := cgo.Handle(handle).Value().(*functionEntry) + delete(entry.registry.functions, entry.id) + cgo.Handle(handle).Delete() +} + +type GoFunction = func(lua *Lua, args ...LuaValue) ([]LuaValue, error) + +type functionRegistry struct { + recoverPanics bool + functions map[uintptr]GoFunction + nextID uintptr +} + +type functionEntry struct { + registry *functionRegistry + id uintptr +} + +func newFunctionRegistry() *functionRegistry { + return &functionRegistry{ + functions: make(map[uintptr]GoFunction), + } +} + +func (fr *functionRegistry) register(fn GoFunction) *functionEntry { + fr.nextID++ + id := fr.nextID + fr.functions[id] = fn + return &functionEntry{registry: fr, id: id} +} + +func (fr *functionRegistry) get(id uintptr) (GoFunction, bool) { + fn, ok := fr.functions[id] + + if fr.recoverPanics { + rawFn := fn + fn = func(lua *Lua, args ...LuaValue) (result []LuaValue, err error) { + defer func() { + // Deferred panic handler + if recv := recover(); recv != nil { + switch v := recv.(type) { + case error: + err = v + default: + err = fmt.Errorf("go panic: %v", v) + } + } + }() + + result, err = rawFn(lua, args...) + + return result, err + } + } + + return fn, ok +} diff --git a/lua/state.go b/lua/state.go new file mode 100644 index 0000000..1af5849 --- /dev/null +++ b/lua/state.go @@ -0,0 +1,276 @@ +package lua + +import ( + "runtime" + "runtime/cgo" + "unsafe" + + "github.com/CompeyDev/lei/ffi" +) + +type LuaOptions struct { + InitMemoryState *MemoryState + CollectGarbage bool + IsSafe bool + CatchPanics bool + EnableCodegen bool + EnableSandbox bool + Compiler *Compiler +} + +type Lua struct { + inner *StateWithMemory + compiler *Compiler + fnRegistry *functionRegistry + codegenEnabled bool +} + +func (l *Lua) Load(name string, input []byte) *LuaChunk { + var mode ChunkMode = ChunkModeSOURCE + if isBytecode(input) { + mode = ChunkModeBYTECODE + } + + return &LuaChunk{ + vm: l, + data: input, + name: &name, + mode: mode, + compiler: l.compiler, + } +} + +func (l *Lua) Memory() *MemoryState { + return l.inner.MemState() +} + +func (l *Lua) SetCodegen(enabled bool) bool { + // NOTE: disabling codegen if it was enabled before still has the codegen + // backend enabled for the state since we already called LuauCodegenCreate + // during state initialization + + supported := ffi.LuauCodegenSupported() + if supported { + l.codegenEnabled = enabled + } + + return supported +} + +func (l *Lua) GetGlobal(global string) LuaValue { + state := l.state() + + ffi.GetGlobal(state, global) + value := intoLuaValue(l, -1) + + ffi.Pop(state, 1) + return value +} + +func (l *Lua) SetGlobal(global string, value LuaValue) { + value.deref(l) + ffi.SetGlobal(l.state(), global) +} + +func (l *Lua) CreateTable() *LuaTable { + state := l.inner.luaState + + ffi.NewTable(state) + index := ffi.Ref(state, -1) + + t := &LuaTable{vm: l, index: int(index)} + runtime.SetFinalizer(t, valueUnrefer[*LuaTable](l)) + + return t +} + +func (l *Lua) CreateString(str string) *LuaString { + state := l.inner.luaState + + ffi.PushString(state, str) + index := ffi.Ref(state, -1) + + s := &LuaString{vm: l, index: int(index)} + runtime.SetFinalizer(s, valueUnrefer[*LuaString](l)) + + return s +} + +func (l *Lua) CreateFunction(name *string, fn GoFunction) *LuaChunk { + state := l.state() + + entry := l.fnRegistry.register(fn) + pushUpvalue(state, entry, registryTrampolineDtor) + + ffi.PushCClosureK(state, registryTrampoline, name, 1, nil) + + index := ffi.Ref(state, -1) + c := &LuaChunk{vm: l, index: int(index), name: name, mode: ChunkModeFUNCTION} + runtime.SetFinalizer(c, func(c *LuaChunk) { ffi.Unref(state, index) }) + + return c +} + +func (l *Lua) CreateUserData(value IntoUserData) *LuaUserData { + state := l.state() + userdata := &LuaUserData{vm: l, inner: value} + + // TOOD: custom destructor support + ud := ffi.NewUserdata(state, uint64(unsafe.Sizeof(uintptr(0)))) + *(*IntoUserData)(unsafe.Pointer(ud)) = value + + if ffi.LNewMetatable(state, "") { + fieldsMap := newFieldMap() + methodsMap := newMethodMap(l.fnRegistry) + metaMethodsMap := newMethodMap(l.fnRegistry) + + value.Fields(fieldsMap) + value.Methods(methodsMap) + value.MetaMethods(metaMethodsMap) + + pushUpvalue(state, fieldsMap, fieldMapDtor) + pushUpvalue(state, methodsMap, methodMapDtor) + + ffi.PushCClosureK(state, indexMt, nil, 2, nil) + ffi.SetField(state, -2, "__index") + + for method, impl := range metaMethodsMap.inner { + if method == "__index" { + panic("Cannot have a manual __index implementation") + } + + pushUpvalue(state, impl, registryTrampolineDtor) + ffi.PushCClosureK(state, registryTrampoline, nil, 1, nil) + ffi.SetField(state, -2, method) + } + } + + ffi.SetMetatable(state, -2) + + userdata.index = int(ffi.Ref(state, -1)) + runtime.SetFinalizer(userdata, valueUnrefer[*LuaUserData](l)) + + return userdata +} + +func (l *Lua) CreateBuffer(size uint64) *LuaBuffer { + state := l.state() + + ffi.NewBuffer(state, size) + index := ffi.Ref(state, -1) + + b := &LuaBuffer{vm: l, index: int(index), size: size} + runtime.SetFinalizer(b, valueUnrefer[*LuaBuffer](l)) + + return b +} + +func (l *Lua) CreateThread(chunk *LuaChunk) (*LuaThread, error) { + mainState := l.state() + threadState := ffi.NewThread(mainState) + + chunk.pushToStack() + ffi.XMove(mainState, threadState, 1) + + index := ffi.Ref(mainState, -1) + t := &LuaThread{vm: l, chunk: chunk, index: int(index)} + + runtime.SetFinalizer(t, func(t *LuaThread) { + ffi.LuaClose(t.state()) + ffi.Unref(l.state(), int32(t.ref())) + }) + + return t, nil +} + +func (l *Lua) SetCompiler(compiler *Compiler) { + l.compiler = compiler +} + +func (l *Lua) Close() { + l.inner.Close() +} + +func (l *Lua) state() *ffi.LuaState { + return l.inner.luaState +} + +func New() *Lua { + return NewWith(StdLibALLSAFE, LuaOptions{ + CollectGarbage: true, + IsSafe: true, + CatchPanics: true, + EnableCodegen: true, + Compiler: DefaultCompiler(), + }) +} + +func NewWith(libs StdLib, options LuaOptions) *Lua { + state := newStateWithAllocator(options.InitMemoryState) + if state == nil { + panic("Failed to create Lua state") + } + + ffi.OpenBase(state.luaState) + luaLibs := map[StdLib]func(*ffi.LuaState){ + StdLibCOROUTINE: ffi.OpenCoroutine, + StdLibTABLE: ffi.OpenTable, + StdLibOS: ffi.OpenOs, + StdLibSTRING: ffi.OpenString, + StdLibUTF8: ffi.OpenUtf8, + StdLibBIT: ffi.OpenBit32, + StdLibBUFFER: ffi.OpenBuffer, + StdLibMATH: ffi.OpenMath, + StdLibDEBUG: ffi.OpenDebug, + StdLibVECTOR: ffi.OpenVector, + } + + for library, opener := range luaLibs { + if (!options.IsSafe || StdLibALLSAFE.Contains(library)) && libs.Contains(library) { + opener(state.luaState) + } + } + + if options.EnableSandbox { + ffi.LSandbox(state.luaState) + } + + compiler := options.Compiler + if compiler == nil { + compiler = DefaultCompiler() + } + + fnReg := newFunctionRegistry() + fnReg.recoverPanics = options.CatchPanics + + lua := &Lua{inner: state, compiler: compiler, fnRegistry: fnReg, codegenEnabled: false} + if options.EnableCodegen && ffi.LuauCodegenSupported() { + ffi.LuauCodegenCreate(state.luaState) + lua.codegenEnabled = true + } + + runtime.SetFinalizer(lua, func(l *Lua) { + if options.CollectGarbage { + ffi.LuaGc(l.state(), ffi.LUA_GCCOLLECT, 0) + } + + l.Close() + }) + + return lua +} + +func pushUpvalue[T any](state *ffi.LuaState, ptr *T, dtor unsafe.Pointer) *uintptr { + var up *uintptr + + sz := uint64(unsafe.Sizeof(uintptr(0))) + if dtor != nil { + up = (*uintptr)(ffi.NewUserdataDtor(state, sz, dtor)) + } else { + up = (*uintptr)(ffi.NewUserdata(state, sz)) + } + + *up = uintptr(cgo.NewHandle(ptr)) + + return up +} diff --git a/lua/stdlib.go b/lua/stdlib.go new file mode 100644 index 0000000..22ea25e --- /dev/null +++ b/lua/stdlib.go @@ -0,0 +1,127 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +// StdLib represents flags describing the set of Lua standard libraries to load. +type StdLib uint32 + +const ( + // COROUTINE library + // https://luau.org/library#coroutine-library + StdLibCOROUTINE StdLib = 1 << 0 + + // TABLE library + // https://luau.org/library#table-library + StdLibTABLE StdLib = 1 << 1 + + // OS library + // https://luau.org/library#os-library + StdLibOS StdLib = 1 << 3 + + // STRING library + // https://luau.org/library#string-library + StdLibSTRING StdLib = 1 << 4 + + // UTF8 library + // https://luau.org/library#utf8-library + StdLibUTF8 StdLib = 1 << 5 + + // BIT library + // https://luau.org/library#bit32-library + StdLibBIT StdLib = 1 << 6 + + // MATH library + // https://luau.org/library#math-library + StdLibMATH StdLib = 1 << 7 + + // BUFFER library (Luau) + // https://luau.org/library#buffer-library + StdLibBUFFER StdLib = 1 << 9 + + // VECTOR library (Luau) + // https://luau.org/library#vector-library + StdLibVECTOR StdLib = 1 << 10 + + // DEBUG library (unsafe) + // https://luau.org/library#debug-library + StdLibDEBUG StdLib = 1 << 31 + + // StdLibNONE represents no libraries + StdLibNONE StdLib = 0 + + // StdLibALL represents all standard libraries (unsafe) + StdLibALL StdLib = ^StdLib(0) // equivalent to uint32 max + + // StdLibALLSAFE represents the safe subset of standard libraries + StdLibALLSAFE StdLib = (1 << 30) - 1 +) + +func (s StdLib) Contains(lib StdLib) bool { + return (s & lib) != 0 +} + +func (s StdLib) And(lib StdLib) StdLib { + return s & lib +} + +func (s StdLib) Or(lib StdLib) StdLib { + return s | lib +} + +func (s StdLib) Xor(lib StdLib) StdLib { + return s ^ lib +} + +func (s *StdLib) Add(lib StdLib) { + *s |= lib +} + +func (s *StdLib) Remove(lib StdLib) { + *s &^= lib +} + +func (s *StdLib) Toggle(lib StdLib) { + *s ^= lib +} + +func (s StdLib) String() string { + if s == StdLibNONE { + return "NONE" + } + if s == StdLibALL { + return "ALL" + } + + var libs []string + flags := map[StdLib]string{ + StdLibCOROUTINE: ffi.LUA_COLIBNAME, + StdLibTABLE: ffi.LUA_TABLIBNAME, + StdLibOS: ffi.LUA_OSLIBNAME, + StdLibSTRING: ffi.LUA_STRLIBNAME, + StdLibUTF8: ffi.LUA_UTF8LIBNAME, + StdLibBIT: ffi.LUA_BITLIBNAME, + StdLibMATH: ffi.LUA_MATHLIBNAME, + StdLibBUFFER: ffi.LUA_BUFFERLIBNAME, + StdLibVECTOR: ffi.LUA_VECLIBNAME, + StdLibDEBUG: ffi.LUA_VECLIBNAME, + } + + for flag, name := range flags { + if s.Contains(flag) { + libs = append(libs, name) + } + } + + if len(libs) == 0 { + return "NONE" + } + + result := "" + for i, lib := range libs { + if i > 0 { + result += "|" + } + result += lib + } + return result +} diff --git a/lua/string.go b/lua/string.go new file mode 100644 index 0000000..0b383b6 --- /dev/null +++ b/lua/string.go @@ -0,0 +1,43 @@ +package lua + +import ( + "unsafe" + + "github.com/CompeyDev/lei/ffi" +) + +type LuaString struct { + vm *Lua + index int +} + +func (s *LuaString) ToString() string { + state := s.vm.state() + + s.deref(s.vm) + defer ffi.Pop(state, 1) + + return ffi.ToString(state, -1) +} + +func (s *LuaString) ToPointer() unsafe.Pointer { + state := s.vm.state() + + s.deref(s.vm) + defer ffi.Pop(state, 1) + + return ffi.ToPointer(state, -1) +} + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaString)(nil) + +func (s *LuaString) lua() *Lua { return s.vm } +func (s *LuaString) ref() int { return s.index } + +func (s *LuaString) deref(lua *Lua) int { + return int(ffi.GetRef(lua.state(), int32(s.ref()))) +} diff --git a/lua/table.go b/lua/table.go new file mode 100644 index 0000000..b6139bf --- /dev/null +++ b/lua/table.go @@ -0,0 +1,247 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +type LuaTable struct { + vm *Lua + index int +} + +func (t *LuaTable) Set(key LuaValue, value LuaValue) { + state := t.vm.state() + + t.deref(t.vm) // table (-3) + key.deref(t.vm) // key (-2) + value.deref(t.vm) // value (-1) + + // Pop the table off + defer ffi.Pop(state, 1) + + ffi.SetTable(state, -3) +} + +func (t *LuaTable) Get(key LuaValue) LuaValue { + state := t.vm.state() + + t.deref(t.vm) //////////////////// table (-3) + key.deref(t.vm) //////////////////// key (-2) + + // Pop the table and value off + defer ffi.Pop(state, 2) + + ffi.GetTable(state, -2) + val := intoLuaValue(t.vm, -1) ////// value (-1) + + return val +} + +func (t *LuaTable) RawSet(key LuaValue, value LuaValue) { + state := t.vm.state() + + t.deref(t.vm) // table (-3) + key.deref(t.vm) // key (-2) + value.deref(t.vm) // value (-1) + + // Pop the table off + defer ffi.Pop(state, 1) + + ffi.RawSet(state, -3) +} + +func (t *LuaTable) RawGet(key LuaValue) LuaValue { + state := t.vm.state() + + t.deref(t.vm) // table (-2) + key.deref(t.vm) // key (-1) + + // Pop the table and value off + defer ffi.Pop(state, 2) + + ffi.RawGet(state, -2) + val := intoLuaValue(t.vm, -1) // value (-1) + + return val +} + +func (t *LuaTable) Push(value LuaValue) { + state := t.vm.state() + + t.deref(t.vm) // table (-2) + value.deref(t.vm) // value (-1) + + // Pop the table and key off + defer ffi.Pop(state, 2) + + // Insert new index and set it to the value + len := ffi.ObjLen(state, -2) + ffi.PushInteger(state, ffi.LuaInteger(len+1)) + ffi.Insert(state, -2) + ffi.SetTable(state, -3) +} + +func (t *LuaTable) Pop() LuaValue { + state := t.vm.state() + + t.deref(t.vm) // table (-1) + + // Pop the table off + defer ffi.Pop(state, 1) + + // Get the last value and nil it out + len := ffi.ObjLen(state, -1) + ffi.PushInteger(state, ffi.LuaInteger(len)) // key (-1), table (-2) + ffi.GetTable(state, -2) // value (-1), table (-2) + val := intoLuaValue(t.vm, -1) + + ffi.PushInteger(state, ffi.LuaInteger(len)) // key (-1), value (-2), table (-3) + ffi.PushNil(state) // nil (-1), key (-2), value (-3), table (-4) + ffi.SetTable(state, -4) + + return val +} + +func (t *LuaTable) RawPush(value LuaValue) { + state := t.vm.state() + + t.deref(t.vm) // table (-2) + value.deref(t.vm) // value (-1) + + // Pop the table off + defer ffi.Pop(state, 1) + + // Insert new index and set it to the value + len := ffi.ObjLen(state, -2) + ffi.PushInteger(state, ffi.LuaInteger(len+1)) // key (-1), value (-2), table (-3) + ffi.Insert(state, -2) // value (-1), key (-2), table (-3) + ffi.RawSet(state, -3) +} + +func (t *LuaTable) RawPop() LuaValue { + state := t.vm.state() + + t.deref(t.vm) // table (-1) + + // Pop the table off + defer ffi.Pop(state, 1) + + // Get the last value and nil it out + len := ffi.ObjLen(state, -1) + ffi.PushInteger(state, ffi.LuaInteger(len)) // key (-1), table (-2) + ffi.RawGet(state, -2) // value (-1), table (-2) + val := intoLuaValue(t.vm, -1) + + ffi.PushInteger(state, ffi.LuaInteger(len)) // key (-1), value (-2), table (-3) + ffi.PushNil(state) // nil (-1), key (-2), value (-3), table (-4) + ffi.RawSet(state, -4) + + return val +} + +func (t *LuaTable) Equals(other LuaValue) bool { + state := t.vm.state() + + // Compare by reference first + otherTable, ok := other.(*LuaTable) + if !ok { + return false + } + if t.index == otherTable.index { + return true + } + + // Compare by value + t.deref(t.vm) // table1 (-2) + otherTable.deref(t.vm) // table2 (-1) + + // Pop off both tables + defer ffi.Pop(state, 2) + + return ffi.Equal(state, -1, -2) +} + +func (t *LuaTable) Clear() { + state := t.vm.state() + + t.deref(t.vm) // table (-1) + + defer ffi.Pop(state, 1) + + // Iterate and nil out all keys + ffi.PushNil(state) + for ffi.Next(state, -2) != 0 { + ffi.Pop(state, 1) + ffi.PushValue(state, -1) + ffi.PushNil(state) + ffi.SetTable(state, -4) + } +} + +func (t *LuaTable) Len() int { + state := t.vm.state() + + t.deref(t.vm) + defer ffi.Pop(state, 1) + + return int(ffi.ObjLen(state, -1)) +} + +func (t *LuaTable) IsEmpty() bool { return t.Len() == 0 } + +func (t *LuaTable) Iterable() map[LuaValue]LuaValue { + state := t.vm.state() + + t.deref(t.vm) + tableIndex := ffi.GetTop(state) + ffi.PushNil(state) + + obj := make(map[LuaValue]LuaValue) + for ffi.Next(state, tableIndex) != 0 { + key, value := intoLuaValue(t.vm, -2), intoLuaValue(t.vm, -1) + obj[key] = value + + ffi.Pop(state, 1) // only pop value, leave key in place + } + + ffi.Pop(state, 1) + return obj +} + +func (t *LuaTable) SetMetatable(metatable *LuaTable) { + state := t.vm.state() + + t.deref(t.vm) // table (-2) + metatable.deref(t.vm) // metatable (-1) + + // Set the metatable for the table + ffi.SetMetatable(state, -2) + + // Pop metatable, re-ref the table + ffi.Pop(state, 1) + t.index = int(ffi.Ref(state, -1)) +} + +func (t *LuaTable) GetMetatable() *LuaTable { + state := t.vm.state() + + if ok := ffi.GetMetatable(state, int32(t.index)); ok { + index := ffi.Ref(state, -1) + ffi.Pop(state, 1) + + return &LuaTable{vm: t.vm, index: int(index)} + } + + return nil +} + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaTable)(nil) + +func (t *LuaTable) lua() *Lua { return t.vm } +func (t *LuaTable) ref() int { return t.index } + +func (t *LuaTable) deref(lua *Lua) int { + return int(ffi.GetRef(lua.state(), int32(t.ref()))) +} diff --git a/lua/thread.go b/lua/thread.go new file mode 100644 index 0000000..b7fc3b8 --- /dev/null +++ b/lua/thread.go @@ -0,0 +1,109 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +type LuaThread struct { + vm *Lua + chunk *LuaChunk + index int +} + +func (t *LuaThread) Resume() ([]LuaValue, error) { + threadState := t.state() + t.pushMainFunction() + + status := int(ffi.Resume(threadState, nil, 0)) + return t.collectResults(threadState, status) +} + +func (t *LuaThread) ResumeWith(args ...LuaValue) ([]LuaValue, error) { + mainState := t.vm.state() + threadState := t.state() + + // Push the function if required + t.pushMainFunction() + + // Push args length and then the args + argsCount := len(args) + ffi.PushNumber(threadState, ffi.LuaNumber(argsCount)) + + for _, arg := range args { + arg.deref(t.vm) + ffi.XMove(mainState, threadState, 1) + } + + status := int(ffi.Resume(threadState, nil, int32(argsCount+1))) // +1 for count arg + return t.collectResults(threadState, status) +} + +func (t *LuaThread) Status() int { + threadState := t.state() + return int(ffi.Status(threadState)) +} + +func (t *LuaThread) IsYielded() bool { + return t.Status() == ffi.LUA_YIELD +} + +func (t *LuaThread) IsFinished() bool { + status := t.Status() + threadState := t.state() + return status == ffi.LUA_OK && ffi.GetTop(threadState) == 0 +} + +func (t *LuaThread) collectResults(threadState *ffi.LuaState, status int) ([]LuaValue, error) { + if status != ffi.LUA_OK && status != ffi.LUA_YIELD { + // Return error if thread did not run successfully + return nil, newLuaError(threadState, status) + } + + nresults := int(ffi.GetTop(threadState)) + if nresults == 0 { + return nil, nil + } + + mainState := t.vm.state() + results := make([]LuaValue, nresults) + + // Push arguments onto main thread and ref them into LuaValues + for i := range nresults { + ffi.PushValue(threadState, int32(i+1)) + ffi.XMove(threadState, mainState, 1) + + results[i] = intoLuaValue(t.vm, int32(ffi.GetTop(mainState))) + } + + return results, nil +} + +func (t *LuaThread) pushMainFunction() { + if threadState := t.state(); t.Status() == ffi.LUA_OK && t.chunk != nil { + // Reset the thread and push the coroutine function if the thread has + // finished running and returned a non-resumable state + ffi.ResetThread(threadState) + t.chunk.pushToStack() + ffi.XMove(t.vm.state(), threadState, 1) + } +} + +func (t *LuaThread) state() *ffi.LuaState { + state := t.vm.state() + + t.deref(t.vm) + defer ffi.Pop(state, 1) + + return ffi.ToThread(state, -1) +} + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaThread)(nil) + +func (t *LuaThread) lua() *Lua { return t.vm } +func (t *LuaThread) ref() int { return t.index } + +func (t *LuaThread) deref(lua *Lua) int { + return int(ffi.GetRef(lua.state(), int32(t.ref()))) +} diff --git a/lua/userdata.c b/lua/userdata.c new file mode 100644 index 0000000..2e4c495 --- /dev/null +++ b/lua/userdata.c @@ -0,0 +1,17 @@ +#include +#include +#include <_cgo_export.h> + +int indexMt(lua_State* L) { + const char* key = lua_tostring(L, 2); + if (key == NULL) { + lua_pushnil(L); + return 1; + } + + uintptr_t* fields_handle = (uintptr_t*)lua_touserdata(L, lua_upvalueindex(1)); + uintptr_t* methods_handle = (uintptr_t*)lua_touserdata(L, lua_upvalueindex(2)); + + indexMtImpl(L, *fields_handle, *methods_handle, (char*)key); + return 1; +} diff --git a/lua/userdata.go b/lua/userdata.go new file mode 100644 index 0000000..a40b949 --- /dev/null +++ b/lua/userdata.go @@ -0,0 +1,149 @@ +package lua + +//go:generate go tool cgo -- -I../ffi/luau/VM/include $GOFILE + +/* +#cgo CFLAGS: -I../ffi/luau/VM/include +#cgo LDFLAGS: -L../ffi/_obj -lLuau.VM -lm -lstdc++ + +#include +#include +#include + +int indexMt(lua_State* L); +void methodMapDtorImpl(lua_State* L, uintptr_t); +void fieldMapDtorImpl(lua_State* L, uintptr_t); +*/ +import "C" + +import ( + "runtime/cgo" + + "github.com/CompeyDev/lei/ffi" +) + +var indexMt = C.indexMt + +var methodMapDtor = C.methodMapDtorImpl +var fieldMapDtor = C.fieldMapDtorImpl + +type LuaUserData struct { + vm *Lua + index int + inner IntoUserData +} + +func (ud *LuaUserData) Downcast() IntoUserData { + if ud.inner != nil { + return ud.inner + } + + ud.deref(ud.vm) + ptr := ffi.ToUserdata(ud.vm.state(), -1) + + if ptr != nil { + return *(*IntoUserData)(ptr) + } else { + return nil + } +} + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaUserData)(nil) + +func (ud *LuaUserData) lua() *Lua { return ud.vm } +func (ud *LuaUserData) ref() int { return ud.index } + +func (ud *LuaUserData) deref(lua *Lua) int { + return int(ffi.GetRef(lua.state(), int32(ud.ref()))) +} + +type IntoUserData interface { + Methods(*MethodMap) + MetaMethods(*MethodMap) + Fields(*FieldMap) +} + +type ValueRegistry[T any, U any] struct { + inner map[string]T + transformer func(fn U) T +} + +func (vr *ValueRegistry[T, U]) Insert(name string, value any) { + if getter, ok := value.(T); ok { + vr.inner[name] = getter + } else { + vr.inner[name] = vr.transformer(value.(U)) + } +} + +type MethodMap = ValueRegistry[*functionEntry, GoFunction] + +func newMethodMap(fnRegistry *functionRegistry) *MethodMap { + return &MethodMap{ + inner: make(map[string]*functionEntry), + transformer: func(fn GoFunction) *functionEntry { return fnRegistry.register(fn) }, + } +} + +type FieldGetter = func(*Lua) LuaValue +type FieldMap = ValueRegistry[FieldGetter, LuaValue] + +func newFieldMap() *FieldMap { + return &FieldMap{ + inner: make(map[string]FieldGetter), + transformer: func(value LuaValue) FieldGetter { + return func(*Lua) LuaValue { return value } + }, + } +} + +//export indexMtImpl +func indexMtImpl(lua *C.lua_State, fieldHandle, methodHandle uintptr, key *C.char) { + rawState := (*ffi.LuaState)(lua) + state := &Lua{ + // FIXME: what about the function registry? + inner: &StateWithMemory{ + memState: getMemoryState(rawState), + luaState: rawState, + }, + } + keyStr := C.GoString(key) + + // Field lookup + fields := cgo.Handle(fieldHandle).Value().(*FieldMap) + if getter := fields.inner[keyStr]; getter != nil { + value := getter(state) + value.deref(state) + return + } + + // Method lookup + methods := cgo.Handle(methodHandle).Value().(*MethodMap) + if method := methods.inner[keyStr]; method != nil { + pushUpvalue(rawState, method, registryTrampolineDtor) + ffi.PushCClosureK(rawState, registryTrampoline, nil, 1, nil) + return + } + + ffi.PushNil(rawState) +} + +func valueRegistryDtorImpl[T any, U any](handle C.uintptr_t) { + entry := cgo.Handle(handle).Value().(*ValueRegistry[T, U]) + clear(entry.inner) + cgo.Handle(handle).Delete() +} + +//export methodMapDtorImpl +func methodMapDtorImpl(_ *C.lua_State, handle C.uintptr_t) { + valueRegistryDtorImpl[*functionRegistry, GoFunction](handle) +} + +//export fieldMapDtorImpl +func fieldMapDtorImpl(_ *C.lua_State, handle C.uintptr_t) { + valueRegistryDtorImpl[FieldGetter, LuaValue](handle) +} diff --git a/lua/value.go b/lua/value.go new file mode 100644 index 0000000..cc29273 --- /dev/null +++ b/lua/value.go @@ -0,0 +1,228 @@ +package lua + +import ( + "fmt" + "reflect" + "strings" + + "github.com/CompeyDev/lei/ffi" +) + +type LuaValue interface { + // Optionally returns the Lua VM this value belongs to + lua() *Lua + // Returns the reference index of this value in the Lua registry + ref() int + // Dereferences this value onto the Lua stack, returning the stack index + deref(*Lua) int +} + +// +// Lua<->Go Type Conversion +// + +func As[T any](v LuaValue) (T, error) { + var zero T + + targetType := reflect.TypeOf(zero) + reflectValue, err := asReflectValue(v, targetType) + + return reflectValue.Interface().(T), err +} + +func asReflectValue(v LuaValue, t reflect.Type) (reflect.Value, error) { + zero := reflect.Zero(t) + + switch val := v.(type) { + case *LuaNumber: + // Map of all numeric types for O(1) lookup + var numericKinds = map[reflect.Kind]bool{ + reflect.Int: true, reflect.Int8: true, reflect.Int16: true, reflect.Int32: true, reflect.Int64: true, + reflect.Uint: true, reflect.Uint8: true, reflect.Uint16: true, reflect.Uint32: true, reflect.Uint64: true, + reflect.Uintptr: true, + reflect.Float32: true, reflect.Float64: true, + reflect.Complex64: true, reflect.Complex128: true, + } + + if kind := t.Kind(); numericKinds[kind] { + num := reflect.ValueOf(*val).Convert(t) + return num, nil + } + + case *LuaString: + if t.Kind() == reflect.String { + str := reflect.ValueOf(val.ToString()).Convert(t) + return str, nil + } + + case *LuaTable: + switch t.Kind() { + case reflect.Map: + res := reflect.MakeMap(t) + for key, value := range val.Iterable() { + var kVal reflect.Value + var vVal reflect.Value + var err error + + // Key conversion + if t.Key() == reflect.TypeOf((*LuaValue)(nil)).Elem() { + kVal = reflect.ValueOf(key) + } else { + kVal, err = asReflectValue(key, t.Key()) + if err != nil { + return zero, err + } + } + + // Value conversion + if t.Elem() == reflect.TypeOf((*LuaValue)(nil)).Elem() { + vVal = reflect.ValueOf(value) + } else { + vVal, err = asReflectValue(value, t.Elem()) + if err != nil { + return zero, err + } + } + + res.SetMapIndex(kVal, vVal) + } + + return res, nil + + case reflect.Struct: + res := reflect.New(t).Elem() + fieldSet := make(map[int]bool) + + for key, value := range val.Iterable() { + keyStr, ok := key.(*LuaString) + if !ok { + continue + } + + luaKey := keyStr.ToString() + var field reflect.Value + var found bool + var priority int // 0 = explicit annotation, 1 = direct match, 2 = lowercase fallback + var fieldIndex int + + for i := 0; i < t.NumField(); i++ { + // Annotation-based field name overrides (eg: `lua:"field_name"`) + structField := t.Field(i) + tagVal, ok := structField.Tag.Lookup("lua") + if ok && tagVal == luaKey { + field = res.Field(i) + found, priority = true, 0 + fieldIndex = i + break + } + + // Exact matches + if structField.Name == luaKey { + if !found || priority > 1 { + field = res.Field(i) + found, priority = true, 1 + fieldIndex = i + } + continue + } + + // If field is exported, try also using lowercase first character + if name := structField.Name; structField.IsExported() { + lower := strings.ToLower(name[:1]) + name[1:] + if lower == luaKey { + if !found || priority > 2 { + field = res.Field(i) + found, priority = true, 2 + fieldIndex = i + } + } + } + } + + if found && field.IsValid() && field.CanSet() { + // We keep track of whether the field has been found, its priority, and the + // index at which it was found within the struct. If there is an explicit + // annotation, we set the field value directly, otherwise we check that + // the field hasn't already been set in another match, and only set it then + if !fieldSet[fieldIndex] || priority == 0 { + // Recursively convert value to a reflect value + vVal, err := asReflectValue(value, field.Type()) + if err != nil { + return zero, err + } + + field.Set(vVal) + fieldSet[fieldIndex] = true + } + } + } + + return res, nil + + } + + case *LuaNil: + return zero, nil + + case *LuaUserData: + if downcasted := val.Downcast(); downcasted != nil { + return reflect.ValueOf(downcasted).Convert(t), nil + } + + return zero, fmt.Errorf("value isn't userdata") + + case *LuaVector: + return reflect.ValueOf(v), nil + + case *LuaBuffer: + kind := t.Kind() + if kind == reflect.Array { + return reflect.ValueOf(val.Read(0, uint64(t.Len()))).Convert(t), nil + } + + if kind == reflect.Slice { + return reflect.ValueOf(val.Read(0, val.size)).Convert(t), nil + } + } + + return zero, fmt.Errorf("cannot convert LuaValue(%T) into %T", v, zero.Type().Name()) +} + +func intoLuaValue(lua *Lua, index int32) LuaValue { + state := lua.state() + + switch ffi.Type(state, index) { + case ffi.LUA_TNUMBER: + num := ffi.ToNumber(state, index) + li := LuaNumber(float64(num)) + return &li + case ffi.LUA_TSTRING: + ref := ffi.Ref(state, index) + return &LuaString{vm: lua, index: int(ref)} + case ffi.LUA_TTABLE: + ref := ffi.Ref(state, index) + return &LuaTable{vm: lua, index: int(ref)} + case ffi.LUA_TNIL: + return &LuaNil{} + case ffi.LUA_TUSERDATA: + ref := ffi.Ref(state, index) + return &LuaUserData{vm: lua, index: int(ref)} + case ffi.LUA_TVECTOR: + x, y, z := ffi.ToVector(state, index) + return &LuaVector{*x, *y, *z} + case ffi.LUA_TBUFFER: + ref := ffi.Ref(state, index) + return &LuaBuffer{vm: lua, index: int(ref), size: ffi.ObjLen(state, ref)} + case ffi.LUA_TTHREAD: + ref := ffi.Ref(state, index) + return &LuaThread{vm: lua, index: int(ref)} // NOTE: no chunk, can only be executed once + default: + panic("unsupported Lua type") + } +} + +func valueUnrefer[T LuaValue](lua *Lua) func(T) { + return func(value T) { + ffi.Unref(lua.state(), int32(value.ref())) + } +} diff --git a/lua/value_test.go b/lua/value_test.go new file mode 100644 index 0000000..26cc165 --- /dev/null +++ b/lua/value_test.go @@ -0,0 +1,163 @@ +package lua_test + +import ( + "testing" + + "github.com/CompeyDev/lei/lua" +) + +func TestAs(t *testing.T) { + state := lua.New() + + // 1. Tag match + t.Run("Tag match", func(t *testing.T) { + type Person struct { + Name string `lua:"username"` + } + + table := state.CreateTable() + table.Set(state.CreateString("username"), state.CreateString("Alice")) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Name != "Alice" { + t.Fatalf("expected Alice, got %v", res.Name) + } + }) + + // 2. Exact field name match + t.Run("Exact match", func(t *testing.T) { + type Person struct { + Age int + } + + table := state.CreateTable() + table.Set(state.CreateString("Age"), lua.LuaNumber(30)) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Age != 30 { + t.Fatalf("expected 30, got %v", res.Age) + } + }) + + // 3. Lowercase-first-letter fallback + t.Run("Lowercase fallback", func(t *testing.T) { + type Person struct{ Country string } + + table := state.CreateTable() + table.Set(state.CreateString("country"), state.CreateString("Germany")) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Country != "Germany" { + t.Fatalf("expected 'Germany', got %v", res.Country) + } + }) + + // 4. Unexported field ignored + t.Run("Unexported field", func(t *testing.T) { + type Box struct{ secret string } + + table := state.CreateTable() + table.Set(state.CreateString("secret"), state.CreateString("trains r cool")) + + res, err := lua.As[Box](table) + if err != nil { + t.Fatal(err) + } + + if res.secret != "" { + t.Fatalf("expected empty, got %v", res.secret) + } + }) + + // 5. Mixed fields + t.Run("Mixed fields", func(t *testing.T) { + type Person struct { + Name string `lua:"username"` + Age int + Email string + } + + table := state.CreateTable() + table.Set(state.CreateString("username"), state.CreateString("Bob")) + table.Set(state.CreateString("Age"), lua.LuaNumber(25)) + table.Set(state.CreateString("email"), state.CreateString("bobby@example.com")) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Name != "Bob" || res.Age != 25 || res.Email != "bobby@example.com" { + t.Fatalf("unexpected result: %+v", res) + } + }) + + // 6. Missing Lua key + t.Run("Missing key", func(t *testing.T) { + type Person struct { + Name string + Age int + } + + table := state.CreateTable() + table.Set(state.CreateString("Name"), state.CreateString("Johnny")) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Name != "Johnny" || res.Age != 0 { + t.Fatalf("unexpected result: %+v", res) + } + }) + + // 7. Extra Lua key ignored + t.Run("Extra key ignored", func(t *testing.T) { + type Person struct{ Name string } + + table := state.CreateTable() + table.Set(state.CreateString("unknown"), state.CreateTable()) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Name != "" { + t.Fatalf("expected Name empty, got %v", res.Name) + } + }) + + // 8. Tag overrides lowercase fallback + t.Run("Tag overrides fallback", func(t *testing.T) { + type Person struct { + Name string `lua:"user"` + } + + table := state.CreateTable() + table.Set(state.CreateString("name"), state.CreateString("Dave")) + table.Set(state.CreateString("user"), state.CreateString("Eve")) + + res, err := lua.As[Person](table) + if err != nil { + t.Fatal(err) + } + + if res.Name != "Eve" { + t.Fatalf("expected 'Eve', got %v", res.Name) + } + }) +} diff --git a/lua/vector.go b/lua/vector.go new file mode 100644 index 0000000..450f605 --- /dev/null +++ b/lua/vector.go @@ -0,0 +1,19 @@ +package lua + +import "github.com/CompeyDev/lei/ffi" + +type LuaVector struct{ X, Y, Z float32 } + +// +// LuaValue implementation +// + +var _ LuaValue = (*LuaVector)(nil) + +func (v LuaVector) lua() *Lua { return nil } +func (v LuaVector) ref() int { return ffi.LUA_NOREF } +func (v LuaVector) deref(lua *Lua) int { + state := lua.state() + ffi.PushVector(state, v.X, v.Y, v.Z) + return int(ffi.GetTop(state)) +} diff --git a/main.go b/main.go index ebc32e3..5e16b66 100644 --- a/main.go +++ b/main.go @@ -1,24 +1,229 @@ package main -import lualib "github.com/CompeyDev/lei/ffi" +import ( + "fmt" + + "github.com/CompeyDev/lei/lua" +) func main() { - lua := lualib.LNewState() - println("Lua VM Address: ", lua) + mem := lua.NewMemoryState() + // mem.SetLimit(250 * 1024) // 250KB max + state := lua.NewWith(lua.StdLibALLSAFE, lua.LuaOptions{InitMemoryState: mem, CatchPanics: true, EnableCodegen: true}) + + table := state.CreateTable() + key, value := state.CreateString("hello"), state.CreateString("lei") + table.RawSet(key, value) + table.Push(state.CreateString("world")) + + mt, indexMt := state.CreateTable(), state.CreateTable() + indexKey := state.CreateString("hej") + indexMt.Set(indexKey, value) + mt.RawSet(state.CreateString("__index"), indexMt) + + table.SetMetatable(mt) + + fmt.Printf("Used: %d, Limit: %d\n", mem.Used(), mem.Limit()) + + fmt.Println(key.ToString(), table.RawGet(key).(*lua.LuaString).ToString()) + fmt.Println("key fetched by metatable:", table.Get(indexKey).(*lua.LuaString).ToString()) + fmt.Println("key fetched without metatable:", table.RawGet(indexKey).(*lua.LuaNil)) + fmt.Println("popped value:", table.Pop().(*lua.LuaString).ToString()) + + fmt.Println("len:", table.Len()) + + table.RawPush(state.CreateString("raw")) + fmt.Println("raw popped value:", table.RawPop().(*lua.LuaString).ToString()) + + fmt.Println("equals self:", table.Equals(table)) + fmt.Println("equals other:", table.Equals(indexMt)) + + table.Clear() + fmt.Println("len after clear:", table.Len()) + + chunk := state.Load("main", []byte("print('hello, lei!!!!', math.random()); return {['mrrp'] = 'foo', ['meow'] = 'bar'}, 'baz'")) + values, returnErr := chunk.Call() + + if returnErr != nil { + fmt.Println(returnErr) + return + } + + for i, value := range values { + fmt.Print(i, ": ") + + switch v := value.(type) { + case *lua.LuaString: + fmt.Println(v.ToString()) + case *lua.LuaTable: + fmt.Println() + + for key, val := range v.Iterable() { + k, kErr := lua.As[string](key) + v, vErr := lua.As[string](val) - lualib.PushCFunction(lua, func(L *lualib.LuaState) int32 { - println("hi from closure?") - return 0 + if kErr != nil || vErr != nil { + fmt.Println(" (non-string key or value)") + } + + fmt.Printf(" %v: %v\n", k, v) + } + } + } + + iterable, iterErr := lua.As[map[string]string](table) + if iterErr != nil { + fmt.Println(iterErr) + return + } + + for k, v := range iterable { // or, we can use `.Iterable` + fmt.Printf("%s %s\n", k, v) + } + + cFnChunk := state.CreateFunction(nil, func(luaState *lua.Lua, args ...lua.LuaValue) ([]lua.LuaValue, error) { + someNumber := lua.LuaNumber(22713) + return []lua.LuaValue{ + luaState.CreateString("Hello"), + luaState.CreateString("from"), + luaState.CreateString(fmt.Sprintf("Go, %s!", args[0].(*lua.LuaString).ToString())), + &someNumber, + }, nil }) - lualib.PushString(lua, "123") - lualib.PushNumber(lua, lualib.ToNumber(lua, 2)) + returns, callErr := cFnChunk.Call(state.CreateString("Lua")) + if callErr != nil { + fmt.Println(callErr) + return + } + + for i, ret := range returns { + str, err := lua.As[string](ret) + if err == nil { + fmt.Printf("Return %d: %s\n", i+1, str) + } else { + num, _ := lua.As[float64](ret) + fmt.Printf("Return %d: %f\n", i+1, num) + } + } + + class := &Class{value: 420.0} + classUd := state.CreateUserData(class) + state.SetGlobal("classUd", classUd) + + got := state.GetGlobal("classUd").(*lua.LuaUserData).Downcast() + fmt.Println("got:", got.(*Class).value) - if !lualib.IsCFunction(lua, 1) { - panic("CFunction was not correctly pushed onto stack") + conv, err := lua.As[*Class](classUd) + if err != nil { + fmt.Println(err) } - if !lualib.IsNumber(lua, 3) { - panic("Number was not correctly pushed onto stack") + fmt.Println("with as:", *conv) + + udChunk := state.Load("udChunk", []byte("print(tostring(classUd), classUd.toggle); classUd.flip(); print(classUd.toggle, classUd.fakeToggle); return vector.one")) + + vectorReturn, udCallErr := udChunk.Call() + if udCallErr != nil { + fmt.Println(udCallErr) + return } + + fmt.Println(vectorReturn[0].(*lua.LuaVector)) + + bufChunk := state.Load( + "bufChunk", + []byte( + `local str = buffer.readstring(b, 0, 5) + print(str) + buffer.writestring(b, 4, "lei")`, + ), + ) + + buf := state.CreateBuffer(10) + buf.Write(0, []byte("hello")) + state.SetGlobal("b", buf) + + _, bufErr := bufChunk.Call() + if bufErr != nil { + fmt.Println(bufErr) + return + } + + fmt.Println(string(buf.Read(4, 3))) + + bufArr, err := lua.As[[4]uint8](buf) + fmt.Println("rapidly approaching", string(bufArr[:])) + + thread, threadErr := state.CreateThread(state.CreateFunction(nil, func(luaState *lua.Lua, args ...lua.LuaValue) ([]lua.LuaValue, error) { + returns := []lua.LuaValue{ + luaState.CreateString("Hello"), + luaState.CreateString("thread"), + } + + if len(args) != 0 { + returns = append(returns, args[0]) + } else { + fmt.Println("No args for coroutine!") + } + + return returns, nil + })) + + if threadErr != nil { + fmt.Println(threadErr) + return + } + + resultsA, errA := thread.Resume() + resultsB, errB := thread.ResumeWith(state.CreateString("B!")) + + if errA != nil || errB != nil { + fmt.Println("Either thread resume failed") + fmt.Println("A:", errA) + fmt.Println("B:", errB) + return + } + + for _, result := range resultsA { + fmt.Println("Thread A =>", result.(*lua.LuaString).ToString()) + } + + fmt.Println() + + for _, result := range resultsB { + fmt.Println("Thread B =>", result.(*lua.LuaString).ToString()) + } +} + +type Class struct{ value float64 } + +var _ lua.IntoUserData = (*Class)(nil) + +func (c *Class) Fields(fields *lua.FieldMap) { + // NOTE: this references takes a copy of the value and mutations hence do + // not persist here. Instead we need a getter which captures the class + // itself + funnyNumber := lua.LuaNumber(c.value) + fields.Insert("fakeToggle", &funnyNumber) + + fields.Insert("toggle", func(*lua.Lua) lua.LuaValue { + value := lua.LuaNumber(c.value) + return &value + }) } + +func (c *Class) MetaMethods(metaMethods *lua.MethodMap) { + metaMethods.Insert("__tostring", func(vm *lua.Lua, _ ...lua.LuaValue) ([]lua.LuaValue, error) { + return []lua.LuaValue{vm.CreateString("Class")}, nil + }) +} + +func (c *Class) Methods(methods *lua.MethodMap) { + methods.Insert("flip", func(_G *lua.Lua, args ...lua.LuaValue) ([]lua.LuaValue, error) { + c.toggle() + return []lua.LuaValue{}, nil + }) +} + +func (c *Class) toggle() { c.value = 69.0 }