summaryrefslogtreecommitdiff
path: root/auth.lua
blob: 978fbc4650bd7a4fa45f5bc10a4e72bbe10a13e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
local request = require 'http.request'
local util = require 'http.util'
local json = require 'json'

local M = {}
M.__index = M

M.cache_enabled = true
M.tmpdir = "/tmp"
M.api_domain = "auth.citrons.xyz"
M.timeout = 8

local db
local function txn(write_enabled)
	local lmdb = require 'lmdb'
	if not lmdb then
		return
	end
	if not db then 
		local user = os.getenv "USER" and "."..os.getenv "USER" or ""
		db = lmdb.open(M.tmpdir.."/"..M.api_domain..user..".mdb", {
			nosubdir = true, mapsize = 2^20 * 5, maxdbs = 256,
		})
	end
	if not db then return end
	return db:txn_begin(write_enabled)
end

local function cache_get(name, k)
	local t = txn()
	local d = t:open(name)
	local entry
	if d then
		entry = d[k] and json.decode(d[k]) or nil
	end
	t:commit()
	if entry and os.time() <= entry.expires then
		return true, entry.v
	end
end

local function cache_clear(name, k)
	local t = txn(true)
	local d = t:open(name, true)
	if d[k] then d[k] = nil end
end

local function cache_put(name, k, v)
	local t = txn(true)
	local d = t:open(name, true)
	if d then
		local entry = {v = v}
		if type(v) == 'table' and v.ttl then
			entry.expires = os.time() + v.ttl
		else
			entry.expires = os.time() + 300
		end
		d[k] = json.encode(entry)
	end
	t:commit()
	return v
end

local rq
local default_headers
local function get_rq(path)
	if not rq then
		rq = request.new_from_uri("https://"..M.api_domain)
		default_headers = rq.headers
	end
	rq.headers = default_headers:clone()
	rq.headers:upsert(':path', path)
	rq:set_body(nil)
	return rq
end

local function api_get(path)
	local rq = get_rq(path)
	local headers, stream = assert(rq:go(M.timeout))
	if headers:get ':status' == "404" then
		return nil
	end
	assert(headers:get ':status' == "200", headers:get ':status')
	local data = assert(stream:get_body_as_string(M.timeout))
	return json.decode(data)
end

local function api_post(path, form)
	local rq = get_rq(path)
	rq.headers:upsert(':method', 'POST')
	rq.headers:upsert('content-type', 'application/x-www-form-urlencoded')
	rq:set_body(util.dict_to_query(form))
	local headers, stream = assert(rq:go(M.timeout))
	if headers:get ':status' == "404" then
		return nil
	end
	assert(headers:get ':status' == "200", headers:get ':status')
	local data = assert(stream:get_body_as_string(M.timeout))
	return json.decode(data)
end

function M.login_url(service)
	return "https://"..M.api_domain.."/login?"
		..util.dict_to_query {service = service}
end

function M.service_lookup(domain)
	if not domain:match "^[%w_%-%.]+$" then
		return nil, "invalid service name!"
	end
	domain = domain:gsub('%.+', ".")

	local _, data = cache_get('services', domain)
	if data then return data end

	local meta_uri = "https://"..domain.."/.well-known/citrons/auth"
	local rq = request.new_from_uri(meta_uri)
	local headers, stream = rq:go(M.timeout)
	if not headers then return nil, stream end
	if headers:get ":status" ~= "200" then
		return nil, "HTTP error: "..headers:get ":status"
	end
	local data, err = stream:get_body_chars(4096, 4)
	if not data then return nil, err end
	local ok, result = pcall(json.decode, data)
	if not ok then return nil, "could not decode JSON" end
	
	cache_put('services', domain, result)
	
	return result
end

local users = setmetatable({}, {__mode = 'k'})
function M.user(uid)
	if not users[uid] then
		local u = setmetatable({uid = uid}, M)
		if u:get_data() then
			users[uid] = u
		end
	end
	return users[uid]
end

function M:get_data()
	assert(
		type(self.uid) == 'string' and self.uid:match '^%w+$', "invalid uid")
	local cached, data = cache_get('users', self.uid)
	if cached then return data end
	data = api_get('/api/user/'..self.uid)
	cache_put('users', self.uid, data)
	return data
end

function M:username()
	return self:get_data().username
end

function M:authenticate(service, token)
	local cached, data = cache_get('tokens', token)
	if not cached then
		data = api_post(
			'/api/user/'..self.uid..'/auth/'..service, {token = token})
		cache_put('tokens', token, data)
		cache_put('users', self.uid, data.user)
	end
	if not data then return false end
	return data.valid
end

function M:invalidate(token)
	local data = api_post(
		'/api/user/'..self.uid..'/invalidate', {token = token})
	if data then
		cache_put('users', self.uid, data)
	end
	cache_put('tokens', token, nil)
end

return M