From f0f5b15fbe827572e335ca825d627f05325ee6e9 Mon Sep 17 00:00:00 2001 From: heav Date: Thu, 27 Oct 2022 01:31:54 +0000 Subject: added filter, folds and sum/product to table library --- czzc/table.lua | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/czzc/table.lua b/czzc/table.lua index bb34f2f..92619eb 100644 --- a/czzc/table.lua +++ b/czzc/table.lua @@ -23,11 +23,86 @@ end -- replaces all values of a table with fn(value). M.map = function(fn, tbl) - local res = M.copy(tbl) + local res = {} for k, v in pairs(tbl) do res[k] = fn(v) end return res end +-- in-place version of the above. +M.map_ip = function(fn, tbl) + for k, v in pairs(tbl) do + tbl[k] = fn(v) + end + return tbl -- redundant, technically, but i thought i'd include it anyway. +end + +-- discards or keeps values from a table based on a predicate. +-- this one only supports arrays, not maps. +M.filter = function(pred, arr) + local res = {} + for i, v in ipairs(arr) do + if pred(v) then + table.insert(res, v) + end + end + return res +end + +M.filter_ip = function(pred, arr) -- in-place. + for i = #arr, 1, -1 do -- we need to iterate backwards here to prevent the. + if not pred(arr[i]) then + table.remove(arr, i) + end + end + return arr +end + +-- like filter, but works on map-style tables instead. +M.filter_map = function(pred, map) + local res = {} + for k, v in pairs(map) do + if pred(v) then + res[k] = v + end + end + return res +end + +M.filter_map_ip = function(pred, map) -- in-place. + for k, v in pairs(map) do + if not pred(v) then + map[k] = nil + end + end + return map +end + +-- "folds" an operation over an array. i.e foldr(*, e, {a, b, c, d}) = a * (b * (c * (d * e))). +M.foldr = function(op, startval, arr) + local res = startval or 0 -- sensible default? + for i = #arr, 1, -1 do + res = op(res, arr[i]) + end + return res +end + +-- same as above, but the other way i.e ((e * a) * b) * c +M.foldl = function(op, startval, arr) + local res = startval or 0 + for i=1, #arr do + res = op(res, arr[i]) + end + return res +end + +M.sum = function(arr) + return M.foldl(function(a,b) return a+b end, 0, arr) +end + +M.product = function(arr) + return M.foldl(function(a,b) return a*b end, 0, arr) +end + return M -- cgit v1.2.3