Skip to content

Commit 13c61f5

Browse files
ericprefcollonval
andauthored
Speed up list_available (#133)
* Use mamba repoquery to retrieve packages instead of conda search when mamba is available to speed up the list_available method. * Avoid duplicated code * Simplify processing of mamba repoquery output * Add testing for mamba repoquery Co-authored-by: Frederic COLLONVAL <[email protected]>
1 parent 8f8003c commit 13c61f5

File tree

2 files changed

+89
-4
lines changed

2 files changed

+89
-4
lines changed

mamba_gator/envmanager.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright (c) 2016-2020 Jupyter Development Team.
22
# Distributed under the terms of the Modified BSD License.
33
import asyncio
4+
import collections
45
import json
56
import logging
67
import os
78
import re
8-
import ssl
99
import sys
1010
import tempfile
11-
from functools import partial
11+
from functools import partial, lru_cache
1212
from pathlib import Path
1313
from subprocess import PIPE, Popen
1414
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -233,6 +233,10 @@ def manager(self) -> str:
233233

234234
return EnvManager._manager_exe
235235

236+
@lru_cache()
237+
def is_mamba(self):
238+
return Path(self.manager).stem == "mamba"
239+
236240
async def env_channels(
237241
self, configuration: Optional[Dict[str, Any]] = None
238242
) -> Dict[str, Dict[str, List[str]]]:
@@ -552,7 +556,7 @@ async def pkg_depends(self, pkg: str) -> Dict[str, List[str]]:
552556
Returns:
553557
{"package": List[dependencies]}
554558
"""
555-
if Path(self.manager).stem != "mamba":
559+
if not self.is_mamba():
556560
self.log.warning(
557561
"Package manager '{}' does not support dependency query.".format(
558562
self.manager
@@ -585,7 +589,10 @@ async def list_available(self) -> Dict[str, List[Dict[str, str]]]:
585589
"with_description": bool # Whether we succeed in get some channeldata.json files
586590
}
587591
"""
588-
ans = await self._execute(self.manager, "search", "--json")
592+
if self.is_mamba():
593+
ans = await self._execute(self.manager, "repoquery", "search", "*", "--json")
594+
else:
595+
ans = await self._execute(self.manager, "search", "--json")
589596
_, output = ans
590597

591598
current_loop = tornado.ioloop.IOLoop.current()
@@ -596,6 +603,23 @@ async def list_available(self) -> Dict[str, List[Dict[str, str]]]:
596603
# dictionary with error info
597604
return data
598605

606+
def process_mamba_repoquery_output(data: Dict) -> Dict:
607+
"""Make a dictionary with keys as packages name and values
608+
containing the list of available packages to match the json output
609+
of "conda search --json".
610+
"""
611+
612+
data_ = collections.defaultdict(lambda : [])
613+
for entry in data['result']['pkgs']:
614+
name = entry.get('name')
615+
if name is not None:
616+
data_[name].append(entry)
617+
618+
return data_
619+
620+
if self.is_mamba():
621+
data = await current_loop.run_in_executor(None, process_mamba_repoquery_output, data)
622+
599623
def format_packages(data: Dict) -> List:
600624
packages = []
601625

mamba_gator/tests/test_api.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import unittest
77
import unittest.mock as mock
8+
from itertools import chain
89

910
try:
1011
from unittest.mock import AsyncMock
@@ -971,6 +972,14 @@ def test_package_list_available(self):
971972
"ssl_verify": False,
972973
}
973974

975+
if has_mamba:
976+
# Change dummy to match mamba repoquery format
977+
dummy = {
978+
"result": {
979+
"pkgs": list(chain(*dummy.values()))
980+
}
981+
}
982+
974983
rvalue = [
975984
(0, json.dumps(dummy)),
976985
(0, json.dumps(channels)),
@@ -984,6 +993,13 @@ def test_package_list_available(self):
984993

985994
r = self.wait_for_task(self.conda_api.get, ["packages"])
986995
self.assertEqual(r.status_code, 200)
996+
997+
args, _ = f.call_args_list[0]
998+
if has_mamba:
999+
self.assertSequenceEqual(args[1:], ["repoquery", "search", "*", "--json"])
1000+
else:
1001+
self.assertSequenceEqual(args[1:], ["search", "--json"])
1002+
9871003
body = r.json()
9881004

9891005
expected = {
@@ -1164,6 +1180,14 @@ def test_package_list_available_local_channel(self):
11641180
],
11651181
}
11661182

1183+
if has_mamba:
1184+
# Change dummy to match mamba repoquery format
1185+
dummy = {
1186+
"result": {
1187+
"pkgs": list(chain(*dummy.values()))
1188+
}
1189+
}
1190+
11671191
with tempfile.TemporaryDirectory() as local_channel:
11681192
with open(
11691193
os.path.join(local_channel, "channeldata.json"), "w+"
@@ -1202,6 +1226,13 @@ def test_package_list_available_local_channel(self):
12021226

12031227
r = self.wait_for_task(self.conda_api.get, ["packages"])
12041228
self.assertEqual(r.status_code, 200)
1229+
1230+
args, _ = f.call_args_list[0]
1231+
if has_mamba:
1232+
self.assertSequenceEqual(args[1:], ["repoquery", "search", "*", "--json"])
1233+
else:
1234+
self.assertSequenceEqual(args[1:], ["search", "--json"])
1235+
12051236
body = r.json()
12061237

12071238
expected = {
@@ -1382,6 +1413,14 @@ def test_package_list_available_no_description(self):
13821413
],
13831414
}
13841415

1416+
if has_mamba:
1417+
# Change dummy to match mamba repoquery format
1418+
dummy = {
1419+
"result": {
1420+
"pkgs": list(chain(*dummy.values()))
1421+
}
1422+
}
1423+
13851424
with tempfile.TemporaryDirectory() as local_channel:
13861425
local_name = local_channel.strip("/")
13871426
channels = {
@@ -1414,6 +1453,13 @@ def test_package_list_available_no_description(self):
14141453

14151454
r = self.wait_for_task(self.conda_api.get, ["packages"])
14161455
self.assertEqual(r.status_code, 200)
1456+
1457+
args, _ = f.call_args_list[0]
1458+
if has_mamba:
1459+
self.assertSequenceEqual(args[1:], ["repoquery", "search", "*", "--json"])
1460+
else:
1461+
self.assertSequenceEqual(args[1:], ["search", "--json"])
1462+
14171463
body = r.json()
14181464

14191465
expected = {
@@ -1620,6 +1666,15 @@ def test_package_list_available_caching(self):
16201666
"ssl_verify": False,
16211667
}
16221668

1669+
1670+
if has_mamba:
1671+
# Change dummy to match mamba repoquery format
1672+
dummy = {
1673+
"result": {
1674+
"pkgs": list(chain(*dummy.values()))
1675+
}
1676+
}
1677+
16231678
rvalue = [
16241679
(0, json.dumps(dummy)),
16251680
(0, json.dumps(channels)),
@@ -1635,6 +1690,12 @@ def test_package_list_available_caching(self):
16351690
r = self.wait_for_task(self.conda_api.get, ["packages"])
16361691
self.assertEqual(r.status_code, 200)
16371692

1693+
args, _ = f.call_args_list[0]
1694+
if has_mamba:
1695+
self.assertSequenceEqual(args[1:], ["repoquery", "search", "*", "--json"])
1696+
else:
1697+
self.assertSequenceEqual(args[1:], ["search", "--json"])
1698+
16381699
expected = {
16391700
"packages": [
16401701
{

0 commit comments

Comments
 (0)