// Integer saturating arighmetic (with some branch-free procedures on x64). #import "Basic"; #import "Compiler"; #import "Math"; // TODO Comparing implementaitons using dump // #run test_math_ext(); // test_math_ext :: () { set_build_options_dc(.{do_output=false}); main :: () { write_strings("=====================\n", "--- Test Math_Ext ---\n"); test_op :: (operation: string, x: $Tx, y: $Ty, result: $Tr, type: Type, saturated: bool, remainder: Tr = 0) -> errors_found: int #expand { print_test_call :: (operation: string) -> string { #import "String"; str: string = ---; if operation != "div" { TEST_CALL :: #string DONE t_result, t_saturated := OP(cast(Tx)x, cast(Ty)y); print("%_%(%, %) = %0%0\n", operation, type, x, y, result, ifx saturated then " : saturated"); DONE str = replace(TEST_CALL, "OP", operation); } else { TEST_CALL :: #string DONE t_result, t_remainder, t_saturated := OP(cast(Tx)x, cast(Ty)y); print("%_%(%, %) = % + %0%0\n", operation, type, x, y, result, remainder, ifx saturated then " : saturated"); DONE str = replace(TEST_CALL, "OP", operation); } return str; } #insert #run print_test_call(operation); errors := 0; if result != t_result { errors += 1; print(" > incorrect result value: got % expected %\n", t_result, result); }; if type != type_of(t_result) { errors += 1; print(" > incorrect result type: got % expected %\n", type_of(t_result), type); }; if saturated != t_saturated { errors += 1; print(" > incorrect saturated flag: got % expected %\n", t_saturated, saturated); }; #if operation == "div" { if remainder != t_remainder { errors += 1; print(" > incorrect remainder value: got % expected %\n", t_remainder, remainder); }; } return errors; } errors := 0; // Test signed add. errors += test_op("add", cast( s8) S8_MAX, cast( s8)1, S8_MAX, s8, true); errors += test_op("add", cast(s16)S16_MAX, cast( u8)1, S16_MAX, s16, true); errors += test_op("add", cast(s32)S32_MAX, cast(s32)1, S32_MAX, s32, true); errors += test_op("add", cast(s64)S64_MAX, cast(u32)1, S64_MAX, s64, true); errors += test_op("add", cast( s8) S8_MAX, cast( s8) S8_MIN, -1, s8, false); errors += test_op("add", cast(s16)S16_MAX, cast(s16)S16_MIN, -1, s16, false); errors += test_op("add", cast(s32)S32_MAX, cast(s32)S32_MIN, -1, s32, false); errors += test_op("add", cast(s64)S64_MAX, cast(s64)S64_MIN, -1, s64, false); // Test unsigned add. errors += test_op("add", cast( u8) U8_MAX, cast( u8)1, U8_MAX, u8, true); errors += test_op("add", cast(u16)U16_MAX, cast(u16)1, U16_MAX, u16, true); errors += test_op("add", cast(u32)U32_MAX, cast(u32)1, U32_MAX, u32, true); errors += test_op("add", cast(u64)U64_MAX, cast(u64)1, U64_MAX, u64, true); errors += test_op("add", cast( u8) U8_MAX, cast( u8)0, U8_MAX, u8, false); errors += test_op("add", cast(u16)U16_MAX, cast(u16)0, U16_MAX, u16, false); errors += test_op("add", cast(u32)U32_MAX, cast(u32)0, U32_MAX, u32, false); errors += test_op("add", cast(u64)U64_MAX, cast(u64)0, U64_MAX, u64, false); // Test signed sub. errors += test_op("sub", cast( s8) S8_MIN, cast( s8)1, S8_MIN, s8, true); errors += test_op("sub", cast(s16)S16_MIN, cast( u8)1, S16_MIN, s16, true); errors += test_op("sub", cast(s32)S32_MIN, cast(s32)1, S32_MIN, s32, true); errors += test_op("sub", cast(s64)S64_MIN, cast(u32)1, S64_MIN, s64, true); errors += test_op("sub", cast( s8)-1, cast( s8) S8_MAX, S8_MIN, s8, false); errors += test_op("sub", cast(s16)-1, cast(s16)S16_MAX, S16_MIN, s16, false); errors += test_op("sub", cast(s32)-1, cast(s32)S32_MAX, S32_MIN, s32, false); errors += test_op("sub", cast(s64)-1, cast(s64)S64_MAX, S64_MIN, s64, false); // Test unsigned sub. errors += test_op("sub", cast( u8)1, cast( u8) U8_MAX, 0, u8, true); errors += test_op("sub", cast( u8)1, cast(u16)U16_MAX, 0, u16, true); errors += test_op("sub", cast(u32)1, cast(u32)U32_MAX, 0, u32, true); errors += test_op("sub", cast(u32)1, cast(u64)U64_MAX, 0, u64, true); errors += test_op("sub", cast( u8) U8_MAX, cast( u8)0, U8_MAX, u8, false); errors += test_op("sub", cast(u16)U16_MAX, cast( u8)0, U16_MAX, u16, false); errors += test_op("sub", cast(u32)U32_MAX, cast(u32)0, U32_MAX, u32, false); errors += test_op("sub", cast(u64)U64_MAX, cast(u32)0, U64_MAX, u64, false); // Test signed mul. errors += test_op("mul", cast( s8) S8_MIN, cast( s8)-1, S8_MAX, s8, true); errors += test_op("mul", cast(s16)S16_MIN, cast( s8)-1, S16_MAX, s16, true); errors += test_op("mul", cast(s32)S32_MIN, cast(s32)-1, S32_MAX, s32, true); errors += test_op("mul", cast(s64)S64_MIN, cast(s32)-1, S64_MAX, s64, true); errors += test_op("mul", cast( s8) S8_MAX, cast( s8)-2, S8_MIN, s8, true); errors += test_op("mul", cast(s16)S16_MAX, cast( s8)-2, S16_MIN, s16, true); errors += test_op("mul", cast(s32)S32_MAX, cast(s32)-2, S32_MIN, s32, true); errors += test_op("mul", cast(s64)S64_MAX, cast(s32)-2, S64_MIN, s64, true); errors += test_op("mul", cast( s8)-2, cast( s8) S8_MAX, S8_MIN, s8, true); errors += test_op("mul", cast( s8)-2, cast(s16)S16_MAX, S16_MIN, s16, true); errors += test_op("mul", cast(s32)-2, cast(s32)S32_MAX, S32_MIN, s32, true); errors += test_op("mul", cast(s32)-2, cast(s64)S64_MAX, S64_MIN, s64, true); errors += test_op("mul", cast( s8) S8_MAX, cast( s8)2, S8_MAX, s8, true); errors += test_op("mul", cast(s16)S16_MAX, cast( s8)2, S16_MAX, s16, true); errors += test_op("mul", cast(s32)S32_MAX, cast(s32)2, S32_MAX, s32, true); errors += test_op("mul", cast(s64)S64_MAX, cast(s32)2, S64_MAX, s64, true); errors += test_op("mul", cast( s8) S8_MAX, cast( s8)-1, -S8_MAX, s8, false); errors += test_op("mul", cast(s16)S16_MAX, cast( s8)-1, -S16_MAX, s16, false); errors += test_op("mul", cast(s32)S32_MAX, cast(s32)-1, -S32_MAX, s32, false); errors += test_op("mul", cast(s64)S64_MAX, cast(s32)-1, -S64_MAX, s64, false); errors += test_op("mul", cast( s8) S8_MAX, cast( s8)0, 0, s8, false); errors += test_op("mul", cast(s16)S16_MAX, cast( u8)0, 0, s16, false); errors += test_op("mul", cast(s32)S32_MAX, cast(s32)0, 0, s32, false); errors += test_op("mul", cast(s64)S64_MAX, cast(u32)0, 0, s64, false); // Test unsigned mul. errors += test_op("mul", cast( u8) U8_MAX, cast( u8)1, U8_MAX, u8, false); errors += test_op("mul", cast(u16)U16_MAX, cast( u8)1, U16_MAX, u16, false); errors += test_op("mul", cast(u32)U32_MAX, cast(u32)1, U32_MAX, u32, false); errors += test_op("mul", cast(u64)U64_MAX, cast(u32)1, U64_MAX, u64, false); errors += test_op("mul", cast( u8) U8_MAX, cast( u8)2, U8_MAX, u8, true); errors += test_op("mul", cast(u16)U16_MAX, cast( u8)2, U16_MAX, u16, true); errors += test_op("mul", cast(u32)U32_MAX, cast(u32)2, U32_MAX, u32, true); errors += test_op("mul", cast(u64)U64_MAX, cast(u32)2, U64_MAX, u64, true); // Test signed div. errors += test_op("div", cast( s8) S8_MIN, cast( s8)-1, S8_MAX, s8, true, -1); errors += test_op("div", cast(s16)S16_MIN, cast( s8)-1, S16_MAX, s16, true, -1); errors += test_op("div", cast(s32)S32_MIN, cast(s32)-1, S32_MAX, s32, true, -1); errors += test_op("div", cast(s64)S64_MIN, cast(s32)-1, S64_MAX, s64, true, -1); errors += test_op("div", cast( s8) S8_MAX, cast( s8)-2, - S8_MAX/2, s8, false, 1); errors += test_op("div", cast(s16)S16_MAX, cast( s8)-2, -S16_MAX/2, s16, false, 1); errors += test_op("div", cast(s32)S32_MAX, cast(s32)-2, -S32_MAX/2, s32, false, 1); errors += test_op("div", cast(s64)S64_MAX, cast(s32)-2, -S64_MAX/2, s64, false, 1); errors += test_op("div", cast( s8)15, cast( s8)5, 3, s8, false, 0); errors += test_op("div", cast( u8)15, cast(s16)7, 2, s16, false, 1); errors += test_op("div", cast(s16)15, cast(s32)13, 1, s32, false, 2); errors += test_op("div", cast(u16)100, cast(s64)3, 33, s64, false, 1); // Test unsigned div. errors += test_op("div", cast( u8) U8_MAX, cast( u8)2, U8_MAX/2, u8, false, 1); errors += test_op("div", cast(u16)U16_MAX, cast( u8)2, U16_MAX/2, u16, false, 1); errors += test_op("div", cast(u32)U32_MAX, cast(u32)2, U32_MAX/2, u32, false, 1); errors += test_op("div", cast(u64)U64_MAX, cast(u32)2, U64_MAX/2, u64, false, 1); if errors > 0 print("# Found % %!\n", errors, ifx errors == 1 then "error" else "errors"); else print(" No errors found.\n"); /* // Performance test. #import "Random"; best_generic: float; best_asm: float; for 0..100 { size, time_generic, time_asm := performance_test(); perf_generic := cast(float)size/cast(float)to_microseconds(time_generic); perf_asm := cast(float)size/cast(float)to_microseconds(time_asm); best_generic = max(best_generic, perf_generic); best_asm = max(best_asm, perf_asm); } print("method : ops/usec\ngeneric : %\nasm : %\n", best_generic, best_asm); performance_test :: () -> sum_size: s64, time_generic: Apollo_Time, time_asm: Apollo_Time { SUM_SIZE := 200;//0000; numbers: [..] s64; array_reserve(*numbers, SUM_SIZE); for 0..SUM_SIZE-1 { array_add(*numbers, cast(s64)random_get()); } sum := 0; start := current_time_monotonic(); for numbers sum = add(sum, it, true); time := current_time_monotonic() - start; sum_asm := 0; start_asm := current_time_monotonic(); for numbers sum_asm = add(sum_asm, it); time_asm := current_time_monotonic() - start_asm; assert(sum == sum_asm); return SUM_SIZE, time, time_asm; } */ } is_signed :: ($t: Type) -> bool { return (cast(*Type_Info_Integer)type_info(t)).signed; } INTEGER_ARITHMETIC_TYPES_CHECK :: #string DONE type_info_x := cast(*Type_Info)Tx; type_info_y := cast(*Type_Info)Ty; if type_info_x.type != .INTEGER || type_info_y.type != .INTEGER return false, "Non integers values passed."; tx := cast(*Type_Info_Integer)type_info_x; ty := cast(*Type_Info_Integer)type_info_y; largest_type := ifx tx.runtime_size > ty.runtime_size then Tx else ifx ty.runtime_size > tx.runtime_size then Ty else ifx tx.signed == ty.signed then Tx else void; // Only allow to add different signedness values if largest type is the signed one (as in JAI). if tx.signed == ty.signed { Tx = largest_type; Ty = largest_type; Tr = largest_type; } else if tx.signed && Tx == largest_type { Ty = largest_type; Tr = largest_type; } else if ty.signed && Ty == largest_type { Tx = largest_type; Tr = largest_type; } else return false, "Number signedness mismatch."; return true; DONE add :: (x: $Tx, y: $Ty) -> result: $Tr, saturated: bool #modify { #insert INTEGER_ARITHMETIC_TYPES_CHECK; } // #dump { #if CPU != .X64 { // #if #run is_signed(Tr) { // TODO Maybe use this? #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } #if Tr == s16 { MAX :: S16_MAX; MIN :: S16_MIN; } #if Tr == s32 { MAX :: S32_MAX; MIN :: S32_MIN; } #if Tr == s64 { MAX :: S64_MAX; MIN :: S64_MIN; } if (y > 0 && x > MAX - y) then return MAX, true; if (y < 0 && x < MIN - y) then return MIN, true; } else { #if Tr == u8 { MAX :: U8_MAX; } #if Tr == u16 { MAX :: U16_MAX; } #if Tr == u32 { MAX :: U32_MAX; } #if Tr == u64 { MAX :: U64_MAX; } if (x > MAX - y) then return MAX, true; } return x + y, false; } else { #import "String"; result: Tr = ---; saturated: bool = ---; S_ADD_ASM :: #string DONE #asm { // Calculate limit based on x's sign. mov limit: gpr, MAX; mov sign: gpr, x; shr.SIZE sign, BITS; add.SIZE limit, sign; // If sign is 1, then limit will overflow from MAX to MIN. mov result, x; add.SIZE result, y; seto saturated; cmovo result, limit; } DONE #if Tr == s8 #insert #run replace(replace(replace(S_ADD_ASM, ".SIZE", ".b"), "MAX", "127"), "BITS", "7"); #if Tr == s16 #insert #run replace(replace(replace(S_ADD_ASM, ".SIZE", ".w"), "MAX", "32767"), "BITS", "15"); #if Tr == s32 #insert #run replace(replace(replace(S_ADD_ASM, ".SIZE", ".d"), "MAX", "2147483647"), "BITS", "31"); #if Tr == s64 #insert #run replace(replace(replace(S_ADD_ASM, ".SIZE", ".q"), "MAX", "9223372036854775807"), "BITS", "63"); U_ADD_ASM :: #string DONE #asm { mov max: gpr, MAX; mov result, x; add.SIZE result, y; setc saturated; cmovc result, max; } DONE #if Tr == u8 #insert #run replace(replace(U_ADD_ASM, ".SIZE", ".b"), "MAX", "255"); #if Tr == u16 #insert #run replace(replace(U_ADD_ASM, ".SIZE", ".w"), "MAX", "65535"); #if Tr == u32 #insert #run replace(replace(U_ADD_ASM, ".SIZE", ".d"), "MAX", "4294967295"); #if Tr == u64 #insert #run replace(replace(U_ADD_ASM, ".SIZE", ".q"), "MAX", "18446744073709551615"); return result, saturated; } } sub :: (x: $Tx, y: $Ty) -> result: $Tr, overflow: bool #modify { #insert INTEGER_ARITHMETIC_TYPES_CHECK; } // #dump { #if CPU != .X64 { #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } #if Tr == s16 { MAX :: S16_MAX; MIN :: S16_MIN; } #if Tr == s32 { MAX :: S32_MAX; MIN :: S32_MIN; } #if Tr == s64 { MAX :: S64_MAX; MIN :: S64_MIN; } if (y < 0 && x > MAX + y) then return MAX, true; if (y > 0 && x < MIN + y) then return MIN, true; } else { if (y > x) then return 0, true; } return x - y, false; } else { #import "String"; result: Tr = ---; saturated: bool = ---; S_SUB_ASM :: #string DONE #asm { // Calculate limit based on x's sign. mov limit: gpr, MAX; mov sign: gpr, x; shr.SIZE sign, BITS; add.SIZE limit, sign; // If sign is 1, then limit will overflow from MAX to MIN. mov result, x; sub.SIZE result, y; seto saturated; cmovo result, limit; } DONE #if Tr == s8 #insert #run replace(replace(replace(S_SUB_ASM, ".SIZE", ".b"), "MAX", "127"), "BITS", "7"); #if Tr == s16 #insert #run replace(replace(replace(S_SUB_ASM, ".SIZE", ".w"), "MAX", "32767"), "BITS", "15"); #if Tr == s32 #insert #run replace(replace(replace(S_SUB_ASM, ".SIZE", ".d"), "MAX", "2147483647"), "BITS", "31"); #if Tr == s64 #insert #run replace(replace(replace(S_SUB_ASM, ".SIZE", ".q"), "MAX", "9223372036854775807"), "BITS", "63"); U_SUB_ASM :: #string DONE #asm { mov limit: gpr, 0; mov result, x; sub.SIZE result, y; setc saturated; cmovc result, limit; } DONE #if Tr == u8 #insert #run replace(U_SUB_ASM, ".SIZE", ".b"); #if Tr == u16 #insert #run replace(U_SUB_ASM, ".SIZE", ".w"); #if Tr == u32 #insert #run replace(U_SUB_ASM, ".SIZE", ".d"); #if Tr == u64 #insert #run replace(U_SUB_ASM, ".SIZE", ".q"); return result, saturated; } } mul :: (x: $Tx, y: $Ty) -> result: $Tr, overflow: bool #modify { #insert INTEGER_ARITHMETIC_TYPES_CHECK; } // #dump { #if CPU != .X64 { // #if #run is_signed(Tr) { // TODO Maybe use this? #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } #if Tr == s16 { MAX :: S16_MAX; MIN :: S16_MIN; } #if Tr == s32 { MAX :: S32_MAX; MIN :: S32_MIN; } #if Tr == s64 { MAX :: S64_MAX; MIN :: S64_MIN; } if x == 0 || y == 0 then return 0, false; if x > 0 && y > 0 && x > MAX / y then return MAX, true; if x < 0 && y < 0 && x < MAX / y then return MAX, true; if (y < 0 && y < MIN / x) || (x < 0 && x < MIN / y) then return MIN, true; } else { if x > MAX / y then return MAX, true; } return x * y, false; } else { #import "String"; result: Tr = ---; saturated: bool = ---; S_MUL_ASM :: #string DONE #asm { result === a; // Calculate limit based on (x^y)'s sign. mov limit: gpr, MAX; mov sign: gpr, x; xor sign, y; shr.SIZE sign, BITS; add.SIZE limit, sign; // If sign is 1, then limit will overflow from MAX to MIN. mov result, x; imul.SIZE result, y; seto saturated; cmovo result, limit; } DONE #if Tr == s8 #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".b"), "MAX", "127"), "BITS", "7"); #if Tr == s16 #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".w"), "MAX", "32767"), "BITS", "15"); #if Tr == s32 #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".d"), "MAX", "2147483647"), "BITS", "31"); #if Tr == s64 #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".q"), "MAX", "9223372036854775807"), "BITS", "63"); U_MUL_ASM :: #string DONE #asm { result === a; mov result, x; mul.SIZE r_d:, result, y; setc saturated; sbb max:, max; // If CF: max = -1 (all bits set); otherwise: max = 0. or result, max; } DONE U_MUL_ASM_8BITS :: #string DONE #asm { result === a; mov result, x; mul.SIZE result, y; setc saturated; sbb max:, max; // If CF: max = -1 (all bits set); otherwise: max = 0. or result, max; } DONE #if Tr == u8 #insert #run replace(U_MUL_ASM_8BITS, ".SIZE", ".b"); #if Tr == u16 #insert #run replace(U_MUL_ASM, ".SIZE", ".w"); #if Tr == u32 #insert #run replace(U_MUL_ASM, ".SIZE", ".d"); #if Tr == u64 #insert #run replace(U_MUL_ASM, ".SIZE", ".q"); return result, saturated; } } div :: (x: $Tx, y: $Ty) -> result: $Tr, remainder: Tr, saturated: bool #modify { #insert INTEGER_ARITHMETIC_TYPES_CHECK; } // #dump { #if CPU != .X64 { // #if #run is_signed(Tr) { // TODO Maybe use this? #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } #if Tr == s16 { MAX :: S16_MAX; MIN :: S16_MIN; } #if Tr == s32 { MAX :: S32_MAX; MIN :: S32_MIN; } #if Tr == s64 { MAX :: S64_MAX; MIN :: S64_MIN; } if x == MIN && y == -1 then return MAX, -1, true; } result := x / y; remainder := x - (y * result); return result, remainder, false; } else { #import "String"; result: Tr = ---; remainder: Tr = ---; saturated: bool = ---; S_DIV_ASM :: #string DONE #asm { result === a; remainder === d; // Detect div(MIN/-1) and flag it on ZF. mov xT: gpr, MIN; mov xV: gpr, x; xor.SIZE xT, xV; mov yT: gpr, y; xor.SIZE yT, -1; or.SIZE xT, yT; mov limit: gpr, LIMIT; mov result, x; cmovz result, limit; // If ZF: limit dividend to MIN-1. setz saturated; SIGN_EXT remainder, result; // Prepare dividend high bits. idiv.SIZE remainder, result, y; // If saturated: remainder = 0 - 1; otherwise: remainder = x - 0. sub.SIZE remainder, saturated; } DONE S_DIV_ASM_8BITS :: #string DONE #asm { result === a; remainder === d; // Detect div(MIN/-1) and flag it on ZF. mov t_x: gpr, x; mov t_y: gpr, y; xor.SIZE t_x, MIN; xor.SIZE t_y, -1; or.SIZE t_x, t_y; mov limit: gpr, LIMIT; mov result, x; cmovz result, limit; // If ZF: limit dividend to MIN-1. setz saturated; idiv.SIZE result, y; // Extract remainder from result's high bits. mov remainder, result; sar remainder, 8; // If saturated: remainder = 0 - 1; otherwise: remainder = x - 0. sub.SIZE remainder, saturated; } DONE #if Tr == s8 #insert #run replace(replace(replace(S_DIV_ASM_8BITS, ".SIZE", ".b"), "MIN", "-128"), "LIMIT", "-127"); #if Tr == s16 #insert #run replace(replace(replace(replace(S_DIV_ASM, ".SIZE", ".w"), "MIN", "-32768"), "LIMIT", "-32767"), "SIGN_EXT", "cwd"); #if Tr == s32 #insert #run replace(replace(replace(replace(S_DIV_ASM, ".SIZE", ".d"), "MIN", "-2147483648"), "LIMIT", "-2147483647"), "SIGN_EXT", "cdq"); #if Tr == s64 #insert #run replace(replace(replace(replace(S_DIV_ASM, ".SIZE", ".q"), "MIN", "-9223372036854775808"), "LIMIT", "-9223372036854775807"), "SIGN_EXT", "cqo"); U_DIV_ASM :: #string DONE #asm { result === a; remainder === d; mov saturated, 0; mov result, x; mov remainder, 0; // Prepare dividend high bits. div.SIZE remainder, result, y; } DONE U_DIV_ASM_8BITS :: #string DONE #asm { result === a; remainder === d; mov saturated, 0; mov result, x; div.SIZE result, y; // Extract remainder from result's high bits. mov remainder, result; sar remainder, 8; } DONE #if Tr == u8 #insert #run replace(U_DIV_ASM_8BITS, ".SIZE", ".b"); #if Tr == u16 #insert #run replace(U_DIV_ASM, ".SIZE", ".w"); #if Tr == u32 #insert #run replace(U_DIV_ASM, ".SIZE", ".d"); #if Tr == u64 #insert #run replace(U_DIV_ASM, ".SIZE", ".q"); return result, remainder, saturated; } }