Skip to content

Commit dfb818a

Browse files
authored
feat(postgres): add support for reading enum types as strings (#11028)
1 parent e1bbe55 commit dfb818a

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

ibis/backends/postgres/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,8 @@ def get_schema(
506506
a = ColGen(table="a")
507507
c = ColGen(table="c")
508508
n = ColGen(table="n")
509+
t = ColGen(table="t")
510+
e = ColGen(table="e")
509511

510512
format_type = self.compiler.f["pg_catalog.format_type"]
511513

@@ -522,7 +524,23 @@ def get_schema(
522524
type_info = (
523525
sg.select(
524526
a.attname.as_("column_name"),
525-
format_type(a.atttypid, a.atttypmod).as_("data_type"),
527+
sg.case()
528+
.when(
529+
sge.Exists(
530+
this=sg.select(1)
531+
.from_(sg.table("pg_type", db="pg_catalog").as_("t"))
532+
.join(
533+
sg.table("pg_enum", db="pg_catalog").as_("e"),
534+
on=sg.and_(
535+
e.enumtypid.eq(t.oid),
536+
t.typname.eq(format_type(a.atttypid, a.atttypmod)),
537+
),
538+
)
539+
),
540+
sge.convert("enum"),
541+
)
542+
.else_(format_type(a.atttypid, a.atttypmod))
543+
.as_("data_type"),
526544
sg.not_(a.attnotnull).as_("nullable"),
527545
)
528546
.from_(sg.table("pg_attribute", db="pg_catalog").as_("a"))

ibis/backends/postgres/tests/test_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,14 +429,18 @@ def enum_table(con):
429429
name = gen_name("enum_table")
430430
with con._safe_raw_sql("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')") as cur:
431431
cur.execute(f"CREATE TEMP TABLE {name} (mood mood)")
432+
cur.execute(f"INSERT INTO {name} (mood) VALUES ('happy'), ('ok')")
432433
yield name
433434
cur.execute(f"DROP TABLE {name}")
434435
cur.execute("DROP TYPE mood")
435436

436437

437438
def test_enum_table(con, enum_table):
438439
t = con.table(enum_table)
439-
assert t.mood.type() == dt.unknown
440+
assert t.mood.type().is_string()
441+
e = t.filter(t.mood == "ok")
442+
result = e.execute()
443+
assert len(result) == 1
440444

441445

442446
def test_parsing_oid_dtype(con):

0 commit comments

Comments
 (0)