Package upgrades and SQLAlchemy 2.0 compatibility
parent
cc4633cd33
commit
9fbf83baff
|
@ -1,4 +1,4 @@
|
|||
psycopg2
|
||||
psycopg2-binary
|
||||
pandas
|
||||
pytest
|
||||
wheel
|
||||
|
@ -6,4 +6,3 @@ twine
|
|||
readme-renderer
|
||||
black
|
||||
isort
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
prettytable==0.7.2
|
||||
ipython>=1.0
|
||||
sqlalchemy>=0.6.7,<2.0
|
||||
prettytable
|
||||
ipython
|
||||
sqlalchemy
|
||||
sqlparse
|
||||
six
|
||||
ipython-genutils>=0.1.0
|
||||
ipython-genutils
|
||||
traitlets
|
||||
matplotlib
|
||||
|
|
10
setup.py
10
setup.py
|
@ -8,15 +8,15 @@ README = open(os.path.join(here, "README.rst"), encoding="utf-8").read()
|
|||
NEWS = open(os.path.join(here, "NEWS.rst"), encoding="utf-8").read()
|
||||
|
||||
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
|
||||
install_requires = [
|
||||
"prettytable<1",
|
||||
"ipython>=1.0",
|
||||
"sqlalchemy>=0.6.7",
|
||||
"prettytable",
|
||||
"ipython",
|
||||
"sqlalchemy",
|
||||
"sqlparse",
|
||||
"six",
|
||||
"ipython-genutils>=0.1.0",
|
||||
"ipython-genutils",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ makes guesses about the role of each column for plotting purposes
|
|||
|
||||
|
||||
class Column(list):
|
||||
"Store a column of tabular data; record its name and whether it is numeric"
|
||||
"""Store a column of tabular data; record its name and whether it is numeric"""
|
||||
is_quantity = True
|
||||
name = ""
|
||||
|
||||
|
@ -28,6 +28,9 @@ class ColumnGuesserMixin(object):
|
|||
pie: ... y
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
|
||||
def _build_columns(self):
|
||||
self.columns = [Column() for col in self.keys]
|
||||
for row in self:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
|
@ -45,20 +45,21 @@ class Connection(object):
|
|||
engine = sqlalchemy.create_engine(
|
||||
connect_str, connect_args=connect_args
|
||||
)
|
||||
except: # TODO: bare except; but what's an ArgumentError?
|
||||
except Exception as ex: # TODO: bare except; but what's an ArgumentError?
|
||||
print(traceback.format_exc())
|
||||
print(self.tell_format())
|
||||
raise
|
||||
self.url = engine.url
|
||||
self.dialect = engine.url.get_dialect()
|
||||
self.metadata = sqlalchemy.MetaData(bind=engine)
|
||||
self.name = self.assign_name(engine)
|
||||
self.session = engine.connect()
|
||||
self.connections[repr(self.metadata.bind.url)] = self
|
||||
self.internal_connection = engine.connect()
|
||||
self.connections[repr(self.url)] = self
|
||||
self.connect_args = connect_args
|
||||
Connection.current = self
|
||||
|
||||
@classmethod
|
||||
def set(cls, descriptor, displaycon, connect_args={}, creator=None):
|
||||
"Sets the current database connection"
|
||||
"""Sets the current database connection"""
|
||||
|
||||
if descriptor:
|
||||
if isinstance(descriptor, Connection):
|
||||
|
@ -94,16 +95,16 @@ class Connection(object):
|
|||
for key in sorted(cls.connections):
|
||||
engine_url = cls.connections[
|
||||
key
|
||||
].metadata.bind.url # type: sqlalchemy.engine.url.URL
|
||||
].url # type: sqlalchemy.engine.url.URL
|
||||
if cls.connections[key] == cls.current:
|
||||
template = " * {}"
|
||||
else:
|
||||
template = " {}"
|
||||
result.append(template.format(engine_url.__repr__()))
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _close(cls, descriptor):
|
||||
def close(cls, descriptor):
|
||||
if isinstance(descriptor, Connection):
|
||||
conn = descriptor
|
||||
else:
|
||||
|
@ -115,8 +116,5 @@ class Connection(object):
|
|||
"Could not close connection because it was not found amongst these: %s"
|
||||
% str(cls.connections.keys())
|
||||
)
|
||||
cls.connections.pop(str(conn.metadata.bind.url))
|
||||
conn.session.close()
|
||||
|
||||
def close(self):
|
||||
self.__class__._close(self)
|
||||
cls.connections.pop(str(conn.url))
|
||||
conn.internal_connection.close()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import re
|
||||
from string import Formatter
|
||||
import traceback
|
||||
|
||||
from IPython.core.magic import (
|
||||
Magics,
|
||||
|
@ -10,7 +10,6 @@ from IPython.core.magic import (
|
|||
needs_local_scope,
|
||||
)
|
||||
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
|
||||
from IPython.display import display_javascript
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError, DatabaseError
|
||||
|
||||
import sql.connection
|
||||
|
@ -46,7 +45,8 @@ class SqlMagic(Magics, Configurable):
|
|||
style = Unicode(
|
||||
"DEFAULT",
|
||||
config=True,
|
||||
help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)",
|
||||
help="Set the table printing style to any of prettytable's defined styles "
|
||||
"(currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)",
|
||||
)
|
||||
short_errors = Bool(
|
||||
True,
|
||||
|
@ -72,9 +72,9 @@ class SqlMagic(Magics, Configurable):
|
|||
"odbc.ini",
|
||||
config=True,
|
||||
help="Path to DSN file. "
|
||||
"When the first argument is of the form [section], "
|
||||
"a sqlalchemy connection string is formed from the "
|
||||
"matching section in the DSN file.",
|
||||
"When the first argument is of the form [section], "
|
||||
"a sqlalchemy connection string is formed from the "
|
||||
"matching section in the DSN file.",
|
||||
)
|
||||
autocommit = Bool(True, config=True, help="Set autocommit mode")
|
||||
|
||||
|
@ -82,7 +82,7 @@ class SqlMagic(Magics, Configurable):
|
|||
Configurable.__init__(self, config=shell.config)
|
||||
Magics.__init__(self, shell=shell)
|
||||
|
||||
# Add ourself to the list of module configurable via %config
|
||||
# Add ourselves to the list of module configurable via %config
|
||||
self.shell.configurables.append(self)
|
||||
|
||||
@needs_local_scope
|
||||
|
@ -121,7 +121,7 @@ class SqlMagic(Magics, Configurable):
|
|||
help="specify dictionary of connection arguments to pass to SQL driver",
|
||||
)
|
||||
@argument("-f", "--file", type=str, help="Run SQL from file at this path")
|
||||
def execute(self, line="", cell="", local_ns={}):
|
||||
def execute(self, line="", cell="", local_ns=None):
|
||||
"""Runs SQL statement against a database, specified by SQLAlchemy connect string.
|
||||
|
||||
If no database connection has been established, first word
|
||||
|
@ -147,15 +147,17 @@ class SqlMagic(Magics, Configurable):
|
|||
|
||||
"""
|
||||
# Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically)
|
||||
if local_ns is None:
|
||||
local_ns = {}
|
||||
cell = self.shell.var_expand(cell)
|
||||
line = sql.parse.without_sql_comment(parser=self.execute.parser, line=line)
|
||||
args = parse_argstring(self.execute, line)
|
||||
if args.connections:
|
||||
return sql.connection.Connection.connections
|
||||
elif args.close:
|
||||
return sql.connection.Connection._close(args.close)
|
||||
return sql.connection.Connection.close(args.close)
|
||||
|
||||
# save globals and locals so they can be referenced in bind vars
|
||||
# save globals and locals, so they can be referenced in bind vars
|
||||
user_ns = self.shell.user_ns.copy()
|
||||
user_ns.update(local_ns)
|
||||
|
||||
|
@ -173,7 +175,7 @@ class SqlMagic(Magics, Configurable):
|
|||
|
||||
if args.connection_arguments:
|
||||
try:
|
||||
# check for string deliniators, we need to strip them for json parse
|
||||
# check for string delineators, we need to strip them for json parse
|
||||
raw_args = args.connection_arguments
|
||||
if len(raw_args) > 1:
|
||||
targets = ['"', "'"]
|
||||
|
@ -183,7 +185,7 @@ class SqlMagic(Magics, Configurable):
|
|||
raw_args = raw_args[1:-1]
|
||||
args.connection_arguments = json.loads(raw_args)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
else:
|
||||
args.connection_arguments = {}
|
||||
|
@ -197,8 +199,10 @@ class SqlMagic(Magics, Configurable):
|
|||
connect_args=args.connection_arguments,
|
||||
creator=args.creator,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Rollback just in case there was an error in previous statement
|
||||
conn.internal_connection.rollback()
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
print(sql.connection.Connection.tell_format())
|
||||
return None
|
||||
|
||||
|
@ -220,7 +224,7 @@ class SqlMagic(Magics, Configurable):
|
|||
and self.column_local_vars
|
||||
):
|
||||
# Instead of returning values, set variables directly in the
|
||||
# users namespace. Variable names given by column names
|
||||
# user's namespace. Variable names given by column names
|
||||
|
||||
if self.autopandas:
|
||||
keys = result.keys()
|
||||
|
@ -253,7 +257,8 @@ class SqlMagic(Magics, Configurable):
|
|||
if self.short_errors:
|
||||
print(e)
|
||||
else:
|
||||
raise
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+")
|
||||
|
||||
|
@ -279,7 +284,7 @@ class SqlMagic(Magics, Configurable):
|
|||
table_name = self.legal_sql_identifier.search(table_name).group(0)
|
||||
|
||||
if_exists = "append" if append else "fail"
|
||||
frame.to_sql(table_name, conn.session.engine, if_exists=if_exists)
|
||||
frame.to_sql(table_name, conn.internal_connection.engine, if_exists=if_exists)
|
||||
return "Persisted %s" % table_name
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import itertools
|
||||
import json
|
||||
import re
|
||||
import shlex
|
||||
from os.path import expandvars
|
||||
|
||||
import six
|
||||
from six.moves import configparser as CP
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
@ -17,7 +14,6 @@ def connection_from_dsn_section(section, config):
|
|||
|
||||
|
||||
def _connection_string(s, config):
|
||||
|
||||
s = expandvars(s) # for environment variables
|
||||
if "@" in s or "://" in s:
|
||||
return s
|
||||
|
|
|
@ -3,6 +3,7 @@ import csv
|
|||
import operator
|
||||
import os.path
|
||||
import re
|
||||
import traceback
|
||||
from functools import reduce
|
||||
|
||||
import prettytable
|
||||
|
@ -103,21 +104,16 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
Can access rows listwise, or by string value of leftmost column.
|
||||
"""
|
||||
|
||||
def __init__(self, sqlaproxy, sql, config):
|
||||
self.keys = sqlaproxy.keys()
|
||||
self.sql = sql
|
||||
def __init__(self, sqlaproxy, config):
|
||||
self.config = config
|
||||
self.limit = config.autolimit
|
||||
style_name = config.style
|
||||
self.style = prettytable.__dict__[style_name.upper()]
|
||||
if sqlaproxy.returns_rows:
|
||||
if self.limit:
|
||||
list.__init__(self, sqlaproxy.fetchmany(size=self.limit))
|
||||
self.keys = sqlaproxy.keys()
|
||||
if config.autolimit:
|
||||
list.__init__(self, sqlaproxy.fetchmany(size=config.autolimit))
|
||||
else:
|
||||
list.__init__(self, sqlaproxy.fetchall())
|
||||
self.field_names = unduplicate_field_names(self.keys)
|
||||
self.pretty = PrettyTable(self.field_names, style=self.style)
|
||||
# self.pretty.set_style(self.style)
|
||||
self.pretty = PrettyTable(self.field_names, style=prettytable.__dict__[config.style.upper()])
|
||||
else:
|
||||
list.__init__(self, [])
|
||||
self.pretty = None
|
||||
|
@ -163,12 +159,12 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
return dict(zip(self.keys, zip(*self)))
|
||||
|
||||
def dicts(self):
|
||||
"Iterator yielding a dict for each row"
|
||||
"""Iterator yielding a dict for each row"""
|
||||
for row in self:
|
||||
yield dict(zip(self.keys, row))
|
||||
|
||||
def DataFrame(self):
|
||||
"Returns a Pandas DataFrame instance built from the result set."
|
||||
"""Returns a Pandas DataFrame instance built from the result set."""
|
||||
import pandas as pd
|
||||
|
||||
frame = pd.DataFrame(self, columns=(self and self.keys) or [])
|
||||
|
@ -315,7 +311,7 @@ class FakeResultProxy(object):
|
|||
self.returns_rows = True
|
||||
|
||||
def from_list(self, source_list):
|
||||
"Simulates SQLA ResultProxy from a list."
|
||||
"""Simulates SQLA ResultProxy from a list."""
|
||||
|
||||
self.fetchall = lambda: source_list
|
||||
self.rowcount = len(source_list)
|
||||
|
@ -323,7 +319,7 @@ class FakeResultProxy(object):
|
|||
def fetchmany(size):
|
||||
pos = 0
|
||||
while pos < len(source_list):
|
||||
yield source_list[pos : pos + size]
|
||||
yield source_list[pos: pos + size]
|
||||
pos += size
|
||||
|
||||
self.fetchmany = fetchmany
|
||||
|
@ -344,9 +340,13 @@ def _commit(conn, config):
|
|||
|
||||
if _should_commit:
|
||||
try:
|
||||
conn.session.execute("commit")
|
||||
conn.internal_connection.commit()
|
||||
except sqlalchemy.exc.OperationalError:
|
||||
pass # not all engines can commit
|
||||
except Exception as ex:
|
||||
conn.internal_connection.rollback()
|
||||
print(traceback.format_exc())
|
||||
raise ex
|
||||
|
||||
|
||||
def run(conn, sql, config, user_namespace):
|
||||
|
@ -356,22 +356,22 @@ def run(conn, sql, config, user_namespace):
|
|||
if first_word == "begin":
|
||||
raise Exception("ipython_sql does not support transactions")
|
||||
if first_word.startswith("\\") and \
|
||||
("postgres" in str(conn.dialect) or \
|
||||
"redshift" in str(conn.dialect)):
|
||||
("postgres" in str(conn.dialect) or
|
||||
"redshift" in str(conn.dialect)):
|
||||
if not PGSpecial:
|
||||
raise ImportError("pgspecial not installed")
|
||||
pgspecial = PGSpecial()
|
||||
_, cur, headers, _ = pgspecial.execute(
|
||||
conn.session.connection.cursor(), statement
|
||||
conn.internal_connection.connection.cursor(), statement
|
||||
)[0]
|
||||
result = FakeResultProxy(cur, headers)
|
||||
else:
|
||||
txt = sqlalchemy.sql.text(statement)
|
||||
result = conn.session.execute(txt, user_namespace)
|
||||
result = conn.internal_connection.execute(txt, user_namespace)
|
||||
_commit(conn=conn, config=config)
|
||||
if result and config.feedback:
|
||||
print(interpret_rowcount(result.rowcount))
|
||||
resultset = ResultSet(result, statement, config)
|
||||
resultset = ResultSet(result, config)
|
||||
if config.autopandas:
|
||||
return resultset.DataFrame()
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import re
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from sql.magic import SqlMagic
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
[DB_CONFIG_1]
|
||||
drivername=postgres
|
||||
host=my.remote.host
|
||||
port=5432
|
||||
database=pgmain
|
||||
username=goesto11
|
||||
password=seentheelephant
|
||||
drivername = postgres
|
||||
host = my.remote.host
|
||||
port = 5432
|
||||
database = pgmain
|
||||
username = goesto11
|
||||
password = seentheelephant
|
||||
|
||||
[DB_CONFIG_2]
|
||||
drivername = mysql
|
||||
host = 127.0.0.1
|
||||
database = dolfin
|
||||
username = thefin
|
||||
password = fishputsfishonthetable
|
||||
drivername = mysql
|
||||
host = 127.0.0.1
|
||||
database = dolfin
|
||||
username = thefin
|
||||
password = fishputsfishonthetable
|
||||
|
|
|
@ -5,8 +5,6 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from sql.magic import SqlMagic
|
||||
|
||||
|
||||
def runsql(ip_session, statements):
|
||||
if isinstance(statements, str):
|
||||
|
@ -321,7 +319,6 @@ def test_dicts(ip):
|
|||
|
||||
|
||||
def test_bracket_var_substitution(ip):
|
||||
|
||||
ip.user_global_ns["col"] = "first_name"
|
||||
assert runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")[0] == (
|
||||
u"William",
|
||||
|
@ -335,7 +332,6 @@ def test_bracket_var_substitution(ip):
|
|||
|
||||
|
||||
def test_multiline_bracket_var_substitution(ip):
|
||||
|
||||
ip.user_global_ns["col"] = "first_name"
|
||||
assert runsql(ip, "SELECT * FROM author\n" " WHERE {col} = 'William' ")[0] == (
|
||||
u"William",
|
||||
|
@ -370,7 +366,7 @@ def test_multiline_bracket_var_substitution(ip):
|
|||
""",
|
||||
)
|
||||
assert not result
|
||||
|
||||
|
||||
|
||||
def test_json_in_select(ip):
|
||||
# Variable expansion does not work within json, but
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from six.moves import configparser
|
||||
|
||||
from sql.parse import connection_from_dsn_section, parse, without_sql_comment
|
||||
|
||||
try:
|
||||
|
@ -112,7 +109,6 @@ class DummyConfig:
|
|||
|
||||
|
||||
def test_connection_from_dsn_section():
|
||||
|
||||
result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig())
|
||||
assert result == "postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain"
|
||||
result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig())
|
||||
|
@ -143,54 +139,46 @@ parser_stub = ParserStub()
|
|||
|
||||
|
||||
def test_without_sql_comment_plain():
|
||||
|
||||
line = "SELECT * FROM author"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == line
|
||||
|
||||
|
||||
def test_without_sql_comment_with_arg():
|
||||
|
||||
line = "--file moo.txt --persist SELECT * FROM author"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == line
|
||||
|
||||
|
||||
def test_without_sql_comment_with_comment():
|
||||
|
||||
line = "SELECT * FROM author -- uff da"
|
||||
expected = "SELECT * FROM author"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == expected
|
||||
|
||||
|
||||
def test_without_sql_comment_with_arg_and_comment():
|
||||
|
||||
line = "--file moo.txt --persist SELECT * FROM author -- uff da"
|
||||
expected = "--file moo.txt --persist SELECT * FROM author"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == expected
|
||||
|
||||
|
||||
def test_without_sql_comment_unspaced_comment():
|
||||
|
||||
line = "SELECT * FROM author --uff da"
|
||||
expected = "SELECT * FROM author"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == expected
|
||||
|
||||
|
||||
def test_without_sql_comment_dashes_in_string():
|
||||
|
||||
line = "SELECT '--very --confusing' FROM author -- uff da"
|
||||
expected = "SELECT '--very --confusing' FROM author"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == expected
|
||||
|
||||
|
||||
def test_without_sql_comment_with_arg_and_leading_comment():
|
||||
|
||||
line = "--file moo.txt --persist --comment, not arg"
|
||||
expected = "--file moo.txt --persist"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == expected
|
||||
|
||||
|
||||
def test_without_sql_persist():
|
||||
|
||||
line = "--persist my_table --uff da"
|
||||
expected = "--persist my_table"
|
||||
assert without_sql_comment(parser=parser_stub, line=line) == expected
|
||||
|
|
Loading…
Reference in New Issue