|
6 | 6 | import itertools
|
7 | 7 | import os
|
8 | 8 | import warnings
|
| 9 | +from functools import partial |
9 | 10 | from pathlib import Path
|
10 | 11 | from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, MutableMapping
|
11 | 12 |
|
@@ -143,16 +144,32 @@ def configure_connection(dbapi_connection, connection_record):
|
143 | 144 | super().do_connect(engine)
|
144 | 145 |
|
145 | 146 | self._meta = sa.MetaData()
|
146 |
| - self._extensions = set() |
147 | 147 |
|
148 | 148 | def _load_extensions(self, extensions):
|
149 |
| - for extension in extensions: |
150 |
| - if extension not in self._extensions: |
151 |
| - with self.begin() as con: |
152 |
| - c = con.connection |
153 |
| - c.install_extension(extension) |
154 |
| - c.load_extension(extension) |
155 |
| - self._extensions.add(extension) |
| 149 | + extension_name = sa.column("extension_name") |
| 150 | + loaded = sa.column("loaded") |
| 151 | + installed = sa.column("installed") |
| 152 | + aliases = sa.column("aliases") |
| 153 | + query = ( |
| 154 | + sa.select(extension_name) |
| 155 | + .select_from(sa.func.duckdb_extensions()) |
| 156 | + .where( |
| 157 | + sa.and_( |
| 158 | + # extension isn't loaded or isn't installed |
| 159 | + sa.not_(loaded & installed), |
| 160 | + # extension is one that we're requesting, or an alias of it |
| 161 | + sa.or_( |
| 162 | + extension_name.in_(extensions), |
| 163 | + *map(partial(sa.func.array_has, aliases), extensions), |
| 164 | + ), |
| 165 | + ) |
| 166 | + ) |
| 167 | + ) |
| 168 | + with self.begin() as con: |
| 169 | + c = con.connection |
| 170 | + for extension in con.execute(query).scalars(): |
| 171 | + c.install_extension(extension) |
| 172 | + c.load_extension(extension) |
156 | 173 |
|
157 | 174 | def register(
|
158 | 175 | self,
|
|
0 commit comments