Module:Grind/fml

From Fallen London Wiki

Documentation for this module may be created at Module:Grind/fml/doc

local p = {}

local util = require('Module:Grind/util')

-- == Formula manipulation utilities ==

-- Parses and returns the next formula token, its type and the new string position.
local function next_token(s, initial_pos)
	local pos = initial_pos
	pos = util.skip_whitespaces(s, pos)
	if pos > #s then
		return nil, 'end', pos
	end
	local c = s:sub(pos, pos)
	if c == '(' or c == ')' or c == ',' or c == '+' or c == '-' or c == '*' or c == '/' then
		return c, 'tok', pos + 1
	end
	-- a number
	if c:match('%d') then
		local str = ''
		while c:match('%d') do
			str = str .. c
			pos, c = util.advance(s, pos)
		end
		if c == '.' then
			str = str .. c
			pos, c = util.advance(s, pos)
			while c:match('%d') do
				str = str .. c
				pos, c = util.advance(s, pos)
			end
		end
		return tonumber(str), 'num', pos
	end
	-- a number
	if c == '.' then
		local str = c
		pos, c = util.advance(s, pos)
		if c:match('%d') then
			while c:match('%d') do
				str = str .. c
				pos, c = util.advance(s, pos)
			end
			return tonumber(str), 'num', pos
		else
			return 'invalid float', 'err', pos
		end
	end
	-- a variable
	if c == '$' then
		pos, c = util.advance(s, pos)
		if c ~= '(' then
			return 'invalid variable', 'err', pos
		end
		local var_end = util.find_match(s, pos)
		if var_end > #s then
			return 'unmatched `(` in variable', 'err', pos
		end
		local var_name = s:sub(pos + 1, var_end - 1)
		return var_name, 'var', var_end + 1
	end
	-- function name
	if c:match('%a') then
		local str = c
		pos, c = util.advance(s, pos)
		while c:match('[%w_%.]') do
			str = str .. c
			pos, c = util.advance(s, pos)
		end
		return str, 'func', pos
	end
	return 'unexpected token: `' .. c .. '`', 'err', pos + 1
end

local function collect_tokens(s)
	local tokens = {}
	local pos = 1
	while pos <= #s do
		local tok_data, tok_type, new_pos = next_token(s, pos)
		table.insert(tokens, {tok_data, tok_type})
		if new_pos > pos then
			pos = new_pos
		else
			break
		end
	end
	return tokens
end

local function find_match(tokens, initial_pos)
	local pos = initial_pos + 1
	local depth = 1
	while pos <= #tokens do
		local tok_data, tok_type = unpack(tokens[pos])
		if tok_type == 'tok' then
			if tok_data == '(' then
				depth = depth + 1
			elseif tok_data == ')' then
				depth = depth - 1
			end
		end
		if depth == 0 then
			break
		end
		pos = pos + 1
	end
	return pos
end

local function split_args(tokens)
	local args = {}
	local arg = {}
	local no_args = true
	local pos = 1
	while pos <= #tokens do
		local tok_data, tok_type = unpack(tokens[pos])
		if tok_type == 'tok' and tok_data == ',' then
			table.insert(args, arg)
			arg = {}
			no_args = false
			pos = pos + 1
		elseif tok_type == 'tok' and tok_data == '(' then
			local pos_end = find_match(tokens, pos)
			if pos_end > #tokens then
				table.insert(arg, {'unmatched `(`', 'err'})
				pos = pos + 1
			else
				while pos <= pos_end do
					table.insert(arg, tokens[pos])
					pos = pos + 1
				end
			end
			no_args = false
		elseif tok_type == 'tok' and tok_data == ')' then
			table.insert(arg, {'unexpected `)`', 'err'})
			no_args = false
			pos = pos + 1
		else
			table.insert(arg, tokens[pos])
			no_args = false
			pos = pos + 1
		end
	end
	if not no_args then
		table.insert(args, arg)
	end
	return args
end

