refactor: async processes

This commit is contained in:
Folke Lemaitre 2024-06-28 16:08:26 +02:00
parent 4319846b8c
commit a36ebd2a75
No known key found for this signature in database
GPG key ID: 41F8B1FBACAE2040
12 changed files with 394 additions and 379 deletions

View file

@ -1,78 +1,122 @@
---@class AsyncOpts
---@field on_done? fun()
---@field on_error? fun(err:string)
---@field on_yield? fun(res:any)
local M = {}
---@type Async[]
M._queue = {}
M._executor = assert(vim.loop.new_check())
M._running = false
---@type Async
M.current = nil
---@type table<thread, Async>
M._threads = setmetatable({}, { __mode = "k" })
---@alias AsyncEvent "done" | "error" | "yield" | "ok"
---@class Async
---@field co thread
---@field opts AsyncOpts
---@field sleeping? boolean
---@field _co thread
---@field _fn fun()
---@field _suspended? boolean
---@field _on table<AsyncEvent, fun(res:any, async:Async)[]>
local Async = {}
---@param fn async fun()
---@param opts? AsyncOpts
---@return Async
function Async.new(fn, opts)
function Async.new(fn)
local self = setmetatable({}, { __index = Async })
self.co = coroutine.create(fn)
self.opts = opts or {}
return self:init(fn)
end
---@param fn async fun()
---@return Async
function Async:init(fn)
self._fn = fn
self._on = {}
self._co = coroutine.create(function()
local ok, err = pcall(self._fn)
if not ok then
self:_emit("error", err)
end
self:_emit("done")
end)
M._threads[self._co] = self
return M.add(self)
end
function Async:restart()
assert(not self:running(), "Cannot restart a running async")
self:init(self._fn)
end
---@param event AsyncEvent
---@param cb async fun(res:any, async:Async)
function Async:on(event, cb)
self._on[event] = self._on[event] or {}
table.insert(self._on[event], cb)
return self
end
function Async:running()
return coroutine.status(self.co) ~= "dead"
---@private
---@param event AsyncEvent
---@param res any
function Async:_emit(event, res)
for _, cb in ipairs(self._on[event] or {}) do
cb(res, self)
end
end
function Async:running()
return coroutine.status(self._co) ~= "dead"
end
---@async
function Async:sleep(ms)
self.sleeping = true
self._suspended = true
vim.defer_fn(function()
self.sleeping = false
self._suspended = false
end, ms)
coroutine.yield()
end
---@async
function Async:suspend()
self.sleeping = true
self._suspended = true
if coroutine.running() == self._co then
coroutine.yield()
end
end
function Async:resume()
self.sleeping = false
self._suspended = false
end
function Async:wait()
local async = M.running()
if coroutine.running() == self._co then
error("Cannot wait on self")
end
while self:running() do
if async then
coroutine.yield()
else
vim.wait(10)
end
end
return self
end
function Async:step()
if self.sleeping then
if self._suspended then
return true
end
local status = coroutine.status(self.co)
local status = coroutine.status(self._co)
if status == "suspended" then
M.current = self
local ok, res = coroutine.resume(self.co)
M.current = nil
local ok, res = coroutine.resume(self._co)
if not ok then
if self.opts.on_error then
self.opts.on_error(tostring(res))
end
error(res)
elseif res then
if self.opts.on_yield then
self.opts.on_yield(res)
end
self:_emit("yield", res)
end
end
if self:running() then
return true
end
if self.opts.on_done then
self.opts.on_done()
end
return self:running()
end
function M.step()
@ -107,32 +151,24 @@ function M.add(async)
return async
end
---@param fn async fun()
---@param opts? AsyncOpts
function M.run(fn, opts)
return M.add(Async.new(fn, opts))
end
---@generic T: async fun()
---@param fn T
---@param opts? AsyncOpts
---@return T
function M.wrap(fn, opts)
return function(...)
local args = { ... }
---@async
local wrapped = function()
return fn(unpack(args))
end
return M.run(wrapped, opts)
function M.running()
local co = coroutine.running()
if co then
local async = M._threads[co]
assert(async, "In coroutine without async context")
return async
end
end
---@async
---@param ms number
function M.sleep(ms)
assert(M.current, "Not in an async context")
M.current:sleep(ms)
local async = M.running()
assert(async, "Not in an async context")
async:sleep(ms)
end
M.Async = Async
M.new = Async.new
return M