Skip to content

Commit 9a1cd6f

Browse files
committed
Add janet_sysir_scalarize
Makes it easier to add simpler backends without needing to completely handle vectorization.
1 parent 768c9b2 commit 9a1cd6f

File tree

2 files changed

+236
-30
lines changed

2 files changed

+236
-30
lines changed

examples/sysir/arrays2.janet

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@
1717
(def ctx (sysir/context))
1818
(sysir/asm ctx ir-asm)
1919
(print (sysir/to-c ctx))
20+
(printf "%.99M" (sysir/to-ir ctx))
21+
(print (sysir/scalarize ctx))
22+
(printf "%.99M" (sysir/to-ir ctx))
23+
(print (sysir/to-c ctx))

src/core/sysir.c

Lines changed: 232 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ const char *janet_sysop_names[] = {
103103
"type-struct", /* JANET_SYSOP_TYPE_STRUCT */
104104
"type-bind", /* JANET_SYSOP_TYPE_BIND */
105105
"arg", /* JANET_SYSOP_ARG */
106-
"field-getp", /* JANET_SYSOP_FIELD_GETP */
107-
"array-getp", /* JANET_SYSOP_ARRAY_GETP */
108-
"array-pgetp", /* JANET_SYSOP_ARRAY_PGETP */
106+
"fgetp", /* JANET_SYSOP_FIELD_GETP */
107+
"agetp", /* JANET_SYSOP_ARRAY_GETP */
108+
"apgetp", /* JANET_SYSOP_ARRAY_PGETP */
109109
"type-pointer", /* JANET_SYSOP_TYPE_POINTER */
110110
"type-array", /* JANET_SYSOP_TYPE_ARRAY */
111111
"type-union", /* JANET_SYSOP_TYPE_UNION */
@@ -198,9 +198,9 @@ static JanetString *table_to_string_array(JanetTable *strings_to_indices, int32_
198198
return NULL;
199199
}
200200
janet_assert(count > 0, "bad count");
201-
JanetString *strings = janet_malloc(count * sizeof(JanetString));
201+
JanetString *strings = NULL;
202202
for (int32_t i = 0; i < count; i++) {
203-
strings[i] = NULL;
203+
janet_v_push(strings, NULL);
204204
}
205205
for (int32_t i = 0; i < strings_to_indices->capacity; i++) {
206206
JanetKV *kv = strings_to_indices->data + i;
@@ -307,24 +307,29 @@ static uint32_t instr_read_type_operand(Janet x, JanetSysIR *ir, ReadOpMode rmod
307307
return operand;
308308
}
309309

310+
static uint32_t janet_sys_makeconst(JanetSysIR *sysir, uint32_t type, Janet x) {
311+
JanetSysConstant jsc;
312+
jsc.type = type;
313+
jsc.value = x;
314+
for (int32_t i = 0; i < janet_v_count(sysir->constants); i++) {
315+
if (sysir->constants[i].type != jsc.type) continue;
316+
if (!janet_equals(sysir->constants[i].value, x)) continue;
317+
/* Found a constant */
318+
return JANET_SYS_CONSTANT_PREFIX + i;
319+
}
320+
uint32_t index = (uint32_t) janet_v_count(sysir->constants);
321+
janet_v_push(sysir->constants, jsc);
322+
sysir->constant_count++;
323+
return JANET_SYS_CONSTANT_PREFIX + index;
324+
}
325+
310326
static uint32_t instr_read_operand_or_const(Janet x, JanetSysIR *ir) {
311327
if (janet_checktype(x, JANET_TUPLE)) {
312-
JanetSysConstant jsc;
313328
const Janet *tup = janet_unwrap_tuple(x);
314329
if (janet_tuple_length(tup) != 2) janet_panicf("expected constant wrapped in tuple, got %p", x);
315330
Janet c = tup[1];
316-
jsc.type = instr_read_type_operand(tup[0], ir, READ_TYPE_REFERENCE);
317-
jsc.value = c;
318-
/* TODO - Use a hash table or something better than linear lookup */
319-
for (int32_t i = 0; i < janet_v_count(ir->constants); i++) {
320-
if (ir->constants[i].type != jsc.type) continue;
321-
if (!janet_equals(ir->constants[i].value, c)) continue;
322-
/* Found a constant */
323-
return JANET_SYS_CONSTANT_PREFIX + i;
324-
}
325-
uint32_t index = (uint32_t) janet_v_count(ir->constants);
326-
janet_v_push(ir->constants, jsc);
327-
return JANET_SYS_CONSTANT_PREFIX + index;
331+
uint32_t t = instr_read_type_operand(tup[0], ir, READ_TYPE_REFERENCE);
332+
return janet_sys_makeconst(ir, t, c);
328333
}
329334
return instr_read_operand(x, ir);
330335
}
@@ -665,7 +670,6 @@ static void janet_sysir_init_instructions(JanetSysIR *out, JanetView instruction
665670

666671
/* Build constants */
667672
out->constant_count = janet_v_count(out->constants);
668-
out->constants = janet_v_flatten(out->constants);
669673
}
670674

671675
/* Get a type index given an operand */
@@ -724,14 +728,19 @@ static void tcheck_redef(JanetSysIR *ir, uint32_t typeid) {
724728
static void janet_sysir_init_types(JanetSysIR *ir) {
725729
JanetSysIRLinkage *linkage = ir->linkage;
726730
JanetSysTypeField *fields = NULL;
727-
JanetSysTypeInfo *type_defs = janet_realloc(linkage->type_defs, sizeof(JanetSysTypeInfo) * (linkage->type_def_count));
731+
JanetSysTypeInfo td;
732+
memset(&td, 0, sizeof(td));
733+
for (uint32_t i = 0; i < linkage->type_def_count; i++) {
734+
janet_v_push(linkage->type_defs, td);
735+
}
736+
JanetSysTypeInfo *type_defs = linkage->type_defs;
728737
uint32_t field_offset = linkage->field_def_count;
729-
uint32_t *types = janet_malloc(sizeof(uint32_t) * ir->register_count);
738+
uint32_t *types = NULL;
730739
linkage->type_defs = type_defs;
731-
ir->types = types;
732740
for (uint32_t i = 0; i < ir->register_count; i++) {
733-
ir->types[i] = 0;
741+
janet_v_push(types, 0);
734742
}
743+
ir->types = types;
735744
for (uint32_t i = linkage->old_type_def_count; i < linkage->type_def_count; i++) {
736745
type_defs[i].prim = JANET_PRIM_UNKNOWN;
737746
}
@@ -795,7 +804,7 @@ static void janet_sysir_init_types(JanetSysIR *ir) {
795804
if (janet_v_count(fields)) {
796805
uint32_t new_field_count = field_offset + janet_v_count(fields);
797806
linkage->field_defs = janet_realloc(linkage->field_defs, sizeof(JanetSysTypeField) * new_field_count);
798-
memcpy(linkage->field_defs + field_offset, fields, janet_v_count(fields) * sizeof(JanetSysTypeField));
807+
safe_memcpy(linkage->field_defs + field_offset, fields, janet_v_count(fields) * sizeof(JanetSysTypeField));
799808
linkage->field_def_count = new_field_count;
800809
janet_v_free(fields);
801810
}
@@ -1332,7 +1341,7 @@ static void janet_sys_ir_init(JanetSysIR *out, JanetView instructions, JanetSysI
13321341

13331342
/* Patch up name mapping arrays */
13341343
/* TODO - make more efficient, don't rebuild from scratch every time */
1335-
if (linkage->type_names) janet_free((void *) linkage->type_names);
1344+
if (linkage->type_names) janet_v_free((void *) linkage->type_names);
13361345
linkage->type_names = table_to_string_array(linkage->type_name_lookup, linkage->type_def_count);
13371346
ir.register_names = table_to_string_array(ir.register_name_lookup, ir.register_count);
13381347

@@ -1346,6 +1355,189 @@ static void janet_sys_ir_init(JanetSysIR *out, JanetView instructions, JanetSysI
13461355
janet_array_push(linkage->ir_ordered, janet_wrap_abstract(out));
13471356
}
13481357

1358+
/*
1359+
* Passes
1360+
*/
1361+
1362+
static JanetSysInstruction makethree(JanetSysInstruction source, JanetSysOp opcode, uint32_t dest, uint32_t lhs, uint32_t rhs) {
1363+
source.opcode = opcode;
1364+
source.three.dest = dest;
1365+
source.three.lhs = lhs;
1366+
source.three.rhs = rhs;
1367+
return source;
1368+
}
1369+
1370+
static JanetSysInstruction maketwo(JanetSysInstruction source, JanetSysOp opcode, uint32_t dest, uint32_t src) {
1371+
source.opcode = opcode;
1372+
source.two.dest = dest;
1373+
source.two.src = src;
1374+
return source;
1375+
}
1376+
1377+
static JanetSysInstruction makejmp(JanetSysInstruction source, JanetSysOp opcode, uint32_t to) {
1378+
source.opcode = opcode;
1379+
source.jump.to = to;
1380+
return source;
1381+
}
1382+
1383+
static JanetSysInstruction makebranch(JanetSysInstruction source, JanetSysOp opcode, uint32_t cond, uint32_t labelid) {
1384+
source.opcode = opcode;
1385+
source.branch.cond = cond;
1386+
source.branch.to = labelid;
1387+
return source;
1388+
}
1389+
1390+
static JanetSysInstruction makelabel(JanetSysInstruction source, JanetSysOp opcode, uint32_t id) {
1391+
source.opcode = opcode;
1392+
source.label.id = id;
1393+
return source;
1394+
}
1395+
1396+
static JanetSysInstruction makebind(JanetSysInstruction source, JanetSysOp opcode, uint32_t reg, uint32_t type) {
1397+
source.opcode = opcode;
1398+
source.type_bind.dest = reg;
1399+
source.type_bind.type = type;
1400+
return source;
1401+
}
1402+
1403+
1404+
static uint32_t janet_sysir_getreg(JanetSysIR *sysir, uint32_t type) {
1405+
uint32_t ret = sysir->register_count++;
1406+
janet_v_push(sysir->types, type);
1407+
return ret;
1408+
}
1409+
1410+
/* Find primitive types in the current linkage to avoid creating tons
1411+
* of copies of duplicate types. */
1412+
static uint32_t janet_sysir_findprim(JanetSysIRLinkage *linkage, JanetPrim prim, const char *type_name) {
1413+
for (uint32_t i = 0; i < linkage->type_def_count; i++) {
1414+
if (linkage->type_defs[i].prim == prim) {
1415+
return i;
1416+
}
1417+
}
1418+
/* Add new type */
1419+
JanetSysTypeInfo td;
1420+
memset(&td, 0, sizeof(td));
1421+
td.prim = prim;
1422+
janet_v_push(linkage->type_defs, td);
1423+
janet_table_put(linkage->type_name_lookup,
1424+
janet_csymbolv(type_name),
1425+
janet_wrap_number(linkage->type_def_count));
1426+
janet_v_push(linkage->type_names, janet_csymbol(type_name));
1427+
return linkage->type_def_count++;
1428+
}
1429+
1430+
/* Get a type that is a pointer to another type */
1431+
static uint32_t janet_sysir_findpointer(JanetSysIRLinkage *linkage, uint32_t to, const char *type_name) {
1432+
for (uint32_t i = 0; i < linkage->type_def_count; i++) {
1433+
if (linkage->type_defs[i].prim == JANET_PRIM_POINTER) {
1434+
if (linkage->type_defs[i].pointer.type == to) {
1435+
return i;
1436+
}
1437+
}
1438+
}
1439+
/* Add new type */
1440+
JanetSysTypeInfo td;
1441+
memset(&td, 0, sizeof(td));
1442+
td.prim = JANET_PRIM_POINTER;
1443+
td.pointer.type = to;
1444+
janet_v_push(linkage->type_defs, td);
1445+
janet_table_put(linkage->type_name_lookup,
1446+
janet_csymbolv(type_name),
1447+
janet_wrap_number(linkage->type_def_count));
1448+
janet_v_push(linkage->type_names, janet_csymbol(type_name));
1449+
return linkage->type_def_count++;
1450+
}
1451+
1452+
/* Unwrap vectorized binops to scalars in one pass to make certain lowering easier. */
1453+
static void janet_sysir_scalarize(JanetSysIRLinkage *linkage) {
1454+
uint32_t index_type = janet_sysir_findprim(linkage, JANET_PRIM_U32, "U32Index");
1455+
uint32_t boolean_type = janet_sysir_findprim(linkage, JANET_PRIM_BOOLEAN, "Boolean");
1456+
for (int32_t j = 0; j < linkage->ir_ordered->count; j++) {
1457+
JanetSysIR *sysir = janet_unwrap_abstract(linkage->ir_ordered->data[j]);
1458+
for (uint32_t i = 0; i < sysir->instruction_count; i++) {
1459+
JanetSysInstruction instruction = sysir->instructions[i];
1460+
sysir->error_ctx = janet_cstringv(janet_sysop_names[instruction.opcode]);
1461+
switch (instruction.opcode) {
1462+
default:
1463+
break;
1464+
case JANET_SYSOP_ADD:
1465+
case JANET_SYSOP_SUBTRACT:
1466+
case JANET_SYSOP_MULTIPLY:
1467+
case JANET_SYSOP_DIVIDE:
1468+
case JANET_SYSOP_BAND:
1469+
case JANET_SYSOP_BOR:
1470+
case JANET_SYSOP_BXOR:
1471+
case JANET_SYSOP_GT:
1472+
case JANET_SYSOP_LT:
1473+
case JANET_SYSOP_EQ:
1474+
case JANET_SYSOP_NEQ:
1475+
case JANET_SYSOP_GTE:
1476+
case JANET_SYSOP_LTE:
1477+
case JANET_SYSOP_SHL:
1478+
case JANET_SYSOP_SHR:
1479+
;
1480+
{
1481+
uint32_t dest_type = janet_sys_optype(sysir, instruction.three.dest);
1482+
uint32_t test_type = dest_type;
1483+
if (linkage->type_defs[dest_type].prim == JANET_PRIM_POINTER) {
1484+
test_type = linkage->type_defs[dest_type].pointer.type;
1485+
}
1486+
if (linkage->type_defs[test_type].prim != JANET_PRIM_ARRAY) {
1487+
break;
1488+
}
1489+
uint32_t pel_type = janet_sysir_findpointer(linkage, linkage->type_defs[test_type].array.type, "PointerTo"); // fixme - type name would need to be unique
1490+
uint32_t lhs_type = janet_sys_optype(sysir, instruction.three.lhs);
1491+
uint32_t rhs_type = janet_sys_optype(sysir, instruction.three.rhs);
1492+
uint32_t array_size = linkage->type_defs[dest_type].array.fixed_count;
1493+
uint32_t index_reg = janet_sysir_getreg(sysir, index_type);
1494+
uint32_t compare_reg = janet_sysir_getreg(sysir, boolean_type);
1495+
uint32_t temp_lhs = janet_sysir_getreg(sysir, pel_type);
1496+
uint32_t temp_rhs = janet_sysir_getreg(sysir, pel_type);
1497+
uint32_t temp_dest = janet_sysir_getreg(sysir, pel_type);
1498+
uint32_t loopstart_label = sysir->label_count++;
1499+
uint32_t loopend_label = sysir->label_count++;
1500+
Janet labelkw_loopstart = janet_wrap_keyword(janet_symbol_gen());
1501+
Janet labelkw_loopend = janet_wrap_keyword(janet_symbol_gen());
1502+
JanetSysOp lhs_getp = (linkage->type_defs[lhs_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP;
1503+
JanetSysOp rhs_getp = (linkage->type_defs[rhs_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP;
1504+
JanetSysOp dest_getp = (linkage->type_defs[dest_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP;
1505+
JanetSysInstruction patch[] = {
1506+
makebind(instruction, JANET_SYSOP_TYPE_BIND, index_reg, index_type),
1507+
makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_lhs, pel_type),
1508+
makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_rhs, pel_type),
1509+
makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_dest, pel_type),
1510+
makebind(instruction, JANET_SYSOP_TYPE_BIND, compare_reg, boolean_type),
1511+
maketwo(instruction, JANET_SYSOP_LOAD, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(0))),
1512+
makelabel(instruction, JANET_SYSOP_LABEL, loopstart_label),
1513+
makethree(instruction, JANET_SYSOP_GTE, compare_reg, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(array_size))),
1514+
makebranch(instruction, JANET_SYSOP_BRANCH, compare_reg, loopend_label),
1515+
makethree(instruction, lhs_getp, temp_lhs, instruction.three.lhs, index_reg),
1516+
makethree(instruction, rhs_getp, temp_rhs, instruction.three.rhs, index_reg),
1517+
makethree(instruction, dest_getp, temp_dest, instruction.three.dest, index_reg),
1518+
makethree(instruction, instruction.opcode, temp_dest, temp_lhs, temp_rhs),
1519+
makethree(instruction, JANET_SYSOP_ADD, index_reg, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(1))),
1520+
makejmp(instruction, JANET_SYSOP_JUMP, loopstart_label),
1521+
makelabel(instruction, JANET_SYSOP_LABEL, loopend_label)
1522+
};
1523+
size_t patchcount = sizeof(patch) / sizeof(patch[0]);
1524+
janet_table_put(sysir->labels, labelkw_loopstart, janet_wrap_number(loopstart_label));
1525+
janet_table_put(sysir->labels, labelkw_loopend, janet_wrap_number(loopend_label));
1526+
janet_table_put(sysir->labels, janet_wrap_number(loopstart_label), janet_wrap_number(i + 1));
1527+
janet_table_put(sysir->labels, janet_wrap_number(loopend_label), janet_wrap_number(i + patchcount - 1));
1528+
size_t remaining = (sysir->instruction_count - i - 1) * sizeof(JanetSysInstruction);
1529+
sysir->instructions = janet_realloc(sysir->instructions, (sysir->instruction_count + patchcount - 1) * sizeof(JanetSysInstruction));
1530+
if (remaining) memmove(sysir->instructions + i + patchcount, sysir->instructions + i + 1, remaining);
1531+
safe_memcpy(sysir->instructions + i, patch, sizeof(patch));
1532+
i += patchcount - 2;
1533+
sysir->instruction_count += patchcount - 1;
1534+
break;
1535+
}
1536+
}
1537+
}
1538+
}
1539+
}
1540+
13491541
/* Lowering to C */
13501542

13511543
static const char *c_prim_names[] = {
@@ -1917,10 +2109,10 @@ void janet_sys_ir_lower_to_ir(JanetSysIRLinkage *linkage, JanetArray *into) {
19172109
static int sysir_gc(void *p, size_t s) {
19182110
JanetSysIR *ir = (JanetSysIR *)p;
19192111
(void) s;
1920-
janet_free(ir->constants);
1921-
janet_free(ir->types);
2112+
janet_v_free(ir->constants);
2113+
janet_v_free(ir->types);
2114+
janet_v_free(ir->register_names);
19222115
janet_free(ir->instructions);
1923-
janet_free((void *) ir->register_names);
19242116
return 0;
19252117
}
19262118

@@ -1949,8 +2141,8 @@ static int sysir_context_gc(void *p, size_t s) {
19492141
JanetSysIRLinkage *linkage = (JanetSysIRLinkage *)p;
19502142
(void) s;
19512143
janet_free(linkage->field_defs);
1952-
janet_free(linkage->type_defs);
1953-
janet_free((void *) linkage->type_names);
2144+
janet_v_free(linkage->type_defs);
2145+
janet_v_free((void *) linkage->type_names);
19542146
return 0;
19552147
}
19562148

@@ -2024,6 +2216,15 @@ JANET_CORE_FN(cfun_sysir_toir,
20242216
return janet_wrap_array(array);
20252217
}
20262218

2219+
JANET_CORE_FN(cfun_sysir_scalarize,
2220+
"(sysir/scalarize context)",
2221+
"Lower all vectorized instrinsics to loops of scalar operations.") {
2222+
janet_fixarity(argc, 1);
2223+
JanetSysIRLinkage *ir = janet_getabstract(argv, 0, &janet_sysir_context_type);
2224+
janet_sysir_scalarize(ir);
2225+
return argv[0];
2226+
}
2227+
20272228
JANET_CORE_FN(cfun_sysir_tox64,
20282229
"(sysir/to-x64 context &opt buffer target)",
20292230
"Lower IR to x64 machine code.") {
@@ -2052,6 +2253,7 @@ void janet_lib_sysir(JanetTable *env) {
20522253
JanetRegExt cfuns[] = {
20532254
JANET_CORE_REG("sysir/context", cfun_sysir_context),
20542255
JANET_CORE_REG("sysir/asm", cfun_sysir_asm),
2256+
JANET_CORE_REG("sysir/scalarize", cfun_sysir_scalarize),
20552257
JANET_CORE_REG("sysir/to-c", cfun_sysir_toc),
20562258
JANET_CORE_REG("sysir/to-ir", cfun_sysir_toir),
20572259
JANET_CORE_REG("sysir/to-x64", cfun_sysir_tox64),

0 commit comments

Comments
 (0)