From 1e8161548fa20b376a8f9f91f81d1ba5e64491ba Mon Sep 17 00:00:00 2001 From: Catherine Devlin Date: Thu, 5 Apr 2018 22:14:12 -0400 Subject: [PATCH] Selection of existing connection works with concealed password --- NEWS.txt | 1 + run_tests.sh | 2 +- src/sql/connection.py | 29 +++++++++++++++++++---------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/NEWS.txt b/NEWS.txt index c5dc1ad..3e3642e 100644 --- a/NEWS.txt +++ b/NEWS.txt @@ -149,3 +149,4 @@ Deleted Plugin import left behind in 0.2.2 * added README example (thanks tanhuil) * bugfix in executing column_local_vars (thanks tebeka) * pgspecial installation optional (thanks jstoebel and arjoe) +* conceal passwords in connection strings (thanks jstoebel) diff --git a/run_tests.sh b/run_tests.sh index 3c53115..66502f0 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,3 +1,3 @@ #!/bin/bash -python -c "import pytest; pytest.main(['.', '-x', '--pdb'])" +ipython -c "import pytest; pytest.main(['.', '-x', '--pdb'])" # Insert breakpoints with `import pytest; pytest.set_trace()` diff --git a/src/sql/connection.py b/src/sql/connection.py index d4c9bac..986c043 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -1,10 +1,26 @@ import sqlalchemy import os +import re class ConnectionError(Exception): pass +def rough_dict_get(dct, sought, default=None): + ''' + Like dct.get(sought), but any key containing sought will do. + + If there is a `@` in sought, seek each piece separately. + This lets `me@server` match `me:***@myserver/db` + ''' + + sought = sought.split('@') + for (key, val) in dct.items(): + if not any(s.lower() not in key.lower() for s in sought): + return val + return default + + class Connection(object): current = None connections = {} @@ -25,8 +41,7 @@ class Connection(object): self.metadata = sqlalchemy.MetaData(bind=engine) self.name = self.assign_name(engine) self.session = engine.connect() - self.connections[self.name] = self - self.connections[str(self.metadata.bind.url)] = self + self.connections[repr(self.metadata.bind.url)] = self Connection.current = self @classmethod @@ -37,8 +52,7 @@ class Connection(object): if isinstance(descriptor, Connection): cls.current = descriptor else: - existing = cls.connections.get(descriptor) or \ - cls.connections.get(descriptor.lower()) + existing = rough_dict_get(cls.connections, descriptor) cls.current = existing or Connection(descriptor) else: if cls.connections: @@ -52,12 +66,7 @@ class Connection(object): @classmethod def assign_name(cls, engine): - core_name = '%s@%s' % (engine.url.username or '', engine.url.database) - incrementer = 1 - name = core_name - while name in cls.connections: - name = '%s_%d' % (core_name, incrementer) - incrementer += 1 + name = '%s@%s' % (engine.url.username or '', engine.url.database) return name @classmethod