Skip to content

Commit 2ccdd9b

Browse files
authored
Merge pull request #7 from acryldata/use-pure-sasl
2 parents 508c941 + c6085e0 commit 2ccdd9b

File tree

6 files changed

+436
-25
lines changed

6 files changed

+436
-25
lines changed

dev_requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ sqlalchemy==1.3.24
1313
requests>=1.0.0
1414
requests_kerberos>=0.12.0
1515
sasl>=0.2.1
16+
pure-sasl>=0.6.2
17+
kerberos>=1.3.0
1618
thrift>=0.10.0
1719
#thrift_sasl>=0.1.0
1820
git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches

pyhive/hive.py

+41-15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,45 @@
4949
}
5050

5151

52+
def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
53+
import sasl
54+
sasl_client = sasl.Client()
55+
sasl_client.setAttr('host', host)
56+
57+
if sasl_auth == 'GSSAPI':
58+
sasl_client.setAttr('service', service)
59+
elif sasl_auth == 'PLAIN':
60+
sasl_client.setAttr('username', username)
61+
sasl_client.setAttr('password', password)
62+
else:
63+
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
64+
65+
sasl_client.init()
66+
return sasl_client
67+
68+
69+
def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
70+
from pyhive.sasl_compat import PureSASLClient
71+
72+
if sasl_auth == 'GSSAPI':
73+
sasl_kwargs = {'service': service}
74+
elif sasl_auth == 'PLAIN':
75+
sasl_kwargs = {'username': username, 'password': password}
76+
else:
77+
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
78+
79+
return PureSASLClient(host=host, **sasl_kwargs)
80+
81+
82+
def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
83+
try:
84+
return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
85+
# The sasl library is available
86+
except ImportError:
87+
# Fallback to pure-sasl library
88+
return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
89+
90+
5291
def _parse_timestamp(value):
5392
if value:
5493
match = _TIMESTAMP_PATTERN.match(value)
@@ -224,7 +263,6 @@ def __init__(
224263
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
225264
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
226265
# Defer import so package dependency is optional
227-
import sasl
228266
import thrift_sasl
229267

230268
if auth == 'KERBEROS':
@@ -235,20 +273,8 @@ def __init__(
235273
if password is None:
236274
# Password doesn't matter in NONE mode, just needs to be nonempty.
237275
password = 'x'
238-
239-
def sasl_factory():
240-
sasl_client = sasl.Client()
241-
sasl_client.setAttr('host', host)
242-
if sasl_auth == 'GSSAPI':
243-
sasl_client.setAttr('service', kerberos_service_name)
244-
elif sasl_auth == 'PLAIN':
245-
sasl_client.setAttr('username', username)
246-
sasl_client.setAttr('password', password)
247-
else:
248-
raise AssertionError
249-
sasl_client.init()
250-
return sasl_client
251-
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
276+
277+
self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
252278
else:
253279
# All HS2 config options:
254280
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration

pyhive/sasl_compat.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
2+
# which uses Apache-2.0 license as of 21 May 2023.
3+
# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
4+
# via PR https://github.com/cloudera/impyla/pull/179
5+
# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
6+
# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
7+
# Hence this code is required for the fallback to work.
8+
9+
10+
from puresasl.client import SASLClient, SASLError
11+
from contextlib import contextmanager
12+
13+
@contextmanager
14+
def error_catcher(self, Exc = Exception):
15+
try:
16+
self.error = None
17+
yield
18+
except Exc as e:
19+
self.error = str(e)
20+
21+
22+
class PureSASLClient(SASLClient):
23+
def __init__(self, *args, **kwargs):
24+
self.error = None
25+
super(PureSASLClient, self).__init__(*args, **kwargs)
26+
27+
def start(self, mechanism):
28+
with error_catcher(self, SASLError):
29+
if isinstance(mechanism, list):
30+
self.choose_mechanism(mechanism)
31+
else:
32+
self.choose_mechanism([mechanism])
33+
return True, self.mechanism, self.process()
34+
# else
35+
return False, mechanism, None
36+
37+
def encode(self, incoming):
38+
with error_catcher(self):
39+
return True, self.unwrap(incoming)
40+
# else
41+
return False, None
42+
43+
def decode(self, outgoing):
44+
with error_catcher(self):
45+
return True, self.wrap(outgoing)
46+
# else
47+
return False, None
48+
49+
def step(self, challenge=None):
50+
with error_catcher(self):
51+
return True, self.process(challenge)
52+
# else
53+
return False, None
54+
55+
def getError(self):
56+
return self.error

pyhive/tests/test_hive.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from decimal import Decimal
1818

1919
import mock
20-
import sasl
2120
import thrift.transport.TSocket
2221
import thrift.transport.TTransport
2322
import thrift_sasl
@@ -204,15 +203,7 @@ def test_custom_transport(self):
204203
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
205204
sasl_auth = 'PLAIN'
206205

207-
def sasl_factory():
208-
sasl_client = sasl.Client()
209-
sasl_client.setAttr('host', 'localhost')
210-
sasl_client.setAttr('username', 'test_username')
211-
sasl_client.setAttr('password', 'x')
212-
sasl_client.init()
213-
return sasl_client
214-
215-
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
206+
transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
216207
conn = hive.connect(thrift_transport=transport)
217208
with contextlib.closing(conn):
218209
with contextlib.closing(conn.cursor()) as cursor:

0 commit comments

Comments
 (0)