Package upgrades and SQLAlchemy 2.0 compatibility

pull/225/head
drnushooz 2023-01-29 17:40:18 -08:00 committed by Abhinav Chawade
parent cc4633cd33
commit 9fbf83baff
12 changed files with 82 additions and 98 deletions

View File

@ -1,4 +1,4 @@
psycopg2
psycopg2-binary
pandas
pytest
wheel
@ -6,4 +6,3 @@ twine
readme-renderer
black
isort

View File

@ -1,6 +1,8 @@
prettytable==0.7.2
ipython>=1.0
sqlalchemy>=0.6.7,<2.0
prettytable
ipython
sqlalchemy
sqlparse
six
ipython-genutils>=0.1.0
ipython-genutils
traitlets
matplotlib

View File

@ -8,15 +8,15 @@ README = open(os.path.join(here, "README.rst"), encoding="utf-8").read()
NEWS = open(os.path.join(here, "NEWS.rst"), encoding="utf-8").read()
version = "0.4.1"
version = "0.5.0"
install_requires = [
"prettytable<1",
"ipython>=1.0",
"sqlalchemy>=0.6.7",
"prettytable",
"ipython",
"sqlalchemy",
"sqlparse",
"six",
"ipython-genutils>=0.1.0",
"ipython-genutils",
]

View File