-- Layer 0: brackets and function calls.
-- Input token types: tok, num, var, func, err.
-- Output token types: tok, num, var, expr, func_call, err.
local function parse_layer0(tokens)
	local tree = {}
	local pos = 1
	while pos <= #tokens do
		local tok_data, tok_type = unpack(tokens[pos])
		if tok_type == 'tok' and tok_data == '(' then
			-- process ()
			local end_pos = find_match(tokens, pos)
			if end_pos > #tokens then
				table.insert(tree, {'unmatched `(`', 'err'})
			else
				local inner_tokens = {}
				pos = pos + 1
				while pos < end_pos do
					table.insert(inner_tokens, tokens[pos])
					pos = pos + 1
				end
				inner_tokens = parse_layer0(inner_tokens)
				table.insert(tree, {inner_tokens, 'expr'})
			end
			pos = end_pos + 1
		elseif tok_type == 'tok' and tok_data == ')' then
			table.insert(tree, {'unexpected `(`', 'err'})
			pos = pos + 1
		elseif tok_type == 'tok' and tok_data == ',' then
			table.insert(tree, {'unexpected `,`', 'err'})
			pos = pos + 1
		elseif tok_type == 'func' then
			-- process f(a,b,c)
			local func_name = tok_data
			pos = pos + 1
			if pos > #tokens then
				table.insert(tree, {'unexpected function name at the end of an expression: `' .. func_name .. '`', 'err'})
				break
			end
			tok_data, tok_type = unpack(tokens[pos])
			if tok_type ~= 'tok' or tok_data ~= '(' then
				table.insert(tree, {'invalid function call, `(` expected', 'err'})
				pos = pos + 1
			else
				local end_pos = find_match(tokens, pos)
				if end_pos > #tokens then
					table.insert(tree, {'unmatched `(`', 'err'})
				else
					local inner_tokens = {}
					pos = pos + 1
					while pos < end_pos do
						table.insert(inner_tokens, tokens[pos])
						pos = pos + 1
					end
					local func_args = split_args(inner_tokens)
					local func_data = {func_name}
					for _, arg in ipairs(func_args) do
						arg = parse_layer0(arg)
						table.insert(func_data, arg)
					end
					if func_name == 'err' then
						if #func_data == 2 then
							local arg = func_data[2]
							local arg_data, arg_type = unpack(arg[1])
							if arg_type == 'var' then
								table.insert(tree, {arg_data, 'err'})
							else
								table.insert(tree, {'err: var expected, ' .. tostring(arg_type) .. ' provided', 'err'})
							end
						else
							table.insert(tree, {'err: 1 argument expected, ' .. (#func_data - 1) .. ' provided', 'err'})
						end
					else
						table.insert(tree, {func_data, 'func_call'})
					end
				end
				pos = end_pos + 1
			end
		else
			table.insert(tree, tokens[pos])
			pos = pos + 1
		end
	end
	return tree
end

-- Layer 1: unary operations `+`, `-`.
-- Input token types: tok, num, var, expr, func_call, err.
-- Output token types: tok, num, var, expr, unary, func_call, err.
local function parse_layer1(tokens)
	local tree = {}
	if #tokens == 0 then
		return tree
	end
	-- An unary operation:
	-- * might be applied to `expr`, `var`, `num`, `func_call`
	-- * might be the first in an `expr` or follow after a binary operation (`+`, `-`, `*`, `/`)
	local first_unary = (tokens[1][2] == 'tok') and (tokens[1][1]:match('[+-]') == tokens[1][1])
	local pos = 1
	if first_unary then
		local op = tokens[1][1]
		if #tokens >= 2 then
			local tok_data, tok_type = unpack(tokens[2])
			if tok_type == 'var' then
				table.insert(tree, {{op, tokens[2]}, 'unary'})
				pos = 3
			elseif tok_type == 'num' then
				local val = tok_data
				if op == '-' then
					val = -val
				end
				table.insert(tree, {val, 'num'})
				pos = 3
			elseif tok_type == 'expr' then
				local expr = parse_layer1(tok_data)
				table.insert(tree, {{op, {expr, 'expr'}}, 'unary'})
				pos = 3
			elseif tok_type == 'func_call' then
				local func_name = tok_data[1]
				local func_data = {func_name}
				for i = 2, #tok_data do
					local arg = parse_layer1(tok_data[i])
					table.insert(func_data, arg)
				end
				table.insert(tree, {{op, {func_data, 'func_call'}}, 'unary'})
				pos = 3
			else
				table.insert(tree, {'unary `' .. op .. '` cannot be applied to ' .. tok_type, 'err'})
				pos = 2
			end
		else
			table.insert(tree, {'`' .. op .. '` is not an expression', 'err'})
			pos = 2
		end
	end
	
	local ops = 0 -- operations in row
	while pos <= #tokens do
		local tok_data, tok_type = unpack(tokens[pos])
		if tok_type == 'tok' and (tok_data:match('[+*/-]') == tok_data) then
			local op = tok_data
			ops = ops + 1
			if ops == 2 then
				if op:match('[+-]') then
					if pos + 1 > #tokens then
						table.insert(tree, {'unexpected `' .. op .. '` at the end of an expression', 'err'})
						ops = 0
						pos = pos + 1
					else
						local next_data, next_type = unpack(tokens[pos + 1])
						-- var, num, expr, func_call
						if next_type == 'var' then
							table.insert(tree, {{op, tokens[pos + 1]}, 'unary'})
							pos = pos + 2
						elseif next_type == 'num' then
							local val = next_data
							if op == '-' then
								val = -val
							end
							table.insert(tree, {val, 'num'})
							pos = pos + 2
						elseif next_type == 'expr' then
							local expr = parse_layer1(next_data)
							table.insert(tree, {{op, {expr, 'expr'}}, 'unary'})
							pos = pos + 2
						elseif next_type == 'func_call' then
							local func_name = next_data[1]
							local func_data = {func_name}
							for i = 2, #next_data do
								local arg = parse_layer1(next_data[i])
								table.insert(func_data, arg)
							end
							table.insert(tree, {{op, {func_data, 'func_call'}}, 'unary'})
							pos = pos + 2
						else
							table.insert(tree, {'unary `' .. op .. '` cannot be applied to ' .. next_type, 'err'})
							pos = pos + 1
						end
						ops = 0
					end
				else
					table.insert(tree, {'unexpected `' .. tok_data .. '`', 'err'})
					ops = 0
					pos = pos + 1
				end
			else
				-- cannot be an unary
				table.insert(tree, tokens[pos])
				pos = pos + 1
			end
		elseif tok_type == 'expr' then
			ops = 0
			table.insert(tree, {parse_layer1(tok_data), 'expr'})
			pos = pos + 1
		elseif tok_type == 'func_call' then
			ops = 0
			local func_name = tok_data[1]
			local func_data = {func_name}
			for i = 2, #tok_data do
				local arg = parse_layer1(tok_data[i])
				table.insert(func_data, arg)
			end
			table.insert(tree, {func_data, 'func_call'})
			pos = pos + 1
		else
			ops = 0
			table.insert(tree, tokens[pos])
			pos = pos + 1
		end
	end
	return tree
end

-- Layer 2: binary operations `*`,`/`.
-- Input token types: tok, num, var, expr, unary, func_call, err.
-- Output token types: tok, num, var, expr, unary, binary, func_call, err.
local function parse_layer2(tokens)
	local tree = {}
	if #tokens == 0 then
		return tree
	end
	local tokens = tokens
	if type(tokens[2]) == 'string' then
		tokens = {tokens}
	end
	local pos = 1
	while pos <= #tokens do
		local tok_data, tok_type = unpack(tokens[pos])
		if tok_type == 'tok' and tok_data:match('[*/]') == tok_data then
			local op = tok_data
			if #tree == 0 then
				table.insert(tree, {'unexpected `' .. op .. '` at the start of an expression', 'err'})
				pos = pos + 1
			elseif pos == #tokens then
				table.insert(tree, {'unexpected `' .. op .. '` at the end of an expression', 'err'})
				pos = pos + 1
			else
				local prev_data, prev_type = unpack(tree[#tree])
				-- valid prev types: num, var, expr, unary, binary, func_call
				if prev_type == 'num'
						or prev_type == 'var'
						or prev_type == 'expr'
						or prev_type == 'unary'
						or prev_type == 'binary'
						or prev_type == 'func_call' then
					local next_data, next_type = unpack(tokens[pos + 1])
					-- valid next types: num, var, expr, unary, func_call
					if next_type == 'num' or next_type == 'var' then
						local prev = table.remove(tree)
						local next = tokens[pos + 1]
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'expr' then
						local prev = table.remove(tree)
						local next = {parse_layer2(next_data), 'expr'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'unary' then
						local prev = table.remove(tree)
						local next = {{next_data[1], parse_layer2(next_data[2])}, 'unary'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'func_call' then
						local prev = table.remove(tree)
						local func_name = next_data[1]
						local func_data = {func_name}
						for i = 2, #next_data do
							local arg = parse_layer2(next_data[i])
							table.insert(func_data, arg)
						end
						local next = {func_data, 'func_call'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					else
						table.insert(tree, {'`' .. op .. '` cannot use ' .. next_type .. ' as the second argument', 'err'})
						pos = pos + 1
					end
				else
					table.insert(tree, {'`' .. op .. '` cannot use ' .. prev_type .. ' as the first argument', 'err'})
					pos = pos + 1
				end
			end
		elseif tok_type == 'expr' then
			table.insert(tree, {parse_layer2(tok_data), 'expr'})
			pos = pos + 1
		elseif tok_type == 'unary' then
			local op, expr = unpack(tok_data)
			table.insert(tree, {{op, parse_layer2(expr)}, 'unary'})
			pos = pos + 1
		elseif tok_type == 'func_call' then
			local func_name = tok_data[1]
			local func_data = {func_name}
			for i = 2, #tok_data do
				local arg = parse_layer2(tok_data[i])
				table.insert(func_data, arg)
			end
			table.insert(tree, {func_data, 'func_call'})
			pos = pos + 1
		else
			table.insert(tree, tokens[pos])
			pos = pos + 1
		end
	end
	return tree
end

-- Layer 3: binary operations `+`,`-`.
-- Input token types: tok, num, var, expr, unary, binary, func_call, err.
-- Output token types: num, var, expr, unary, binary, func_call, err.
local function parse_layer3(tokens)
	local tree = {}
	if #tokens == 0 then
		return tree
	end
	local tokens = tokens
	if type(tokens[2]) == 'string' then
		tokens = {tokens}
	end
	local pos = 1
	while pos <= #tokens do
		local tok_data, tok_type = unpack(tokens[pos])
		if tok_type == 'tok' and tok_data:match('[+-]') == tok_data then
			local op = tok_data
			if #tree == 0 then
				table.insert(tree, {'unexpected `' .. op .. '` at the start of an expression', 'err'})
				pos = pos + 1
			elseif pos == #tokens then
				table.insert(tree, {'unexpected `' .. op .. '` at the end of an expression', 'err'})
				pos = pos + 1
			else
				local prev_data, prev_type = unpack(tree[#tree])
				-- valid prev types: num, var, expr, unary, binary, func_call
				if prev_type == 'num'
						or prev_type == 'var'
						or prev_type == 'expr'
						or prev_type == 'unary'
						or prev_type == 'binary'
						or prev_type == 'func_call' then
					local next_data, next_type = unpack(tokens[pos + 1])
					-- valid next types: num, var, expr, unary, binary, func_call
					if next_type == 'num' or next_type == 'var' then
						local prev = table.remove(tree)
						local next = tokens[pos + 1]
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'expr' then
						local prev = table.remove(tree)
						local next = {parse_layer3(next_data), 'expr'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'unary' then
						local prev = table.remove(tree)
						local next = {{next_data[1], parse_layer3(next_data[2])}, 'unary'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'binary' then
						local prev = table.remove(tree)
						local next_op, next_a, next_b = unpack(next_data)
						next_a = parse_layer3(next_a)
						next_b = parse_layer3(next_b)
						local next = {{next_op, next_a, next_b}, 'binary'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					elseif next_type == 'func_call' then
						local prev = table.remove(tree)
						local func_name = next_data[1]
						local func_data = {func_name}
						for i = 2, #next_data do
							local arg = parse_layer3(next_data[i])
							table.insert(func_data, arg)
						end
						local next = {func_data, 'func_call'}
						table.insert(tree, {{op, prev, next}, 'binary'})
						pos = pos + 2
					else
						table.insert(tree, {'`' .. op .. '` cannot use ' .. next_type .. ' as the second argument', 'err'})
						pos = pos + 1
					end
				else
					table.insert(tree, {'`' .. op .. '` cannot use ' .. prev_type .. ' as the first argument', 'err'})
					pos = pos + 1
				end
			end
		elseif tok_type == 'expr' then
			table.insert(tree, {parse_layer3(tok_data), 'expr'})
			pos = pos + 1
		elseif tok_type == 'unary' then
			local op, expr = unpack(tok_data)
			table.insert(tree, {{op, parse_layer3(expr)}, 'unary'})
			pos = pos + 1
		elseif tok_type == 'binary' then
			local op, a, b = unpack(tok_data)
			a = parse_layer3(a)
			b = parse_layer3(b)
			table.insert(tree, {{op, a, b}, 'binary'})
			pos = pos + 1
		elseif tok_type == 'func_call' then
			local func_name = tok_data[1]
			local func_data = {func_name}
			for i = 2, #tok_data do
				local arg = parse_layer3(tok_data[i])
				table.insert(func_data, arg)
			end
			table.insert(tree, {func_data, 'func_call'})
			pos = pos + 1
		else
			table.insert(tree, tokens[pos])
			pos = pos + 1
		end
	end
	return tree
end

-- Removes expr, validates other nodes.
-- Input token types: num, var, expr, unary, binary, func_call, err.
-- Output token types: num, var, unary, binary, func_call, err.
local function parse_postprocess(tokens)
	local tokens = tokens
	if type(tokens[2]) == 'string' then
		tokens = {tokens}
	end
	if #tokens > 1 then
		return {'multiple trees in expression', 'err'}
	end
	if #tokens == 0 then
		return {'empty expression', 'err'}
	end
	local token = tokens[1]
	local tok_data, tok_type = unpack(token)
	if tok_type == 'num' then
		if type(tok_data) ~= 'number' then
			return {'a number is no number', 'err'}
		end
		return token
	elseif tok_type == 'var' then
		if type(tok_data) ~= 'string' then
			return {'a variable is no variable', 'err'}
		end
		return token
	elseif tok_type == 'expr' then
		if type(tok_data) ~= 'table' then
			return {'an expression is no expression', 'err'}
		end
		return parse_postprocess(tok_data)
	elseif tok_type == 'unary' then
		if type(tok_data) ~= 'table' then
			return {'an unary operation is no operation', 'err'}
		end
		if #tok_data ~= 2 then
			return {'an unary operation is not unary', 'err'}
		end
		local op, inner = unpack(tok_data)
		if op:match('[+-]') ~= op then
			return {'unknown unary operation `' .. op .. '`', 'err'}
		end
		inner = parse_postprocess(inner)
		return {{op, inner}, 'unary'}
	elseif tok_type == 'binary' then
		if type(tok_data) ~= 'table' then
			return {'a binary operation is no operation', 'err'}
		end
		if #tok_data ~= 3 then
			return {'a binary operation is not binary', 'err'}
		end
		local op, a, b = unpack(tok_data)
		if op:match('[+*/-]') ~= op then
			return {'unknown binary operation `' .. op .. '`', 'err'}
		end
		a = parse_postprocess(a)
		b = parse_postprocess(b)
		return {{op, a, b}, 'binary'}
	elseif tok_type == 'func_call' then
		if type(tok_data) ~= 'table' then
			return {'a function call is no function call', 'err'}
		end
		if #tok_data == 0 then
			return {'no function name is provided', 'err'}
		end
		local func_name = tok_data[1]
		if type(func_name) ~= 'string' then
			return {'a function name is no function name', 'err'}
		end
		local func_data = {func_name}
		for i = 2, #tok_data do
			local arg = parse_postprocess(tok_data[i])
			table.insert(func_data, arg)
		end
		return {func_data, 'func_call'}
	else
		return {'unknown token of type ' .. tostring(tok_type), 'err'}
	end
end

-- Finds and returns the first error encountered.
function p.find_error(tree)
	local tree = tree
	if type(tree[2]) == 'string' then
		tree = {tree}
	end
	for i = 1, #tree do
		local tok_data, tok_type = unpack(tree[i])
		if tok_type == 'err' then
			return tok_data
		elseif tok_type == 'expr' then
			local err = p.find_error(tok_data)
			if err then return err end
		elseif tok_type == 'func_call' then
			for i = 2, #tok_data do 
				local err = p.find_error(tok_data[i])
				if err then return err end
			end
		elseif tok_type == 'unary' then
			local err = p.find_error({tok_data[2]})
			if err then return err end
		elseif tok_type == 'binary' then
			local err = p.find_error({tok_data[2]})
			if err then return err end
			err = p.find_error({tok_data[3]})
			if err then return err end
		end
	end
	return nil
end

-- Parses the formula.
-- Returns its tree and an error string or nil.
function p.parse(s)
	local tokens = collect_tokens(s)
	local err = p.find_error(tokens)
	if err then return tokens, err end
	tokens = parse_layer0(tokens)
	err = p.find_error(tokens)
	if err then return tokens, err end
	tokens = parse_layer1(tokens)
	err = p.find_error(tokens)
	if err then return tokens, err end
	tokens = parse_layer2(tokens)
	err = p.find_error(tokens)
	if err then return tokens, err end
	tokens = parse_layer3(tokens)
	err = p.find_error(tokens)
	if err then return tokens, err end
	tokens = parse_postprocess(tokens)
	err = p.find_error(tokens)
	if err then return tokens, err end
	return tokens, nil
end

-- Returns a table of variables used.
-- Note that the `err($(message))` pattern requires no special handling:
-- it is resolved into a proper error in parse_layer0().
-- Not that you should ever call this function on a formula with errors.
function p.variables(tree)
	local vars = {}
	if type(tree) ~= 'table' or #tree ~= 2 then
		return vars
	end
	local tok_data, tok_type = unpack(tree)
	if tok_type == 'var' then
		return {[tok_data]=true}
	elseif tok_type == 'unary' then
		return p.variables(tok_data[2])
	elseif tok_type == 'binary' then
		local a, b = tok_data[2], tok_data[3]
		for v, _ in pairs(p.variables(a)) do
			vars[v] = true
		end
		for v, _ in pairs(p.variables(b)) do
			vars[v] = true
		end
	elseif tok_type == 'func_call' then
		for i = 2, #tok_data do
			local arg = tok_data[i]
			for v, _ in pairs(p.variables(arg)) do
				vars[v] = v
			end
		end
	end
	return vars
end

local function iterate_arg_dist(args)
	local n = #args
	local data = {}
	local idx = {}
	for i = 1, n do
		local d = {}
		for delta, prob in pairs(args[i]) do
			table.insert(d, {delta, prob})
		end
		data[i] = d
		idx[i] = 1
	end
	local finished = false
	local function advance()
		idx[n] = idx[n] + 1
		for i = n, 1, -1 do
			if idx[i] > #(data[i]) then
				idx[i] = 1
				if i > 1 then
					idx[i - 1] = idx[i - 1] + 1
				else
					finished = true
				end
			else
				break
			end
		end
	end
	return function()
		if finished then
			return nil
		end
		local args_item = {}
		local prob = 1
		for i = 1, n do
			table.insert(args_item, data[i][idx[i]][1])
			prob = prob * data[i][idx[i]][2]
		end
		advance()
		return args_item, prob
	end
end

local function eval_func(func, args)
	local arg_n = {
		min = 2, max = 2,
		exp = 1, ln = 1, pow = 2, sqrt = 1,
		sign = 1, abs = 1,
		round = 1, floor = 1, ceil = 1,
		sin = 1, cos = 1, tan = 1,
		pi = 0
	}
	if arg_n[func] == nil then
		return nil, 'unknown function ' .. tostring(func)
	end
	local n = arg_n[func]
	if n ~= #args then
		return nil, func .. ': ' .. n .. ' args expected, ' .. #args .. ' provided'
	end
	if func == 'min' then
		return math.min(args[1], args[2])
	elseif func == 'max' then
		return math.max(args[1], args[2])
	elseif func == 'exp' then
		return math.exp(args[1])
	elseif func == 'ln' then
		if args[1] <= 0 then
			return nil, 'ln: the argument must be positive'
		end
		return math.log(args[1])
	elseif func == 'pow' then
		local val = math.pow(args[1], args[2])
		if val ~= val then
			return nil, 'pow: result is NaN'
		end
		return val
	elseif func == 'sqrt' then
		local val = math.sqrt(args[1])
		if val ~= val then
			return nil, 'sqrt: result is NaN'
		end
		return val
	elseif func == 'sign' then
		if args[1] < 0 then
			return -1
		elseif args[1] > 0 then
			return 1
		else
			return 0
		end
	elseif func == 'abs' then
		return math.abs(args[1])
	elseif func == 'round' then
		return math.floor((math.floor(args[1] * 2) + 1) / 2)
	elseif func == 'floor' then
		return math.floor(args[1])
	elseif func == 'ceil' then
		return math.cail(args[1])
	elseif func == 'sin' then
		return math.sin(args[1])
	elseif func == 'cos' then
		return math.cos(args[1])
	elseif func == 'tan' then
		return math.tan(args[1])
	elseif func == 'pi' then
		return math.pi
	end
	return nil, 'FUNCTION ' .. func .. ' IS DECLARED BUT NOT DEFINED'
end

-- Evaluates the formula, returns a distribution or a tree with an error.
function p.eval(tree, data)
	local tok_data, tok_type = unpack(tree)
	if tok_type == 'num' then
		local val = tok_data
		local d = {[val]=1}
		return {d, 'dist'}
	elseif tok_type == 'var' then
		local key = tok_data
		if key:sub(1, 6) == 'Input:' then
			local _
			key, _ = util.normalise_input(key)
		end
		local val = data[key] or 0
		local d = {[val]=1}
		return {d, 'dist'}
	elseif tok_type == 'unary' then
		local op, inner = unpack(tok_data)
		inner = p.eval(inner, data)
		if inner[2] == 'dist' then
			if op == '+' then
				return inner
			elseif op == '-' then
				local d = {}
				for delta, prob in pairs(inner[1]) do
					d[-delta] = prob
				end
				return {d, 'dist'}
			end
		else
			return {{op, inner}, 'unary'}
		end
	elseif tok_type == 'binary' then
		local op, a, b = unpack(tok_data)
		a = p.eval(a, data)
		b = p.eval(b, data)
		if a[2] == 'dist' and b[2] == 'dist' then
			local d = {}
			for delta_a, prob_a in pairs(a[1]) do
				for delta_b, prob_b in pairs(b[1]) do
					local prob = prob_a * prob_b
					local delta
					if op == '+' then
						delta = delta_a + delta_b
					elseif op == '-' then
						delta = delta_a - delta_b
					elseif op == '*' then
						delta = delta_a * delta_b
					elseif op == '/' then
						delta = delta_a / delta_b
					end
					if prob > 0 and delta ~= delta then
						return {'not a number', 'err'}
					end
					if prob > 0 and delta == delta + 1 then
						return {'division by zero', 'err'}
					end
					if d[delta] == nil then
						d[delta] = 0
					end
					d[delta] = d[delta] + prob
				end
			end
			return {d, 'dist'}
		else
			return {{op, a, b}, 'binary'}
		end
	elseif tok_type == 'func_call' then
		local func = tok_data[1]
		local args = {}
		for i = 2, #tok_data do
			local arg = p.eval(tok_data[i], data)
			table.insert(args, arg)
		end
		if #args == 0 then
			local val, err = eval_func(func, {})
			if err then
				return {err, 'err'}
			else
				local d = {[val]=1}
				return {d, 'dist'}
			end
		else
			local no_eval = false
			for i = 1, #args do
				if args[i][2] ~= 'dist' then
					no_eval = true
				end
			end
			if no_eval then
				local func_data = {func}
				for i = 1, #args do
					table.insert(func_data, args[i])
				end
				return {func_data, 'func_call'}
			end
			local args_dist = {}
			for i = 1, #args do
				table.insert(args_dist, args[i][1])
			end
			local d = {}
			for args_item, prob in iterate_arg_dist(args_dist) do
				if func == 'random.range' then
					if #args_item ~= 2 then
						return {'random.range: 2 args expected, ' .. #args_item .. ' provided', 'err'}
					end
					local a, b = unpack(args_item)
					if math.floor(a) ~= a or math.floor(b) ~= b then
						return {'random.range requires its arguments to be integer', 'err'}
					end
					local n = b - a + 1
					if n <= 0 then
						return {'random.range(' .. a .. ',' .. b .. ') is invalid: ' .. a .. ' > ' .. b, 'err'}
					end
					if n ~= n or n == n + 1 then
						return {'random.range: infinite arguments are not supported', 'err'}
					end
					for delta = a, b do
						if d[delta] == nil then
							d[delta] = 0
						end
						d[delta] = d[delta] + prob / n
					end
				else
					local delta, err = eval_func(func, args_item)
					if prob > 0 and err then
						return {err, 'err'}
					end
					if d[delta] == nil then
						d[delta] = 0
					end
					d[delta] = d[delta] + prob
				end
			end
			return {d, 'dist'}
		end
	elseif tok_type == 'err' then
		return tree
	else
		return {'cannot evaluate ' .. tok_type, 'err'}
	end
end

-- Substitutes the specified values into the formula.
-- Also partially evaluates the formula.
-- `assume_unknown`: if true, unknown variables will be assumed to be zero.
function p.substitute(tree, data, assume_unknown)
	assume_unknown = assume_unknown or false
	local tok_data, tok_type = unpack(tree)
	if tok_type == 'num' then
		-- there is nothing to do
	elseif tok_type == 'var' then
		local key = tok_data
		if key:sub(1, 6) == 'Input:' then
			local _
			key, _ = util.normalise_input(key)
		end
		local val = tonumber(data[key])
		if val then
			tree = {val, 'num'}
		elseif assume_unknown then
			tree = {0, 'num'}
		else
			-- there is nothing to do
		end
	elseif tok_type == 'unary' then
		local op, inner = unpack(tok_data)
		inner = p.substitute(inner, data, assume_unknown)
		if inner[2] == 'num' then
			-- evaluate the operation
			local val = inner[1]
			if op == '-' then
				val = -val
			end
			tree = {val, 'num'}
		else
			-- cannot evaluate further
			tok_data = {op, inner}
			tree = {tok_data, tok_type}
		end
	elseif tok_type == 'binary' then
		local op, a, b = unpack(tok_data)
		a = p.substitute(a, data, assume_unknown)
		b = p.substitute(b, data, assume_unknown)
		if a[2] == 'num' and b[2] == 'num' then
			-- evaluate the operation
			local val = 0 / 0
			if op == '+' then
				val = a[1] + b[1]
			elseif op == '-' then
				val = a[1] - b[1]
			elseif op == '*' then
				val = a[1] * b[1]
			elseif op == '/' then
				val = a[1] / b[1]
			end
			if val ~= val then
				tree = {'not a number', 'err'}
			elseif val == val + 1 then
				tree = {'division by zero', 'err'}
			else
				tree = {val, 'num'}
			end
		else
			-- cannot evaluate further
			tok_data = {op, a, b}
			tree = {tok_data, tok_type}
		end
	elseif tok_type == 'func_call' then
		local func_name = tok_data[1]
		local args = {}
		local all_numbers = true
		for i = 2, #tok_data do
			local arg = tok_data[i]
			arg = p.substitute(arg, data, assume_unknown)
			if arg[2] ~= 'num' then
				all_numbers = false
			end
			table.insert(args, arg)
		end
		if all_numbers and func_name ~= 'random.range' then
			-- evaluate the function
			local val, f_err = eval_func(func_name, args)
			if f_err then
				tree = {f_err, 'err'}
			else
				tree = {val, 'num'}
			end
		else
			-- cannot evaluate further
			local func_data = {func_name}
			for _, arg in ipairs(args) do
				table.insert(func_data, arg)
			end
			tree = {func_data, 'func_call'}
		end
	elseif tok_type == 'err' then
		-- there is nothing to do
	end
	return tree
end

-- Returns: the string; whether the last layer has binary +/-.
function p.encode(tree)
	local tok_data, tok_type = unpack(tree)
	if tok_type == 'num' then
		return tostring(tok_data), false
	elseif tok_type == 'var' then
		return '$(' .. tok_data .. ')', false
	elseif tok_type == 'unary' then
		local op, inner = unpack(tok_data)
		local inner_str, inner_esc = p.encode(inner)
		if inner_esc then
			return op .. '(' .. inner_str .. ')', false
		else
			return op .. inner_str, false
		end
	elseif tok_type == 'binary' then
		local op, a, b = unpack(tok_data)
		local a_str, a_esc = p.encode(a)
		local b_str, b_esc = p.encode(b)
		local esc = op:match('[+-]') == op
		local s = ''
		if op:match('[+-]') == op then
			a_esc = false
			b_esc = false
		end
		if a_esc then
			s = s .. '(' .. a_str .. ')'
		else
			s = s .. a_str
		end
		s = s .. ' ' .. op .. ' '
		if b_esc then
			s = s .. '(' .. b_str .. ')'
		else
			s = s .. b_str
		end
		return s, esc
	elseif tok_type == 'func_call' then
		local func = tok_data[1]
		local s = func .. '('
		for i = 2, #tok_data do
			if i > 2 then
				s = s .. ', '
			end
			local arg = tok_data[i]
			local arg_str, arg_esc = p.encode(arg)
			s = s .. arg_str
		end
		s = s .. ')'
		return s, false
	elseif tok_type == 'err' then
		return 'err($(' .. tostring(tok_data) .. '))'
	elseif tok_type == 'dist' then
		local d = tok_data
		local dist_data = {}
		for delta, prob in pairs(d) do
			table.insert(dist_data, {delta, prob})
		end
		if #dist_data == 1 then
			local val = dist_data[1][1]
			return tostring(val), false
		end
		return 'err($(cannot encode a disttribution))', false
	else
		return 'err($(cannot encode ' .. tok_type .. '))', false
	end
end

return p