diff options
| author | dam <dam@gudinoff> | 2023-05-24 01:44:50 +0100 |
|---|---|---|
| committer | dam <dam@gudinoff> | 2023-05-24 01:44:50 +0100 |
| commit | 981170fcaf7eea3c1cc2f0c0a14a53d877276997 (patch) | |
| tree | 42442818c1f175b436fb28cd19a16a19ceb7b6dd /Math_Ext.jai | |
| parent | e029b883686de9f37e147914b8b0fb0045c9f395 (diff) | |
| download | task-time-tracker-981170fcaf7eea3c1cc2f0c0a14a53d877276997.tar.zst task-time-tracker-981170fcaf7eea3c1cc2f0c0a14a53d877276997.zip | |
Implemented saturating integer mul and div.
Diffstat (limited to 'Math_Ext.jai')
| -rw-r--r-- | Math_Ext.jai | 319 |
1 files changed, 290 insertions, 29 deletions
diff --git a/Math_Ext.jai b/Math_Ext.jai index 0fb0014..d4fd0cc 100644 --- a/Math_Ext.jai +++ b/Math_Ext.jai @@ -1,3 +1,5 @@ +// Integer saturating arighmetic (with some branch-free procedures on x64). + #import "Basic"; #import "Compiler"; #import "Math"; @@ -11,17 +13,6 @@ test_math_ext :: () { set_build_options_dc(.{do_output=false}); write_strings("=====================\n", "--- Test Math_Ext ---\n"); - // Different signals: only works if signaled variable is higher. - /* - #run cena(); - cena :: () { - a: s64 = -232; - b: u32 = 4; - c := a+b; - print("\n\n--- --- ---\ntttt : % : % + % = %\n--- --- ---\n\n", type_of(c), a, b, c); - } - */ - test_op :: (op: string, x: $Tx, y: $Ty, r: $Tr, t: Type, o: bool) -> errors_found: int #expand { #import "String"; @@ -81,19 +72,64 @@ test_math_ext :: () { set_build_options_dc(.{do_output=false}); errors += test_op("sub", cast(u32)U32_MAX, cast(u32)0, U32_MAX, u32, false); errors += test_op("sub", cast(u64)U64_MAX, cast(u32)0, U64_MAX, u64, false); - // errors += test_add(cast(s32)66, cast(s64)-2, 64, s64, false); - // errors += test_add(cast(u32)66, cast(s64)4, 70, s64, false); - // errors += test_add(cast(s32)S32_MAX, cast(s64)1, 2147483648, s64, false); - // errors += test_add(cast(s32)S32_MAX, cast(s32)1, S32_MAX, s32, true); - // errors += test_add(cast(s64)S64_MAX, cast(s64)0, S64_MAX, s64, false); - // errors += test_add(cast(s64)9223372036854775806, cast(s64)1, S64_MAX, s64, false); - // errors += test_add(cast(s64)9223372036854775806, cast(s64)2, S64_MAX, s64, true); - - // errors += test_add(cast(u8)7, cast(u8)1, 8, u8, false); - // errors += test_add(cast(u8)U8_MAX, cast(u8)1, U8_MAX, u8, true); - - // errors += test_add(cast(u16)10, cast(u8)3, 13, u16, false); - // errors += test_add(cast(u8)1, cast(u16)U16_MAX, U16_MAX, u16, true); + // Test signed mul. + errors += test_op("mul", cast( s8) S8_MIN, cast( s8)-1, S8_MAX, s8, true); + errors += test_op("mul", cast(s16)S16_MIN, cast( s8)-1, S16_MAX, s16, true); + errors += test_op("mul", cast(s32)S32_MIN, cast(s32)-1, S32_MAX, s32, true); + errors += test_op("mul", cast(s64)S64_MIN, cast(s32)-1, S64_MAX, s64, true); + + errors += test_op("mul", cast( s8) S8_MAX, cast( s8)-2, S8_MIN, s8, true); + errors += test_op("mul", cast(s16)S16_MAX, cast( s8)-2, S16_MIN, s16, true); + errors += test_op("mul", cast(s32)S32_MAX, cast(s32)-2, S32_MIN, s32, true); + errors += test_op("mul", cast(s64)S64_MAX, cast(s32)-2, S64_MIN, s64, true); + + errors += test_op("mul", cast( s8)-2, cast( s8) S8_MAX, S8_MIN, s8, true); + errors += test_op("mul", cast( s8)-2, cast(s16)S16_MAX, S16_MIN, s16, true); + errors += test_op("mul", cast(s32)-2, cast(s32)S32_MAX, S32_MIN, s32, true); + errors += test_op("mul", cast(s32)-2, cast(s64)S64_MAX, S64_MIN, s64, true); + + errors += test_op("mul", cast( s8) S8_MAX, cast( s8)2, S8_MAX, s8, true); + errors += test_op("mul", cast(s16)S16_MAX, cast( s8)2, S16_MAX, s16, true); + errors += test_op("mul", cast(s32)S32_MAX, cast(s32)2, S32_MAX, s32, true); + errors += test_op("mul", cast(s64)S64_MAX, cast(s32)2, S64_MAX, s64, true); + + errors += test_op("mul", cast( s8) S8_MAX, cast( s8)-1, -S8_MAX, s8, false); + errors += test_op("mul", cast(s16)S16_MAX, cast( s8)-1, -S16_MAX, s16, false); + errors += test_op("mul", cast(s32)S32_MAX, cast(s32)-1, -S32_MAX, s32, false); + errors += test_op("mul", cast(s64)S64_MAX, cast(s32)-1, -S64_MAX, s64, false); + + errors += test_op("mul", cast( s8) S8_MAX, cast( s8)0, 0, s8, false); + errors += test_op("mul", cast(s16)S16_MAX, cast( u8)0, 0, s16, false); + errors += test_op("mul", cast(s32)S32_MAX, cast(s32)0, 0, s32, false); + errors += test_op("mul", cast(s64)S64_MAX, cast(u32)0, 0, s64, false); + + // Test unsigned mul. + errors += test_op("mul", cast( u8) U8_MAX, cast( u8)1, U8_MAX, u8, false); + errors += test_op("mul", cast(u16)U16_MAX, cast( u8)1, U16_MAX, u16, false); + errors += test_op("mul", cast(u32)U32_MAX, cast(u32)1, U32_MAX, u32, false); + errors += test_op("mul", cast(u64)U64_MAX, cast(u32)1, U64_MAX, u64, false); + + errors += test_op("mul", cast( u8) U8_MAX, cast( u8)2, U8_MAX, u8, true); + errors += test_op("mul", cast(u16)U16_MAX, cast( u8)2, U16_MAX, u16, true); + errors += test_op("mul", cast(u32)U32_MAX, cast(u32)2, U32_MAX, u32, true); + errors += test_op("mul", cast(u64)U64_MAX, cast(u32)2, U64_MAX, u64, true); + + // Test signed div. + errors += test_op("div", cast( s8) S8_MIN, cast( s8)-1, S8_MAX, s8, true); + errors += test_op("div", cast(s16)S16_MIN, cast( s8)-1, S16_MAX, s16, true); + errors += test_op("div", cast(s32)S32_MIN, cast(s32)-1, S32_MAX, s32, true); + errors += test_op("div", cast(s64)S64_MIN, cast(s32)-1, S64_MAX, s64, true); + + errors += test_op("div", cast( s8) S8_MAX, cast( s8)-2, - S8_MAX/2, s8, false); + errors += test_op("div", cast(s16)S16_MAX, cast( s8)-2, -S16_MAX/2, s16, false); + errors += test_op("div", cast(s32)S32_MAX, cast(s32)-2, -S32_MAX/2, s32, false); + errors += test_op("div", cast(s64)S64_MAX, cast(s32)-2, -S64_MAX/2, s64, false); + + // Test unsigned div. + errors += test_op("div", cast( u8) U8_MAX, cast( u8)2, U8_MAX/2, u8, false); + errors += test_op("div", cast(u16)U16_MAX, cast( u8)2, U16_MAX/2, u16, false); + errors += test_op("div", cast(u32)U32_MAX, cast(u32)2, U32_MAX/2, u32, false); + errors += test_op("div", cast(u64)U64_MAX, cast(u32)2, U64_MAX/2, u64, false); if errors > 0 print("# Found % %!\n", errors, ifx errors == 1 then "error" else "errors"); else print(" No errors found.\n"); @@ -138,6 +174,8 @@ test_math_ext :: () { set_build_options_dc(.{do_output=false}); */ } +is_signed :: ($t: Type) -> bool { return (cast(*Type_Info_Integer)type_info(t)).signed; } + INTEGER_ARITHMETIC_TYPES_CHECK :: #string DONE type_info_x := cast(*Type_Info)Tx; type_info_y := cast(*Type_Info)Ty; @@ -174,7 +212,8 @@ add :: (x: $Tx, y: $Ty) -> result: $Tr, saturated: bool #modify { #insert INTEGE { #if CPU != .X64 { - + + // #if #run is_signed(Tr) { // TODO Maybe use this? #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } @@ -184,7 +223,6 @@ add :: (x: $Tx, y: $Ty) -> result: $Tr, saturated: bool #modify { #insert INTEGE if (y > 0 && x > MAX - y) then return MAX, true; if (y < 0 && x < MIN - y) then return MIN, true; - return x + y, false; } else { @@ -194,9 +232,10 @@ add :: (x: $Tx, y: $Ty) -> result: $Tr, saturated: bool #modify { #insert INTEGE #if Tr == u64 { MAX :: U64_MAX; } if (x > MAX - y) then return MAX, true; - return x + y, false; } + + return x + y, false; } else { @@ -269,14 +308,14 @@ sub :: (x: $Tx, y: $Ty) -> result: $Tr, overflow: bool #modify { #insert INTEGER if (y < 0 && x > MAX + y) then return MAX, true; if (y > 0 && x < MIN + y) then return MIN, true; - return x - y, false; } else { if (y > x) then return 0, true; - return x - y, false; } + + return x - y, false; } else { @@ -334,3 +373,225 @@ sub :: (x: $Tx, y: $Ty) -> result: $Tr, overflow: bool #modify { #insert INTEGER } } + +mul :: (x: $Tx, y: $Ty) -> result: $Tr, overflow: bool #modify { #insert INTEGER_ARITHMETIC_TYPES_CHECK; } //#dump +{ + + #if CPU != .X64 { + + // #if #run is_signed(Tr) { // TODO Maybe use this? + #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { + + #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } + #if Tr == s16 { MAX :: S16_MAX; MIN :: S16_MIN; } + #if Tr == s32 { MAX :: S32_MAX; MIN :: S32_MIN; } + #if Tr == s64 { MAX :: S64_MAX; MIN :: S64_MIN; } + + if x == 0 || y == 0 then return 0, false; + if x > 0 && y > 0 && x > MAX / y then return MAX, true; + if x < 0 && y < 0 && x < MAX / y then return MAX, true; + if (y < 0 && y < MIN / x) || (x < 0 && x < MIN / y) then return MIN, true; + + } else { + + if x > MAX / y then return MAX, true; + + } + + return x * y, false; + + } else { + + #import "String"; + result: Tr = ---; + saturated: bool = ---; + + S_MUL_ASM :: #string DONE + #asm { + result === a; + + // Calculate limit based on (x^y)'s sign. + mov limit: gpr, MAX; + mov sign: gpr, x; + xor sign, y; + shr.SIZE sign, BITS; + add.SIZE limit, sign; + + mov result, x; + imul.SIZE result, y; + seto saturated; + cmovo result, limit; + } + DONE + + #if Tr == s8 + #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".b"), "MAX", "127"), "BITS", "7"); + #if Tr == s16 + #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".w"), "MAX", "32767"), "BITS", "15"); + #if Tr == s32 + #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".d"), "MAX", "2147483647"), "BITS", "31"); + #if Tr == s64 + #insert #run replace(replace(replace(S_MUL_ASM, ".SIZE", ".q"), "MAX", "9223372036854775807"), "BITS", "63"); + + + U_MUL_ASM :: #string DONE + #asm { + result === a; + result_h: gpr === d; + + mov tmp: gpr, x; + mov result, x; + mul.SIZE result_h, result, y; + setc saturated; + sbb tmp, tmp; // SBB performs: dst = dst - (src + CF). Thus, max = 0xFF...FF if CF is 1, otherwise 0x00...00. TODO Improve comment. + or result, tmp; + } + DONE + + U_MUL_ASM_8BITS :: #string DONE + #asm { + result === a; + max: gpr; + + mov tmp: gpr, x; + mov result, x; + mul.SIZE result, y; + setc saturated; + sbb max, max; // SBB performs: dst = dst - (src + CF). Thus, max = 0xFF...FF if CF is 1, otherwise 0x00...00. TODO Improve comment. + or result, max; + } + DONE + + #if Tr == u8 + #insert #run replace(U_MUL_ASM_8BITS, ".SIZE", ".b"); + #if Tr == u16 + #insert #run replace(U_MUL_ASM, ".SIZE", ".w"); + #if Tr == u32 + #insert #run replace(U_MUL_ASM, ".SIZE", ".d"); + #if Tr == u64 + #insert #run replace(U_MUL_ASM, ".SIZE", ".q"); + + + return result, saturated; + + } +} + +div :: (x: $Tx, y: $Ty) -> result: $Tr, overflow: bool #modify { #insert INTEGER_ARITHMETIC_TYPES_CHECK; } //#dump +{ + + #if CPU != .X64 { + + // #if #run is_signed(Tr) { // TODO Maybe use this? + #if Tr == s8 || Tr == s16 || Tr == s32 || Tr == s64 { + + #if Tr == s8 { MAX :: S8_MAX; MIN :: S8_MIN; } + #if Tr == s16 { MAX :: S16_MAX; MIN :: S16_MIN; } + #if Tr == s32 { MAX :: S32_MAX; MIN :: S32_MIN; } + #if Tr == s64 { MAX :: S64_MAX; MIN :: S64_MIN; } + + if x == MIN && y == -1 then return MAX, true; + + } + + return x / y, false; + + } else { + + #import "String"; + result: Tr = ---; + saturated: bool = ---; + + S_DIV_ASM :: #string DONE + #asm { + result === a; + + // Calculate dividend limit (MIN+1) for the div(MIN/-1) problem. + mov limit: gpr, MIN; + inc limit; + + // Detect div(MIN/-1) and flag it on ZF. + mov xT: gpr, MIN; + mov xV: gpr, x; + xor.SIZE xT, xV; + mov yT: gpr, -1; + mov yV: gpr, y; + xor.SIZE yT, yV; + // + or.SIZE xT, yT; + + mov result, x; + cmovz result, limit; // Apply dividend limit if ZF. + cqo rdx:, result; // Prepare dividend high bits. + setz saturated; + idiv.SIZE rdx, result, y; + } + DONE + + S_DIV_ASM_8BITS :: #string DONE + #asm { + result === a; + + // Calculate dividend limit (MIN+1) for the div(MIN/-1) problem. + mov limit: gpr, MIN; + inc limit; + + // Detect div(MIN/-1) and flag it on ZF. + mov t_x: gpr, x; + mov t_y: gpr, y; + xor.SIZE t_x, MIN; + xor.SIZE t_y, -1; + or.SIZE t_x, t_y; + + mov result, x; + cmovz result, limit; // Apply dividend limit if ZF. + setz saturated; + idiv.SIZE result, y; + } + DONE + + #if Tr == s8 + #insert #run replace(replace(S_DIV_ASM_8BITS, ".SIZE", ".b"), "MIN", "-128"); + #if Tr == s16 + #insert #run replace(replace(S_DIV_ASM, ".SIZE", ".w"), "MIN", "-32768"); + #if Tr == s32 + #insert #run replace(replace(S_DIV_ASM, ".SIZE", ".d"), "MIN", "-2147483648"); + #if Tr == s64 + #insert #run replace(replace(S_DIV_ASM, ".SIZE", ".q"), "MIN", "-9223372036854775808"); + + + U_DIV_ASM :: #string DONE + #asm { + result === a; + + mov result, x; + mov rdx: gpr === d, 0; // Prepare dividend high bits. + div.SIZE rdx, result, y; + mov saturated, 0; + } + DONE + + U_DIV_ASM_8BITS :: #string DONE + #asm { + result === a; + + mov result, x; + div.SIZE result, y; + mov saturated, 0; + } + DONE + + #if Tr == u8 + #insert #run replace(U_DIV_ASM_8BITS, ".SIZE", ".b"); + #if Tr == u16 + #insert #run replace(U_DIV_ASM, ".SIZE", ".w"); + #if Tr == u32 + #insert #run replace(U_DIV_ASM, ".SIZE", ".d"); + #if Tr == u64 + #insert #run replace(U_DIV_ASM, ".SIZE", ".q"); + + + return result, saturated; + + } +} |