@ -6,7 +6,7 @@ makes guesses about the role of each column for plotting purposes
class Column(list):
"Store a column of tabular data; record its name and whether it is numeric"
"""Store a column of tabular data; record its name and whether it is numeric"""
is_quantity = True
name = ""
@ -28,6 +28,9 @@ class ColumnGuesserMixin(object):
pie: ... y
"""
def __init__(self):
self.keys = None
def _build_columns(self):
self.columns = [Column() for col in self.keys]
for row in self:

View File

@ -1,5 +1,5 @@
import os
import re
import traceback
import sqlalchemy
@ -45,20 +45,21 @@ class Connection(object):
engine = sqlalchemy.create_engine(
connect_str, connect_args=connect_args
)
except: # TODO: bare except; but what's an ArgumentError?
except Exception as ex: # TODO: bare except; but what's an ArgumentError?
print(traceback.format_exc())
print(self.tell_format())
raise
self.url = engine.url
self.dialect = engine.url.get_dialect()
self.metadata = sqlalchemy.MetaData(bind=engine)
self.name = self.assign_name(engine)
self.session = engine.connect()
self.connections[repr(self.metadata.bind.url)] = self
self.internal_connection = engine.connect()
self.connections[repr(self.url)] = self
self.connect_args = connect_args
Connection.current = self
@classmethod
def set(cls, descriptor, displaycon, connect_args={}, creator=None):
"Sets the current database connection"
"""Sets the current database connection"""
if descriptor:
if isinstance(descriptor, Connection):
@ -94,16 +95,16 @@ class Connection(object):
for key in sorted(cls.connections):
engine_url = cls.connections[
key
].metadata.bind.url # type: sqlalchemy.engine.url.URL
].url # type: sqlalchemy.engine.url.URL
if cls.connections[key] == cls.current:
template = " * {}"
else:
template = " {}"
result.append(template.format(engine_url.__repr__()))
return "\n".join(result)
@classmethod
def _close(cls, descriptor):
def close(cls, descriptor):
if isinstance(descriptor, Connection):
conn = descriptor
else:
@ -115,8 +116,5 @@ class Connection(object):
"Could not close connection because it was not found amongst these: %s"
% str(cls.connections.keys())
)
cls.connections.pop(str(conn.metadata.bind.url))
conn.session.close()
def close(self):
self.__class__._close(self)
cls.connections.pop(str(conn.url))
conn.internal_connection.close()

View File

@ -1,6 +1,6 @@
import json
import re
from string import Formatter
import traceback
from IPython.core.magic import (
Magics,
@ -10,7 +10,6 @@ from IPython.core.magic import (
needs_local_scope,
)
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
from IPython.display import display_javascript
from sqlalchemy.exc import OperationalError, ProgrammingError, DatabaseError
import sql.connection
@ -46,7 +45,8 @@ class SqlMagic(Magics, Configurable):
style = Unicode(
"DEFAULT",
config=True,
help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)",
help="Set the table printing style to any of prettytable's defined styles "
"(currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)",
)
short_errors = Bool(
True,
@ -72,9 +72,9 @@ class SqlMagic(Magics, Configurable):
"odbc.ini",
config=True,
help="Path to DSN file. "
"When the first argument is of the form [section], "
"a sqlalchemy connection string is formed from the "
"matching section in the DSN file.",
"When the first argument is of the form [section], "
"a sqlalchemy connection string is formed from the "
"matching section in the DSN file.",
)
autocommit = Bool(True, config=True, help="Set autocommit mode")
@ -82,7 +82,7 @@ class SqlMagic(Magics, Configurable):
Configurable.__init__(self, config=shell.config)
Magics.__init__(self, shell=shell)
# Add ourself to the list of module configurable via %config
# Add ourselves to the list of module configurable via %config
self.shell.configurables.append(self)
@needs_local_scope
@ -121,7 +121,7 @@ class SqlMagic(Magics, Configurable):
help="specify dictionary of connection arguments to pass to SQL driver",
)
@argument("-f", "--file", type=str, help="Run SQL from file at this path")
def execute(self, line="", cell="", local_ns={}):
def execute(self, line="", cell="", local_ns=None):
"""Runs SQL statement against a database, specified by SQLAlchemy connect string.
If no database connection has been established, first word
@ -147,15 +147,17 @@ class SqlMagic(Magics, Configurable):
"""
# Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically)
if local_ns is None:
local_ns = {}
cell = self.shell.var_expand(cell)
line = sql.parse.without_sql_comment(parser=self.execute.parser, line=line)
args = parse_argstring(self.execute, line)
if args.connections:
return sql.connection.Connection.connections
elif args.close:
return sql.connection.Connection._close(args.close)
return sql.connection.Connection.close(args.close)
# save globals and locals so they can be referenced in bind vars
# save globals and locals, so they can be referenced in bind vars
user_ns = self.shell.user_ns.copy()
user_ns.update(local_ns)
@ -173,7 +175,7 @@ class SqlMagic(Magics, Configurable):
if args.connection_arguments:
try:
# check for string deliniators, we need to strip them for json parse
# check for string delineators, we need to strip them for json parse
raw_args = args.connection_arguments
if len(raw_args) > 1:
targets = ['"', "'"]
@ -183,7 +185,7 @@ class SqlMagic(Magics, Configurable):
raw_args = raw_args[1:-1]
args.connection_arguments = json.loads(raw_args)
except Exception as e:
print(e)
print(traceback.format_exc())
raise e
else:
args.connection_arguments = {}
@ -197,8 +199,10 @@ class SqlMagic(Magics, Configurable):
connect_args=args.connection_arguments,
creator=args.creator,
)
except Exception as e:
print(e)
# Rollback just in case there was an error in previous statement
conn.internal_connection.rollback()
except Exception:
print(traceback.format_exc())
print(sql.connection.Connection.tell_format())
return None
@ -220,7 +224,7 @@ class SqlMagic(Magics, Configurable):
and self.column_local_vars
):
# Instead of returning values, set variables directly in the
# users namespace. Variable names given by column names
# user's namespace. Variable names given by column names
if self.autopandas:
keys = result.keys()
@ -253,7 +257,8 @@ class SqlMagic(Magics, Configurable):
if self.short_errors:
print(e)
else:
raise
print(traceback.format_exc())
raise e
legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+")
@ -279,7 +284,7 @@ class SqlMagic(Magics, Configurable):
table_name = self.legal_sql_identifier.search(table_name).group(0)
if_exists = "append" if append else "fail"
frame.to_sql(table_name, conn.session.engine, if_exists=if_exists)
frame.to_sql(table_name, conn.internal_connection.engine, if_exists=if_exists)
return "Persisted %s" % table_name

View File

@ -1,10 +1,7 @@
import itertools
import json
import re
import shlex
from os.path import expandvars
import six
from six.moves import configparser as CP
from sqlalchemy.engine.url import URL
@ -17,7 +14,6 @@ def connection_from_dsn_section(section, config):
def _connection_string(s, config):
s = expandvars(s) # for environment variables
if "@" in s or "://" in s:
return s

View File

@ -3,6 +3,7 @@ import csv
import operator
import os.path
import re
import traceback
from functools import reduce
import prettytable
@ -103,21 +104,16 @@ class ResultSet(list, ColumnGuesserMixin):
Can access rows listwise, or by string value of leftmost column.
"""
def __init__(self, sqlaproxy, sql, config):
self.keys = sqlaproxy.keys()
self.sql = sql
def __init__(self, sqlaproxy, config):
self.config = config
self.limit = config.autolimit
style_name = config.style
self.style = prettytable.__dict__[style_name.upper()]
if sqlaproxy.returns_rows:
if self.limit:
list.__init__(self, sqlaproxy.fetchmany(size=self.limit))
self.keys = sqlaproxy.keys()
if config.autolimit:
list.__init__(self, sqlaproxy.fetchmany(size=config.autolimit))
else:
list.__init__(self, sqlaproxy.fetchall())
self.field_names = unduplicate_field_names(self.keys)
self.pretty = PrettyTable(self.field_names, style=self.style)
# self.pretty.set_style(self.style)
self.pretty = PrettyTable(self.field_names, style=prettytable.__dict__[config.style.upper()])
else:
list.__init__(self, [])
self.pretty = None
@ -163,12 +159,12 @@ class ResultSet(list, ColumnGuesserMixin):
return dict(zip(self.keys, zip(*self)))
def dicts(self):
"Iterator yielding a dict for each row"
"""Iterator yielding a dict for each row"""
for row in self:
yield dict(zip(self.keys, row))
def DataFrame(self):
"Returns a Pandas DataFrame instance built from the result set."
"""Returns a Pandas DataFrame instance built from the result set."""
import pandas as pd
frame = pd.DataFrame(self, columns=(self and self.keys) or [])
@ -315,7 +311,7 @@ class FakeResultProxy(object):
self.returns_rows = True
def from_list(self, source_list):
"Simulates SQLA ResultProxy from a list."
"""Simulates SQLA ResultProxy from a list."""
self.fetchall = lambda: source_list
self.rowcount = len(source_list)
@ -323,7 +319,7 @@ class FakeResultProxy(object):
def fetchmany(size):
pos = 0
while pos < len(source_list):
yield source_list[pos : pos + size]
yield source_list[pos: pos + size]
pos += size
self.fetchmany = fetchmany
@ -344,9 +340,13 @@ def _commit(conn, config):
if _should_commit:
try:
conn.session.execute("commit")
conn.internal_connection.commit()
except sqlalchemy.exc.OperationalError:
pass # not all engines can commit
except Exception as ex:
conn.internal_connection.rollback()
print(traceback.format_exc())
raise ex
def run(conn, sql, config, user_namespace):
@ -356,22 +356,22 @@ def run(conn, sql, config, user_namespace):
if first_word == "begin":
raise Exception("ipython_sql does not support transactions")
if first_word.startswith("\\") and \
("postgres" in str(conn.dialect) or \
"redshift" in str(conn.dialect)):
("postgres" in str(conn.dialect) or
"redshift" in str(conn.dialect)):
if not PGSpecial:
raise ImportError("pgspecial not installed")
pgspecial = PGSpecial()
_, cur, headers, _ = pgspecial.execute(
conn.session.connection.cursor(), statement
conn.internal_connection.connection.cursor(), statement
)[0]
result = FakeResultProxy(cur, headers)
else:
txt = sqlalchemy.sql.text(statement)
result = conn.session.execute(txt, user_namespace)
result = conn.internal_connection.execute(txt, user_namespace)
_commit(conn=conn, config=config)
if result and config.feedback:
print(interpret_rowcount(result.rowcount))
resultset = ResultSet(result, statement, config)
resultset = ResultSet(result, config)
if config.autopandas:
return resultset.DataFrame()
else:

