diff --git a/jsonpath.lua b/jsonpath.lua index 07bf2c9..ab526f7 100755 --- a/jsonpath.lua +++ b/jsonpath.lua @@ -291,22 +291,46 @@ end)() local function eval_ast(ast, obj) -- Helper helper: match type of second operand to type of first operand - local function match_type(op1, op2) + local function match_cmp_type(op1, op2, compare) if type(op1) == 'boolean' then if is_null(op2) then - -- null must never be equal to other boolean, invert op1 - return not op1 + return not op1, nil else - return (op2 and true or false) + if type(op2) == 'string' then + return nil, "cannot compare boolean with string" + end + if type(op2) == 'number' then + if compare then + return nil, "cannot compare boolean with number" + end + return op2 ~= 0, nil + end + if type(op2) == 'boolean' then + return op2, nil + end + return (op2 and true or false), nil end elseif type(op1) == 'number' then - return tonumber(op2) + if type(op2) == 'boolean' then + return op2 and 1 or 0, nil + end + if type(op2) == 'string' then + local num = tonumber(op2) + if num == nil then + return nil, "cannot compare number with non-numeric string" + end + return num, nil + end + return tonumber(op2), nil elseif type(op1) == 'cdata' and tostring(ffi.typeof(op1)) == 'ctype' then - return tonumber(op2) + return tonumber(op2), nil elseif is_null(op1) then - return op2 + if compare then + return nil, "cannot compare null with other values" + end + return op2, nil end - return tostring(op2 or '') + return tostring(op2 or ''), nil end -- Helper helper: convert operand to boolean @@ -314,6 +338,12 @@ local function eval_ast(ast, obj) return op1 and true or false end + local function is_str_or_int(val) + return type(val) == 'string' or + type(val) == 'number' or + (type(val) == 'cdata' and tostring(ffi.typeof(val)) == 'ctype') + end + -- Helper helper: evaluate variable expression inside abstract syntax tree local function eval_var(expr, obj) if obj == nil then @@ -408,43 +438,76 @@ local function eval_ast(ast, obj) return nil, err end if operator == '+' then - op1 = tonumber(op1) + tonumber(op2) + if is_str_or_int(op1) and is_str_or_int(op2) then + op1 = tonumber(op1) + tonumber(op2) + else + return nil, "Only operations on strings and numbers are allowed." + end elseif operator == '-' then - op1 = tonumber(op1) - tonumber(op2) + if is_str_or_int(op1) and is_str_or_int(op2) then + op1 = tonumber(op1) - tonumber(op2) + else + return nil, "Only operations on strings and numbers are allowed." + end elseif operator == '*' then - op1 = tonumber(op1) * tonumber(op2) + if is_str_or_int(op1) and is_str_or_int(op2) then + op1 = tonumber(op1) * tonumber(op2) + else + return nil, "Only operations on strings and numbers are allowed." + end elseif operator == '/' then - op1 = tonumber(op1) / tonumber(op2) + if is_str_or_int(op1) and is_str_or_int(op2) then + op1 = tonumber(op1) / tonumber(op2) + else + return nil, "Only operations on strings and numbers are allowed." + end elseif operator == '%' then + if is_str_or_int(op1) and is_str_or_int(op2) then + op1 = tonumber(op1) % tonumber(op2) + else + return nil, "Only operations on strings and numbers are allowed." + end op1 = tonumber(op1) % tonumber(op2) elseif operator:upper() == 'AND' or operator == '&&' then op1 = notempty(op1) and notempty(op2) elseif operator:upper() == 'OR' or operator == '||' then op1 = notempty(op1) or notempty(op2) elseif operator == '=' or operator == '==' then - op1 = op1 == match_type(op1, op2) + local op2, err = match_cmp_type(op1, op2, false) + if err then + return nil, err + end + op1 = op1 == op2 elseif operator == '<>' or operator == '!=' then - op1 = op1 ~= match_type(op1, op2) + local op2, err = match_cmp_type(op1, op2, false) + if err then + return nil, err + end + op1 = op1 ~= op2 elseif operator == '>' then - if is_null(op1) then - return false + local op2, err = match_cmp_type(op1, op2, true) + if err then + return nil, err end - op1 = op1 > match_type(op1, op2) + op1 = op1 > op2 elseif operator == '>=' then - if is_null(op1) then - return false + local op2, err = match_cmp_type(op1, op2, true) + if err then + return nil, err end - op1 = op1 >= match_type(op1, op2) + op1 = op1 >= op2 elseif operator == '<' then - if is_null(op1) then - return false + local op2, err = match_cmp_type(op1, op2, true) + if err then + return nil, err end - op1 = op1 < match_type(op1, op2) + op1 = op1 < op2 elseif operator == '<=' then - if is_null(op1) then - return false + local op2, err = match_cmp_type(op1, op2, true) + if err then + return nil, err end - op1 = op1 <= match_type(op1, op2) + op1 = op1 <= op2 else return nil, 'unknown expression operator "' .. operator .. '"' end diff --git a/test/test.lua b/test/test.lua index ff565de..d392ec8 100755 --- a/test/test.lua +++ b/test/test.lua @@ -946,6 +946,112 @@ testQuery = { lu.assertNil(err) lu.assertItemsEquals(result, { array[2], array[3] }) end, + + testFilterIntBoolComparison = function () + local array = { + { id = 1, value = 0 }, + { id = 2, value = 1 }, + { id = 3, value = 2 }, + } + local result, err = jp.query(array, '$[?(@.value==true)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[2] }) + + local result, err = jp.query(array, '$[?(@.value>true)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[3] }) + + local result, err = jp.query(array, '$[?(@.value>=true)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[2], array[3] }) + + local result, err = jp.query(array, '$[?(@.value1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value>=1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value<1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value<=1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + end, + + testFilterBoolStrComparison = function () + local array = { + { id = 1, value = true }, + { id = 2, value = false }, + } + local result, err = jp.query(array, '$[?(@.value=="1")]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value>"1")]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value>="1")]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value<"1")]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value<="1")]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + end, + + testFilterArithmeticOpOnBool = function () + local array = { + { id = 1, value = 0 }, + { id = 1, value = 1 }, + { id = 2, value = 2 }, + } + local result, err = jp.query(array, '$[?(@.value==true+1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value==true*1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value==true/1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value==true%1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value<>false+1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + end, }