View File

@ -1,6 +1,3 @@
import re
import sys
import pytest
from sql.magic import SqlMagic

View File

@ -1,14 +1,14 @@
[DB_CONFIG_1]
drivername=postgres
host=my.remote.host
port=5432
database=pgmain
username=goesto11
password=seentheelephant
drivername = postgres
host = my.remote.host
port = 5432
database = pgmain
username = goesto11
password = seentheelephant
[DB_CONFIG_2]
drivername = mysql
host = 127.0.0.1
database = dolfin
username = thefin
password = fishputsfishonthetable
drivername = mysql
host = 127.0.0.1
database = dolfin
username = thefin
password = fishputsfishonthetable

View File

@ -5,8 +5,6 @@ from textwrap import dedent
import pytest
from sql.magic import SqlMagic
def runsql(ip_session, statements):
if isinstance(statements, str):
@ -321,7 +319,6 @@ def test_dicts(ip):
def test_bracket_var_substitution(ip):
ip.user_global_ns["col"] = "first_name"
assert runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")[0] == (
u"William",
@ -335,7 +332,6 @@ def test_bracket_var_substitution(ip):
def test_multiline_bracket_var_substitution(ip):
ip.user_global_ns["col"] = "first_name"
assert runsql(ip, "SELECT * FROM author\n" " WHERE {col} = 'William' ")[0] == (
u"William",
@ -370,7 +366,7 @@ def test_multiline_bracket_var_substitution(ip):
""",
)
assert not result
def test_json_in_select(ip):
# Variable expansion does not work within json, but

View File

@ -1,9 +1,6 @@
import json
import os
from pathlib import Path
from six.moves import configparser
from sql.parse import connection_from_dsn_section, parse, without_sql_comment
try:
@ -112,7 +109,6 @@ class DummyConfig:
def test_connection_from_dsn_section():
result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig())
assert result == "postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain"
result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig())
@ -143,54 +139,46 @@ parser_stub = ParserStub()
def test_without_sql_comment_plain():
line = "SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == line
def test_without_sql_comment_with_arg():
line = "--file moo.txt --persist SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == line
def test_without_sql_comment_with_comment():
line = "SELECT * FROM author -- uff da"
expected = "SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected
def test_without_sql_comment_with_arg_and_comment():
line = "--file moo.txt --persist SELECT * FROM author -- uff da"
expected = "--file moo.txt --persist SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected
def test_without_sql_comment_unspaced_comment():
line = "SELECT * FROM author --uff da"
expected = "SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected
def test_without_sql_comment_dashes_in_string():
line = "SELECT '--very --confusing' FROM author -- uff da"
expected = "SELECT '--very --confusing' FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected
def test_without_sql_comment_with_arg_and_leading_comment():
line = "--file moo.txt --persist --comment, not arg"
expected = "--file moo.txt --persist"
assert without_sql_comment(parser=parser_stub, line=line) == expected
def test_without_sql_persist():
line = "--persist my_table --uff da"
expected = "--persist my_table"
assert without_sql_comment(parser=parser_stub, line=line) == expected