applied black
parent
ba21a75616
commit
2af7e344fb
|
@ -2,7 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
|
@ -16,7 +16,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
|
|
|
@ -4,4 +4,6 @@ pytest
|
|||
wheel
|
||||
twine
|
||||
readme-renderer
|
||||
black
|
||||
isort
|
||||
|
||||
|
|
|
@ -4,26 +4,30 @@ makes guesses about the role of each column for plotting purposes
|
|||
(X values, Y values, and text labels).
|
||||
"""
|
||||
|
||||
|
||||
class Column(list):
|
||||
'Store a column of tabular data; record its name and whether it is numeric'
|
||||
"Store a column of tabular data; record its name and whether it is numeric"
|
||||
is_quantity = True
|
||||
name = ''
|
||||
name = ""
|
||||
|
||||
def __init__(self, *arg, **kwarg):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def is_quantity(val):
|
||||
"""Is ``val`` a quantity (int, float, datetime, etc) (not str, bool)?
|
||||
|
||||
Relies on presence of __sub__.
|
||||
"""
|
||||
return hasattr(val, '__sub__')
|
||||
return hasattr(val, "__sub__")
|
||||
|
||||
|
||||
class ColumnGuesserMixin(object):
|
||||
"""
|
||||
plot: [x, y, y...], y
|
||||
pie: ... y
|
||||
"""
|
||||
|
||||
def _build_columns(self):
|
||||
self.columns = [Column() for col in self.keys]
|
||||
for row in self:
|
||||
|
@ -32,39 +36,40 @@ class ColumnGuesserMixin(object):
|
|||
col.append(col_val)
|
||||
if (col_val is not None) and (not is_quantity(col_val)):
|
||||
col.is_quantity = False
|
||||
|
||||
|
||||
for (idx, key_name) in enumerate(self.keys):
|
||||
self.columns[idx].name = key_name
|
||||
|
||||
|
||||
self.x = Column()
|
||||
self.ys = []
|
||||
|
||||
|
||||
def _get_y(self):
|
||||
for idx in range(len(self.columns)-1,-1,-1):
|
||||
for idx in range(len(self.columns) - 1, -1, -1):
|
||||
if self.columns[idx].is_quantity:
|
||||
self.ys.insert(0, self.columns.pop(idx))
|
||||
return True
|
||||
|
||||
def _get_x(self):
|
||||
def _get_x(self):
|
||||
for idx in range(len(self.columns)):
|
||||
if self.columns[idx].is_quantity:
|
||||
self.x = self.columns.pop(idx)
|
||||
return True
|
||||
|
||||
|
||||
def _get_xlabel(self, xlabel_sep=" "):
|
||||
self.xlabels = []
|
||||
if self.columns:
|
||||
for row_idx in range(len(self.columns[0])):
|
||||
self.xlabels.append(xlabel_sep.join(
|
||||
str(c[row_idx]) for c in self.columns))
|
||||
self.xlabels.append(
|
||||
xlabel_sep.join(str(c[row_idx]) for c in self.columns)
|
||||
)
|
||||
self.xlabel = ", ".join(c.name for c in self.columns)
|
||||
|
||||
|
||||
def _guess_columns(self):
|
||||
self._build_columns()
|
||||
self._get_y()
|
||||
if not self.ys:
|
||||
raise AttributeError("No quantitative columns found for chart")
|
||||
|
||||
|
||||
def guess_pie_columns(self, xlabel_sep=" "):
|
||||
"""
|
||||
Assigns x, y, and x labels from the data set for a pie chart.
|
||||
|
@ -75,7 +80,7 @@ class ColumnGuesserMixin(object):
|
|||
"""
|
||||
self._guess_columns()
|
||||
self._get_xlabel(xlabel_sep)
|
||||
|
||||
|
||||
def guess_plot_columns(self):
|
||||
"""
|
||||
Assigns ``x`` and ``y`` series from the data set for a plot.
|
||||
|
@ -88,4 +93,4 @@ class ColumnGuesserMixin(object):
|
|||
self._guess_columns()
|
||||
self._get_x()
|
||||
while self._get_y():
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
import sqlalchemy
|
||||
import os
|
||||
import re
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def rough_dict_get(dct, sought, default=None):
|
||||
'''
|
||||
"""
|
||||
Like dct.get(sought), but any key containing sought will do.
|
||||
|
||||
If there is a `@` in sought, seek each piece separately.
|
||||
This lets `me@server` match `me:***@myserver/db`
|
||||
'''
|
||||
|
||||
sought = sought.split('@')
|
||||
"""
|
||||
|
||||
sought = sought.split("@")
|
||||
for (key, val) in dct.items():
|
||||
if not any(s.lower() not in key.lower() for s in sought):
|
||||
return val
|
||||
|
@ -29,15 +31,21 @@ class Connection(object):
|
|||
def tell_format(cls):
|
||||
return """Connection info needed in SQLAlchemy format, example:
|
||||
postgresql://username:password@hostname/dbname
|
||||
or an existing connection: %s""" % str(cls.connections.keys())
|
||||
or an existing connection: %s""" % str(
|
||||
cls.connections.keys()
|
||||
)
|
||||
|
||||
def __init__(self, connect_str=None, connect_args={}, creator=None):
|
||||
try:
|
||||
if creator:
|
||||
engine = sqlalchemy.create_engine(connect_str, connect_args=connect_args, creator=creator)
|
||||
engine = sqlalchemy.create_engine(
|
||||
connect_str, connect_args=connect_args, creator=creator
|
||||
)
|
||||
else:
|
||||
engine = sqlalchemy.create_engine(connect_str, connect_args=connect_args)
|
||||
except: # TODO: bare except; but what's an ArgumentError?
|
||||
engine = sqlalchemy.create_engine(
|
||||
connect_str, connect_args=connect_args
|
||||
)
|
||||
except: # TODO: bare except; but what's an ArgumentError?
|
||||
print(self.tell_format())
|
||||
raise
|
||||
self.dialect = engine.url.get_dialect()
|
||||
|
@ -65,42 +73,50 @@ class Connection(object):
|
|||
if displaycon:
|
||||
print(cls.connection_list())
|
||||
else:
|
||||
if os.getenv('DATABASE_URL'):
|
||||
cls.current = Connection(os.getenv('DATABASE_URL'), connect_args, creator)
|
||||
if os.getenv("DATABASE_URL"):
|
||||
cls.current = Connection(
|
||||
os.getenv("DATABASE_URL"), connect_args, creator
|
||||
)
|
||||
else:
|
||||
raise ConnectionError('Environment variable $DATABASE_URL not set, and no connect string given.')
|
||||
raise ConnectionError(
|
||||
"Environment variable $DATABASE_URL not set, and no connect string given."
|
||||
)
|
||||
return cls.current
|
||||
|
||||
@classmethod
|
||||
def assign_name(cls, engine):
|
||||
name = '%s@%s' % (engine.url.username or '', engine.url.database)
|
||||
name = "%s@%s" % (engine.url.username or "", engine.url.database)
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def connection_list(cls):
|
||||
result = []
|
||||
for key in sorted(cls.connections):
|
||||
engine_url = cls.connections[key].metadata.bind.url # type: sqlalchemy.engine.url.URL
|
||||
engine_url = cls.connections[
|
||||
key
|
||||
].metadata.bind.url # type: sqlalchemy.engine.url.URL
|
||||
if cls.connections[key] == cls.current:
|
||||
template = ' * {}'
|
||||
template = " * {}"
|
||||
else:
|
||||
template = ' {}'
|
||||
template = " {}"
|
||||
result.append(template.format(engine_url.__repr__()))
|
||||
return '\n'.join(result)
|
||||
return "\n".join(result)
|
||||
|
||||
def _close(cls, descriptor):
|
||||
if isinstance(descriptor, Connection):
|
||||
conn = descriptor
|
||||
else:
|
||||
conn = cls.connections.get(descriptor) or \
|
||||
cls.connections.get(descriptor.lower())
|
||||
conn = cls.connections.get(descriptor) or cls.connections.get(
|
||||
descriptor.lower()
|
||||
)
|
||||
if not conn:
|
||||
raise Exception("Could not close connection because it was not found amongst these: %s" \
|
||||
%str(cls.connections.keys()))
|
||||
raise Exception(
|
||||
"Could not close connection because it was not found amongst these: %s"
|
||||
% str(cls.connections.keys())
|
||||
)
|
||||
cls.connections.pop(conn.name)
|
||||
cls.connections.pop(str(conn.metadata.bind.url))
|
||||
conn.session.close()
|
||||
|
||||
def close(self):
|
||||
self.__class__._close(self)
|
||||
|
||||
|
||||
|
|
177
src/sql/magic.py
177
src/sql/magic.py
|
@ -1,26 +1,30 @@
|
|||
import json
|
||||
import re
|
||||
from string import Formatter
|
||||
from IPython.core.magic import Magics, magics_class, cell_magic, line_magic, needs_local_scope
|
||||
|
||||
from IPython.core.magic import (Magics, cell_magic, line_magic, magics_class,
|
||||
needs_local_scope)
|
||||
from IPython.core.magic_arguments import (argument, magic_arguments,
|
||||
parse_argstring)
|
||||
from IPython.display import display_javascript
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
|
||||
import sql.connection
|
||||
import sql.parse
|
||||
import sql.run
|
||||
|
||||
try:
|
||||
from traitlets.config.configurable import Configurable
|
||||
from traitlets import Bool, Int, Unicode
|
||||
except ImportError:
|
||||
from IPython.config.configurable import Configurable
|
||||
from IPython.utils.traitlets import Bool, Int, Unicode
|
||||
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
|
||||
try:
|
||||
from pandas.core.frame import DataFrame, Series
|
||||
except ImportError:
|
||||
DataFrame = None
|
||||
Series = None
|
||||
|
||||
from sqlalchemy.exc import ProgrammingError, OperationalError
|
||||
|
||||
import sql.connection
|
||||
import sql.parse
|
||||
import sql.run
|
||||
|
||||
@magics_class
|
||||
class SqlMagic(Magics, Configurable):
|
||||
|
@ -29,20 +33,47 @@ class SqlMagic(Magics, Configurable):
|
|||
Provides the %%sql magic."""
|
||||
|
||||
displaycon = Bool(True, config=True, help="Show connection string after execute")
|
||||
autolimit = Int(0, config=True, allow_none=True, help="Automatically limit the size of the returned result sets")
|
||||
style = Unicode('DEFAULT', config=True, help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)")
|
||||
short_errors = Bool(True, config=True, help="Don't display the full traceback on SQL Programming Error")
|
||||
displaylimit = Int(None, config=True, allow_none=True, help="Automatically limit the number of rows displayed (full result set is still stored)")
|
||||
autopandas = Bool(False, config=True, help="Return Pandas DataFrames instead of regular result sets")
|
||||
column_local_vars = Bool(False, config=True, help="Return data into local variables from column names")
|
||||
autolimit = Int(
|
||||
0,
|
||||
config=True,
|
||||
allow_none=True,
|
||||
help="Automatically limit the size of the returned result sets",
|
||||
)
|
||||
style = Unicode(
|
||||
"DEFAULT",
|
||||
config=True,
|
||||
help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)",
|
||||
)
|
||||
short_errors = Bool(
|
||||
True,
|
||||
config=True,
|
||||
help="Don't display the full traceback on SQL Programming Error",
|
||||
)
|
||||
displaylimit = Int(
|
||||
None,
|
||||
config=True,
|
||||
allow_none=True,
|
||||
help="Automatically limit the number of rows displayed (full result set is still stored)",
|
||||
)
|
||||
autopandas = Bool(
|
||||
False,
|
||||
config=True,
|
||||
help="Return Pandas DataFrames instead of regular result sets",
|
||||
)
|
||||
column_local_vars = Bool(
|
||||
False, config=True, help="Return data into local variables from column names"
|
||||
)
|
||||
feedback = Bool(True, config=True, help="Print number of rows affected by DML")
|
||||
dsn_filename = Unicode('odbc.ini', config=True, help="Path to DSN file. "
|
||||
"When the first argument is of the form [section], "
|
||||
"a sqlalchemy connection string is formed from the "
|
||||
"matching section in the DSN file.")
|
||||
dsn_filename = Unicode(
|
||||
"odbc.ini",
|
||||
config=True,
|
||||
help="Path to DSN file. "
|
||||
"When the first argument is of the form [section], "
|
||||
"a sqlalchemy connection string is formed from the "
|
||||
"matching section in the DSN file.",
|
||||
)
|
||||
autocommit = Bool(True, config=True, help="Set autocommit mode")
|
||||
|
||||
|
||||
def __init__(self, shell):
|
||||
Configurable.__init__(self, config=shell.config)
|
||||
Magics.__init__(self, shell=shell)
|
||||
|
@ -51,19 +82,42 @@ class SqlMagic(Magics, Configurable):
|
|||
self.shell.configurables.append(self)
|
||||
|
||||
@needs_local_scope
|
||||
@line_magic('sql')
|
||||
@cell_magic('sql')
|
||||
@line_magic("sql")
|
||||
@cell_magic("sql")
|
||||
@magic_arguments()
|
||||
@argument('line', default='', nargs='*', type=str, help='sql')
|
||||
@argument('-l', '--connections', action='store_true', help="list active connections")
|
||||
@argument('-x', '--close', type=str, help="close a session by name")
|
||||
@argument('-c', '--creator', type=str, help="specify creator function for new connection")
|
||||
@argument('-s', '--section', type=str, help="section of dsn_file to be used for generating a connection string")
|
||||
@argument('-p', '--persist', action='store_true', help="create a table name in the database from the named DataFrame")
|
||||
@argument('--append', action='store_true', help="create, or append to, a table name in the database from the named DataFrame")
|
||||
@argument('-a', '--connection_arguments', type=str, help="specify dictionary of connection arguments to pass to SQL driver")
|
||||
@argument('-f', '--file', type=str, help="Run SQL from file at this path")
|
||||
def execute(self, line='', cell='', local_ns={}):
|
||||
@argument("line", default="", nargs="*", type=str, help="sql")
|
||||
@argument(
|
||||
"-l", "--connections", action="store_true", help="list active connections"
|
||||
)
|
||||
@argument("-x", "--close", type=str, help="close a session by name")
|
||||
@argument(
|
||||
"-c", "--creator", type=str, help="specify creator function for new connection"
|
||||
)
|
||||
@argument(
|
||||
"-s",
|
||||
"--section",
|
||||
type=str,
|
||||
help="section of dsn_file to be used for generating a connection string",
|
||||
)
|
||||
@argument(
|
||||
"-p",
|
||||
"--persist",
|
||||
action="store_true",
|
||||
help="create a table name in the database from the named DataFrame",
|
||||
)
|
||||
@argument(
|
||||
"--append",
|
||||
action="store_true",
|
||||
help="create, or append to, a table name in the database from the named DataFrame",
|
||||
)
|
||||
@argument(
|
||||
"-a",
|
||||
"--connection_arguments",
|
||||
type=str,
|
||||
help="specify dictionary of connection arguments to pass to SQL driver",
|
||||
)
|
||||
@argument("-f", "--file", type=str, help="Run SQL from file at this path")
|
||||
def execute(self, line="", cell="", local_ns={}):
|
||||
"""Runs SQL statement against a database, specified by SQLAlchemy connect string.
|
||||
|
||||
If no database connection has been established, first word
|
||||
|
@ -89,7 +143,9 @@ class SqlMagic(Magics, Configurable):
|
|||
|
||||
"""
|
||||
# Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically)
|
||||
cell_variables = [fn for _, fn, _, _ in Formatter().parse(cell) if fn is not None]
|
||||
cell_variables = [
|
||||
fn for _, fn, _, _ in Formatter().parse(cell) if fn is not None
|
||||
]
|
||||
cell_params = {}
|
||||
for variable in cell_variables:
|
||||
cell_params[variable] = local_ns[variable]
|
||||
|
@ -105,15 +161,15 @@ class SqlMagic(Magics, Configurable):
|
|||
user_ns = self.shell.user_ns.copy()
|
||||
user_ns.update(local_ns)
|
||||
|
||||
command_text = ' '.join(args.line) + '\n' + cell
|
||||
command_text = " ".join(args.line) + "\n" + cell
|
||||
|
||||
if args.file:
|
||||
with open(args.file, 'r') as infile:
|
||||
command_text = infile.read() + "\n" + command_text
|
||||
with open(args.file, "r") as infile:
|
||||
command_text = infile.read() + "\n" + command_text
|
||||
|
||||
parsed = sql.parse.parse(command_text, self)
|
||||
parsed = sql.parse.parse(command_text, self)
|
||||
|
||||
connect_str = parsed['connection']
|
||||
connect_str = parsed["connection"]
|
||||
if args.section:
|
||||
connect_str = sql.parse.connection_from_dsn_section(args.section, self)
|
||||
|
||||
|
@ -137,27 +193,36 @@ class SqlMagic(Magics, Configurable):
|
|||
args.creator = user_ns[args.creator]
|
||||
|
||||
try:
|
||||
conn = sql.connection.Connection.set(parsed['connection'], displaycon=self.displaycon, connect_args=args.connection_arguments, creator=args.creator)
|
||||
conn = sql.connection.Connection.set(
|
||||
parsed["connection"],
|
||||
displaycon=self.displaycon,
|
||||
connect_args=args.connection_arguments,
|
||||
creator=args.creator,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(sql.connection.Connection.tell_format())
|
||||
return None
|
||||
|
||||
if args.persist:
|
||||
return self._persist_dataframe(parsed['sql'], conn, user_ns, append=False)
|
||||
return self._persist_dataframe(parsed["sql"], conn, user_ns, append=False)
|
||||
|
||||
if args.append:
|
||||
return self._persist_dataframe(parsed['sql'], conn, user_ns, append=True)
|
||||
return self._persist_dataframe(parsed["sql"], conn, user_ns, append=True)
|
||||
|
||||
if not parsed['sql']:
|
||||
if not parsed["sql"]:
|
||||
return
|
||||
|
||||
try:
|
||||
result = sql.run.run(conn, parsed['sql'], self, user_ns)
|
||||
result = sql.run.run(conn, parsed["sql"], self, user_ns)
|
||||
|
||||
if result is not None and not isinstance(result, str) and self.column_local_vars:
|
||||
#Instead of returning values, set variables directly in the
|
||||
#users namespace. Variable names given by column names
|
||||
if (
|
||||
result is not None
|
||||
and not isinstance(result, str)
|
||||
and self.column_local_vars
|
||||
):
|
||||
# Instead of returning values, set variables directly in the
|
||||
# users namespace. Variable names given by column names
|
||||
|
||||
if self.autopandas:
|
||||
keys = result.keys()
|
||||
|
@ -166,21 +231,22 @@ class SqlMagic(Magics, Configurable):
|
|||
result = result.dict()
|
||||
|
||||
if self.feedback:
|
||||
print('Returning data to local variables [{}]'.format(
|
||||
', '.join(keys)))
|
||||
print(
|
||||
"Returning data to local variables [{}]".format(", ".join(keys))
|
||||
)
|
||||
|
||||
self.shell.user_ns.update(result)
|
||||
|
||||
return None
|
||||
else:
|
||||
|
||||
if parsed['result_var']:
|
||||
result_var = parsed['result_var']
|
||||
if parsed["result_var"]:
|
||||
result_var = parsed["result_var"]
|
||||
print("Returning data to local variable {}".format(result_var))
|
||||
self.shell.user_ns.update({result_var: result})
|
||||
return None
|
||||
|
||||
#Return results into the default ipython _ variable
|
||||
# Return results into the default ipython _ variable
|
||||
return result
|
||||
|
||||
except (ProgrammingError, OperationalError) as e:
|
||||
|
@ -190,28 +256,29 @@ class SqlMagic(Magics, Configurable):
|
|||
else:
|
||||
raise
|
||||
|
||||
legal_sql_identifier = re.compile(r'^[A-Za-z0-9#_$]+')
|
||||
legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+")
|
||||
|
||||
def _persist_dataframe(self, raw, conn, user_ns, append=False):
|
||||
"""Implements PERSIST, which writes a DataFrame to the RDBMS"""
|
||||
if not DataFrame:
|
||||
raise ImportError("Must `pip install pandas` to use DataFrames")
|
||||
|
||||
frame_name = raw.strip(';')
|
||||
frame_name = raw.strip(";")
|
||||
|
||||
# Get the DataFrame from the user namespace
|
||||
if not frame_name:
|
||||
raise SyntaxError('Syntax: %sql PERSIST <name_of_data_frame>')
|
||||
raise SyntaxError("Syntax: %sql PERSIST <name_of_data_frame>")
|
||||
frame = eval(frame_name, user_ns)
|
||||
if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
|
||||
raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name)
|
||||
raise TypeError("%s is not a Pandas DataFrame or Series" % frame_name)
|
||||
|
||||
# Make a suitable name for the resulting database table
|
||||
table_name = frame_name.lower()
|
||||
table_name = self.legal_sql_identifier.search(table_name).group(0)
|
||||
|
||||
if_exists = 'append' if append else 'fail'
|
||||
if_exists = "append" if append else "fail"
|
||||
frame.to_sql(table_name, conn.session.engine, if_exists=if_exists)
|
||||
return 'Persisted %s' % table_name
|
||||
return "Persisted %s" % table_name
|
||||
|
||||
|
||||
def load_ipython_extension(ip):
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import re
|
||||
from os.path import expandvars
|
||||
|
||||
import six
|
||||
from six.moves import configparser as CP
|
||||
from sqlalchemy.engine.url import URL
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def connection_from_dsn_section(section, config):
|
||||
parser = CP.ConfigParser()
|
||||
|
@ -11,19 +13,19 @@ def connection_from_dsn_section(section, config):
|
|||
cfg_dict = dict(parser.items(section))
|
||||
return str(URL(**cfg_dict))
|
||||
|
||||
|
||||
def _connection_string(s):
|
||||
|
||||
s = expandvars(s) # for environment variables
|
||||
if '@' in s or '://' in s:
|
||||
if "@" in s or "://" in s:
|
||||
return s
|
||||
if s.startswith('[') and s.endswith(']'):
|
||||
section = s.lstrip('[').rstrip(']')
|
||||
if s.startswith("[") and s.endswith("]"):
|
||||
section = s.lstrip("[").rstrip("]")
|
||||
parser = CP.ConfigParser()
|
||||
parser.read(config.dsn_filename)
|
||||
cfg_dict = dict(parser.items(section))
|
||||
return str(URL(**cfg_dict))
|
||||
return ''
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def parse(cell, config):
|
||||
|
@ -37,18 +39,18 @@ def parse(cell, config):
|
|||
connection string and `<<` operator in.
|
||||
"""
|
||||
|
||||
result = {'connection': '', 'sql': '', 'result_var': None}
|
||||
result = {"connection": "", "sql": "", "result_var": None}
|
||||
|
||||
pieces = cell.split(None, 3)
|
||||
if not pieces:
|
||||
return result
|
||||
result['connection'] = _connection_string(pieces[0])
|
||||
if result['connection']:
|
||||
result["connection"] = _connection_string(pieces[0])
|
||||
if result["connection"]:
|
||||
pieces.pop(0)
|
||||
if len(pieces) > 1 and pieces[1] == '<<':
|
||||
result['result_var'] = pieces.pop(0)
|
||||
if len(pieces) > 1 and pieces[1] == "<<":
|
||||
result["result_var"] = pieces.pop(0)
|
||||
pieces.pop(0) # discard << operator
|
||||
result['sql'] = (' '.join(pieces)).strip()
|
||||
result["sql"] = (" ".join(pieces)).strip()
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -24,9 +24,9 @@ def unduplicate_field_names(field_names):
|
|||
for k in field_names:
|
||||
if k in res:
|
||||
i = 1
|
||||
while k + '_' + str(i) in res:
|
||||
while k + "_" + str(i) in res:
|
||||
i += 1
|
||||
k += '_' + str(i)
|
||||
k += "_" + str(i)
|
||||
res.append(k)
|
||||
return res
|
||||
|
||||
|
@ -46,8 +46,7 @@ class UnicodeWriter(object):
|
|||
|
||||
def writerow(self, row):
|
||||
if six.PY2:
|
||||
_row = [s.encode("utf-8") if hasattr(s, "encode") else s
|
||||
for s in row]
|
||||
_row = [s.encode("utf-8") if hasattr(s, "encode") else s for s in row]
|
||||
else:
|
||||
_row = row
|
||||
self.writer.writerow(_row)
|
||||
|
@ -75,12 +74,12 @@ class CsvResultDescriptor(object):
|
|||
self.file_path = file_path
|
||||
|
||||
def __repr__(self):
|
||||
return 'CSV results at %s' % os.path.join(
|
||||
os.path.abspath('.'), self.file_path)
|
||||
return "CSV results at %s" % os.path.join(os.path.abspath("."), self.file_path)
|
||||
|
||||
def _repr_html_(self):
|
||||
return '<a href="%s">CSV results</a>' % os.path.join('.', 'files',
|
||||
self.file_path)
|
||||
return '<a href="%s">CSV results</a>' % os.path.join(
|
||||
".", "files", self.file_path
|
||||
)
|
||||
|
||||
|
||||
def _nonbreaking_spaces(match_obj):
|
||||
|
@ -90,11 +89,11 @@ def _nonbreaking_spaces(match_obj):
|
|||
Call with a ``re`` match object. Retain group 1, replace group 2
|
||||
with nonbreaking speaces.
|
||||
"""
|
||||
spaces = ' ' * len(match_obj.group(2))
|
||||
return '%s%s' % (match_obj.group(1), spaces)
|
||||
spaces = " " * len(match_obj.group(2))
|
||||
return "%s%s" % (match_obj.group(1), spaces)
|
||||
|
||||
|
||||
_cell_with_spaces_pattern = re.compile(r'(<td>)( {2,})')
|
||||
_cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
|
||||
|
||||
|
||||
class ResultSet(list, ColumnGuesserMixin):
|
||||
|
@ -124,22 +123,23 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
self.pretty = None
|
||||
|
||||
def _repr_html_(self):
|
||||
_cell_with_spaces_pattern = re.compile(r'(<td>)( {2,})')
|
||||
_cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
|
||||
if self.pretty:
|
||||
self.pretty.add_rows(self)
|
||||
result = self.pretty.get_html_string()
|
||||
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
|
||||
if self.config.displaylimit and len(
|
||||
self) > self.config.displaylimit:
|
||||
result = '%s\n<span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % (
|
||||
result, len(self), self.config.displaylimit)
|
||||
if self.config.displaylimit and len(self) > self.config.displaylimit:
|
||||
result = (
|
||||
'%s\n<span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>'
|
||||
% (result, len(self), self.config.displaylimit)
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def __str__(self, *arg, **kwarg):
|
||||
self.pretty.add_rows(self)
|
||||
return str(self.pretty or '')
|
||||
return str(self.pretty or "")
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""
|
||||
|
@ -170,6 +170,7 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
def DataFrame(self):
|
||||
"Returns a Pandas DataFrame instance built from the result set."
|
||||
import pandas as pd
|
||||
|
||||
frame = pd.DataFrame(self, columns=(self and self.keys) or [])
|
||||
return frame
|
||||
|
||||
|
@ -196,6 +197,7 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
"""
|
||||
self.guess_pie_columns(xlabel_sep=key_word_sep)
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
pie = plt.pie(self.ys[0], labels=self.xlabels, **kwargs)
|
||||
plt.title(title or self.ys[0].name)
|
||||
return pie
|
||||
|
@ -219,11 +221,12 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
through to ``matplotlib.pylab.plot``.
|
||||
"""
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
self.guess_plot_columns()
|
||||
self.x = self.x or range(len(self.ys[0]))
|
||||
coords = reduce(operator.add, [(self.x, y) for y in self.ys])
|
||||
plot = plt.plot(*coords, **kwargs)
|
||||
if hasattr(self.x, 'name'):
|
||||
if hasattr(self.x, "name"):
|
||||
plt.xlabel(self.x.name)
|
||||
ylabel = ", ".join(y.name for y in self.ys)
|
||||
plt.title(title or ylabel)
|
||||
|
@ -251,6 +254,7 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
through to ``matplotlib.pylab.bar``.
|
||||
"""
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
self.guess_pie_columns(xlabel_sep=key_word_sep)
|
||||
plot = plt.bar(range(len(self.ys[0])), self.ys[0], **kwargs)
|
||||
if self.xlabels:
|
||||
|
@ -266,11 +270,11 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
return None # no results
|
||||
self.pretty.add_rows(self)
|
||||
if filename:
|
||||
encoding = format_params.get('encoding', 'utf-8')
|
||||
encoding = format_params.get("encoding", "utf-8")
|
||||
if six.PY2:
|
||||
outfile = open(filename, 'wb')
|
||||
outfile = open(filename, "wb")
|
||||
else:
|
||||
outfile = open(filename, 'w', newline='', encoding=encoding)
|
||||
outfile = open(filename, "w", newline="", encoding=encoding)
|
||||
else:
|
||||
outfile = six.StringIO()
|
||||
writer = UnicodeWriter(outfile, **format_params)
|
||||
|
@ -286,9 +290,9 @@ class ResultSet(list, ColumnGuesserMixin):
|
|||
|
||||
def interpret_rowcount(rowcount):
|
||||
if rowcount < 0:
|
||||
result = 'Done.'
|
||||
result = "Done."
|
||||
else:
|
||||
result = '%d rows affected.' % rowcount
|
||||
result = "%d rows affected." % rowcount
|
||||
return result
|
||||
|
||||
|
||||
|
@ -313,34 +317,33 @@ class FakeResultProxy(object):
|
|||
def from_list(self, source_list):
|
||||
"Simulates SQLA ResultProxy from a list."
|
||||
|
||||
self.fetchall = lambda: source_list
|
||||
self.fetchall = lambda: source_list
|
||||
self.rowcount = len(source_list)
|
||||
|
||||
def fetchmany(size):
|
||||
pos = 0
|
||||
pos = 0
|
||||
while pos < len(source_list):
|
||||
yield source_list[pos:pos+size]
|
||||
pos += size
|
||||
yield source_list[pos : pos + size]
|
||||
pos += size
|
||||
|
||||
self.fetchmany = fetchmany
|
||||
|
||||
|
||||
|
||||
# some dialects have autocommit
|
||||
# specific dialects break when commit is used:
|
||||
_COMMIT_BLACKLIST_DIALECTS = ('mssql', 'clickhouse', 'teradata', 'athena')
|
||||
_COMMIT_BLACKLIST_DIALECTS = ("mssql", "clickhouse", "teradata", "athena")
|
||||
|
||||
|
||||
def _commit(conn, config):
|
||||
"""Issues a commit, if appropriate for current config and dialect"""
|
||||
|
||||
_should_commit = config.autocommit and all(
|
||||
dialect not in str(conn.dialect)
|
||||
for dialect in _COMMIT_BLACKLIST_DIALECTS)
|
||||
dialect not in str(conn.dialect) for dialect in _COMMIT_BLACKLIST_DIALECTS
|
||||
)
|
||||
|
||||
if _should_commit:
|
||||
try:
|
||||
conn.session.execute('commit')
|
||||
conn.session.execute("commit")
|
||||
except sqlalchemy.exc.OperationalError:
|
||||
pass # not all engines can commit
|
||||
|
||||
|
@ -349,14 +352,15 @@ def run(conn, sql, config, user_namespace):
|
|||
if sql.strip():
|
||||
for statement in sqlparse.split(sql):
|
||||
first_word = sql.strip().split()[0].lower()
|
||||
if first_word == 'begin':
|
||||
if first_word == "begin":
|
||||
raise Exception("ipython_sql does not support transactions")
|
||||
if first_word.startswith('\\') and 'postgres' in str(conn.dialect):
|
||||
if first_word.startswith("\\") and "postgres" in str(conn.dialect):
|
||||
if not PGSpecial:
|
||||
raise ImportError('pgspecial not installed')
|
||||
raise ImportError("pgspecial not installed")
|
||||
pgspecial = PGSpecial()
|
||||
_, cur, headers, _ = pgspecial.execute(
|
||||
conn.session.connection.cursor(), statement)[0]
|
||||
conn.session.connection.cursor(), statement
|
||||
)[0]
|
||||
result = FakeResultProxy(cur, headers)
|
||||
else:
|
||||
txt = sqlalchemy.sql.text(statement)
|
||||
|
@ -369,9 +373,9 @@ def run(conn, sql, config, user_namespace):
|
|||
return resultset.DataFrame()
|
||||
else:
|
||||
return resultset
|
||||
#returning only last result, intentionally
|
||||
# returning only last result, intentionally
|
||||
else:
|
||||
return 'Connected: %s' % conn.name
|
||||
return "Connected: %s" % conn.name
|
||||
|
||||
|
||||
class PrettyTable(prettytable.PrettyTable):
|
||||
|
@ -391,5 +395,5 @@ class PrettyTable(prettytable.PrettyTable):
|
|||
self.row_count = len(data)
|
||||
else:
|
||||
self.row_count = min(len(data), self.displaylimit)
|
||||
for row in data[:self.displaylimit]:
|
||||
for row in data[: self.displaylimit]:
|
||||
self.add_row(row)
|
||||
|
|
|
@ -13,10 +13,10 @@ class SqlEnv(object):
|
|||
self.connectstr = connectstr
|
||||
|
||||
def query(self, txt):
|
||||
return ip.run_line_magic('sql', "%s %s" % (self.connectstr, txt))
|
||||
return ip.run_line_magic("sql", "%s %s" % (self.connectstr, txt))
|
||||
|
||||
|
||||
sql_env = SqlEnv('sqlite://')
|
||||
sql_env = SqlEnv("sqlite://")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -54,14 +54,14 @@ class TestOneNum(Harness):
|
|||
assert results.ys == [[1.01, 2.01, 3.01]]
|
||||
assert results.x == []
|
||||
assert results.xlabels == []
|
||||
assert results.xlabel == ''
|
||||
assert results.xlabel == ""
|
||||
|
||||
def test_plot(self, tbl):
|
||||
results = self.run_query()
|
||||
results.guess_plot_columns()
|
||||
assert results.ys == [[1.01, 2.01, 3.01]]
|
||||
assert results.x == []
|
||||
assert results.x.name == ''
|
||||
assert results.x.name == ""
|
||||
|
||||
|
||||
class TestOneStrOneNum(Harness):
|
||||
|
@ -72,8 +72,8 @@ class TestOneStrOneNum(Harness):
|
|||
results.guess_pie_columns(xlabel_sep="//")
|
||||
assert results.ys[0].is_quantity
|
||||
assert results.ys == [[1.01, 2.01, 3.01]]
|
||||
assert results.xlabels == ['r1-txt1', 'r2-txt1', 'r3-txt1']
|
||||
assert results.xlabel == 'name'
|
||||
assert results.xlabels == ["r1-txt1", "r2-txt1", "r3-txt1"]
|
||||
assert results.xlabel == "name"
|
||||
|
||||
def test_plot(self, tbl):
|
||||
results = self.run_query()
|
||||
|
@ -91,10 +91,11 @@ class TestTwoStrTwoNum(Harness):
|
|||
assert results.ys[0].is_quantity
|
||||
assert results.ys == [[1.01, 2.01, 3.01]]
|
||||
assert results.xlabels == [
|
||||
'r1-txt2//1.04//r1-txt1', 'r2-txt2//2.04//r2-txt1',
|
||||
'r3-txt2//3.04//r3-txt1'
|
||||
"r1-txt2//1.04//r1-txt1",
|
||||
"r2-txt2//2.04//r2-txt1",
|
||||
"r3-txt2//3.04//r3-txt1",
|
||||
]
|
||||
assert results.xlabel == 'name2, y3, name'
|
||||
assert results.xlabel == "name2, y3, name"
|
||||
|
||||
def test_plot(self, tbl):
|
||||
results = self.run_query()
|
||||
|
@ -112,8 +113,9 @@ class TestTwoStrThreeNum(Harness):
|
|||
assert results.ys[0].is_quantity
|
||||
assert results.ys == [[1.04, 2.04, 3.04]]
|
||||
assert results.xlabels == [
|
||||
'r1-txt1//1.01//r1-txt2//1.02', 'r2-txt1//2.01//r2-txt2//2.02',
|
||||
'r3-txt1//3.01//r3-txt2//3.02'
|
||||
"r1-txt1//1.01//r1-txt2//1.02",
|
||||
"r2-txt1//2.01//r2-txt2//2.02",
|
||||
"r3-txt1//3.01//r3-txt2//3.02",
|
||||
]
|
||||
|
||||
def test_plot(self, tbl):
|
||||
|
|
|
@ -14,7 +14,7 @@ def runsql(ip_session, statements):
|
|||
statements,
|
||||
]
|
||||
for statement in statements:
|
||||
result = ip_session.run_line_magic('sql', 'sqlite:// %s' % statement)
|
||||
result = ip_session.run_line_magic("sql", "sqlite:// %s" % statement)
|
||||
return result # returns only last result
|
||||
|
||||
|
||||
|
@ -23,87 +23,108 @@ def ip():
|
|||
"""Provides an IPython session in which tables have been created"""
|
||||
|
||||
ip_session = get_ipython()
|
||||
runsql(ip_session, [
|
||||
"CREATE TABLE test (n INT, name TEXT)",
|
||||
"INSERT INTO test VALUES (1, 'foo')",
|
||||
"INSERT INTO test VALUES (2, 'bar')",
|
||||
"CREATE TABLE author (first_name, last_name, year_of_death)",
|
||||
"INSERT INTO author VALUES ('William', 'Shakespeare', 1616)",
|
||||
"INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)"
|
||||
])
|
||||
runsql(
|
||||
ip_session,
|
||||
[
|
||||
"CREATE TABLE test (n INT, name TEXT)",
|
||||
"INSERT INTO test VALUES (1, 'foo')",
|
||||
"INSERT INTO test VALUES (2, 'bar')",
|
||||
"CREATE TABLE author (first_name, last_name, year_of_death)",
|
||||
"INSERT INTO author VALUES ('William', 'Shakespeare', 1616)",
|
||||
"INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)",
|
||||
],
|
||||
)
|
||||
yield ip_session
|
||||
runsql(ip_session, 'DROP TABLE test')
|
||||
runsql(ip_session, 'DROP TABLE author')
|
||||
runsql(ip_session, "DROP TABLE test")
|
||||
runsql(ip_session, "DROP TABLE author")
|
||||
|
||||
|
||||
def test_memory_db(ip):
|
||||
assert runsql(ip, "SELECT * FROM test;")[0][0] == 1
|
||||
assert runsql(ip, "SELECT * FROM test;")[1]['name'] == 'bar'
|
||||
assert runsql(ip, "SELECT * FROM test;")[1]["name"] == "bar"
|
||||
|
||||
|
||||
def test_html(ip):
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
assert '<td>foo</td>' in result._repr_html_().lower()
|
||||
assert "<td>foo</td>" in result._repr_html_().lower()
|
||||
|
||||
|
||||
def test_print(ip):
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
assert re.search(r'1\s+\|\s+foo', str(result))
|
||||
assert re.search(r"1\s+\|\s+foo", str(result))
|
||||
|
||||
|
||||
def test_plain_style(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.style = 'PLAIN_COLUMNS'")
|
||||
ip.run_line_magic("config", "SqlMagic.style = 'PLAIN_COLUMNS'")
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
assert re.search(r'1\s+\|\s+foo', str(result))
|
||||
assert re.search(r"1\s+\|\s+foo", str(result))
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_multi_sql(ip):
|
||||
result = ip.run_cell_magic('sql', '', """
|
||||
result = ip.run_cell_magic(
|
||||
"sql",
|
||||
"",
|
||||
"""
|
||||
sqlite://
|
||||
SELECT last_name FROM author;
|
||||
""")
|
||||
assert 'Shakespeare' in str(result) and 'Brecht' in str(result)
|
||||
""",
|
||||
)
|
||||
assert "Shakespeare" in str(result) and "Brecht" in str(result)
|
||||
|
||||
|
||||
def test_result_var(ip):
|
||||
ip.run_cell_magic('sql', '', """
|
||||
ip.run_cell_magic(
|
||||
"sql",
|
||||
"",
|
||||
"""
|
||||
sqlite://
|
||||
x <<
|
||||
SELECT last_name FROM author;
|
||||
""")
|
||||
result = ip.user_global_ns['x']
|
||||
assert 'Shakespeare' in str(result) and 'Brecht' in str(result)
|
||||
""",
|
||||
)
|
||||
result = ip.user_global_ns["x"]
|
||||
assert "Shakespeare" in str(result) and "Brecht" in str(result)
|
||||
|
||||
|
||||
def test_result_var_multiline_shovel(ip):
|
||||
ip.run_cell_magic('sql', '', """
|
||||
ip.run_cell_magic(
|
||||
"sql",
|
||||
"",
|
||||
"""
|
||||
sqlite:// x << SELECT last_name
|
||||
FROM author;
|
||||
""")
|
||||
result = ip.user_global_ns['x']
|
||||
assert 'Shakespeare' in str(result) and 'Brecht' in str(result)
|
||||
""",
|
||||
)
|
||||
result = ip.user_global_ns["x"]
|
||||
assert "Shakespeare" in str(result) and "Brecht" in str(result)
|
||||
|
||||
|
||||
def test_access_results_by_keys(ip):
|
||||
assert runsql(ip,
|
||||
"SELECT * FROM author;")['William'] == (u'William',
|
||||
u'Shakespeare', 1616)
|
||||
assert runsql(ip, "SELECT * FROM author;")["William"] == (
|
||||
u"William",
|
||||
u"Shakespeare",
|
||||
1616,
|
||||
)
|
||||
|
||||
|
||||
def test_duplicate_column_names_accepted(ip):
|
||||
result = ip.run_cell_magic('sql', '', """
|
||||
result = ip.run_cell_magic(
|
||||
"sql",
|
||||
"",
|
||||
"""
|
||||
sqlite://
|
||||
SELECT last_name, last_name FROM author;
|
||||
""")
|
||||
assert (u'Brecht', u'Brecht') in result
|
||||
""",
|
||||
)
|
||||
assert (u"Brecht", u"Brecht") in result
|
||||
|
||||
|
||||
def test_autolimit(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autolimit = 0")
|
||||
ip.run_line_magic("config", "SqlMagic.autolimit = 0")
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
assert len(result) == 2
|
||||
ip.run_line_magic('config', "SqlMagic.autolimit = 1")
|
||||
ip.run_line_magic("config", "SqlMagic.autolimit = 1")
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
assert len(result) == 1
|
||||
|
||||
|
@ -113,8 +134,8 @@ def test_persist(ip):
|
|||
ip.run_cell("results = %sql SELECT * FROM test;")
|
||||
ip.run_cell("results_dframe = results.DataFrame()")
|
||||
ip.run_cell("%sql --persist sqlite:// results_dframe")
|
||||
persisted = runsql(ip, 'SELECT * FROM results_dframe')
|
||||
assert 'foo' in str(persisted)
|
||||
persisted = runsql(ip, "SELECT * FROM results_dframe")
|
||||
assert "foo" in str(persisted)
|
||||
|
||||
|
||||
def test_append(ip):
|
||||
|
@ -122,9 +143,9 @@ def test_append(ip):
|
|||
ip.run_cell("results = %sql SELECT * FROM test;")
|
||||
ip.run_cell("results_dframe = results.DataFrame()")
|
||||
ip.run_cell("%sql --persist sqlite:// results_dframe")
|
||||
persisted = runsql(ip, 'SELECT COUNT(*) FROM results_dframe')
|
||||
persisted = runsql(ip, "SELECT COUNT(*) FROM results_dframe")
|
||||
ip.run_cell("%sql --append sqlite:// results_dframe")
|
||||
appended = runsql(ip, 'SELECT COUNT(*) FROM results_dframe')
|
||||
appended = runsql(ip, "SELECT COUNT(*) FROM results_dframe")
|
||||
assert appended[0][0] == persisted[0][0] * 2
|
||||
|
||||
|
||||
|
@ -149,28 +170,32 @@ def test_persist_bare(ip):
|
|||
def test_persist_frame_at_its_creation(ip):
|
||||
ip.run_cell("results = %sql SELECT * FROM author;")
|
||||
ip.run_cell("%sql --persist sqlite:// results.DataFrame()")
|
||||
persisted = runsql(ip, 'SELECT * FROM results')
|
||||
assert 'Shakespeare' in str(persisted)
|
||||
persisted = runsql(ip, "SELECT * FROM results")
|
||||
assert "Shakespeare" in str(persisted)
|
||||
|
||||
|
||||
def test_connection_args_enforce_json(ip):
|
||||
result = ip.run_cell("%sql --connection_arguments {\"badlyformed\":true")
|
||||
result = ip.run_cell('%sql --connection_arguments {"badlyformed":true')
|
||||
assert result.error_in_exec
|
||||
|
||||
|
||||
def test_connection_args_in_connection(ip):
|
||||
ip.run_cell("%sql --connection_arguments {\"timeout\":10} sqlite:///:memory:")
|
||||
ip.run_cell('%sql --connection_arguments {"timeout":10} sqlite:///:memory:')
|
||||
result = ip.run_cell("%sql --connections")
|
||||
assert 'timeout' in result.result['sqlite:///:memory:'].connect_args
|
||||
assert "timeout" in result.result["sqlite:///:memory:"].connect_args
|
||||
|
||||
|
||||
def test_connection_args_single_quotes(ip):
|
||||
ip.run_cell("%sql --connection_arguments '{\"timeout\": 10}' sqlite:///:memory:")
|
||||
result = ip.run_cell("%sql --connections")
|
||||
assert 'timeout' in result.result['sqlite:///:memory:'].connect_args
|
||||
assert "timeout" in result.result["sqlite:///:memory:"].connect_args
|
||||
|
||||
|
||||
def test_connection_args_double_quotes(ip):
|
||||
ip.run_cell('%sql --connection_arguments \"{\\\"timeout\\\": 10}\" sqlite:///:memory:')
|
||||
ip.run_cell('%sql --connection_arguments "{\\"timeout\\": 10}" sqlite:///:memory:')
|
||||
result = ip.run_cell("%sql --connections")
|
||||
assert 'timeout' in result.result['sqlite:///:memory:'].connect_args
|
||||
assert "timeout" in result.result["sqlite:///:memory:"].connect_args
|
||||
|
||||
|
||||
# TODO: support
|
||||
# @with_setup(_setup_author, _teardown_author)
|
||||
|
@ -182,162 +207,168 @@ def test_connection_args_double_quotes(ip):
|
|||
|
||||
|
||||
def test_displaylimit(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autolimit = None")
|
||||
ip.run_line_magic('config', "SqlMagic.displaylimit = None")
|
||||
ip.run_line_magic("config", "SqlMagic.autolimit = None")
|
||||
ip.run_line_magic("config", "SqlMagic.displaylimit = None")
|
||||
result = runsql(
|
||||
ip,
|
||||
"SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;"
|
||||
"SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;",
|
||||
)
|
||||
assert 'apple' in result._repr_html_()
|
||||
assert 'banana' in result._repr_html_()
|
||||
assert 'cherry' in result._repr_html_()
|
||||
ip.run_line_magic('config', "SqlMagic.displaylimit = 1")
|
||||
assert "apple" in result._repr_html_()
|
||||
assert "banana" in result._repr_html_()
|
||||
assert "cherry" in result._repr_html_()
|
||||
ip.run_line_magic("config", "SqlMagic.displaylimit = 1")
|
||||
result = runsql(
|
||||
ip,
|
||||
"SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;"
|
||||
"SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;",
|
||||
)
|
||||
assert 'apple' in result._repr_html_()
|
||||
assert 'cherry' not in result._repr_html_()
|
||||
assert "apple" in result._repr_html_()
|
||||
assert "cherry" not in result._repr_html_()
|
||||
|
||||
|
||||
def test_column_local_vars(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.column_local_vars = True")
|
||||
ip.run_line_magic("config", "SqlMagic.column_local_vars = True")
|
||||
result = runsql(ip, "SELECT * FROM author;")
|
||||
assert result is None
|
||||
assert 'William' in ip.user_global_ns['first_name']
|
||||
assert 'Shakespeare' in ip.user_global_ns['last_name']
|
||||
assert len(ip.user_global_ns['first_name']) == 2
|
||||
ip.run_line_magic('config', "SqlMagic.column_local_vars = False")
|
||||
assert "William" in ip.user_global_ns["first_name"]
|
||||
assert "Shakespeare" in ip.user_global_ns["last_name"]
|
||||
assert len(ip.user_global_ns["first_name"]) == 2
|
||||
ip.run_line_magic("config", "SqlMagic.column_local_vars = False")
|
||||
|
||||
|
||||
def test_userns_not_changed(ip):
|
||||
ip.run_cell(
|
||||
dedent("""
|
||||
dedent(
|
||||
"""
|
||||
def function():
|
||||
local_var = 'local_val'
|
||||
%sql sqlite:// INSERT INTO test VALUES (2, 'bar');
|
||||
function()"""))
|
||||
assert 'local_var' not in ip.user_ns
|
||||
function()"""
|
||||
)
|
||||
)
|
||||
assert "local_var" not in ip.user_ns
|
||||
|
||||
|
||||
def test_bind_vars(ip):
|
||||
ip.user_global_ns['x'] = 22
|
||||
ip.user_global_ns["x"] = 22
|
||||
result = runsql(ip, "SELECT :x")
|
||||
assert result[0][0] == 22
|
||||
|
||||
|
||||
def test_autopandas(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autopandas = True")
|
||||
ip.run_line_magic("config", "SqlMagic.autopandas = True")
|
||||
dframe = runsql(ip, "SELECT * FROM test;")
|
||||
assert not dframe.empty
|
||||
assert dframe.ndim == 2
|
||||
assert dframe.name[0] == 'foo'
|
||||
assert dframe.name[0] == "foo"
|
||||
|
||||
|
||||
def test_csv(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh
|
||||
ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
result = result.csv()
|
||||
for row in result.splitlines():
|
||||
assert row.count(',') == 1
|
||||
assert row.count(",") == 1
|
||||
assert len(result.splitlines()) == 3
|
||||
|
||||
|
||||
def test_csv_to_file(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh
|
||||
ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh
|
||||
result = runsql(ip, "SELECT * FROM test;")
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
fname = os.path.join(tempdir, 'test.csv')
|
||||
fname = os.path.join(tempdir, "test.csv")
|
||||
output = result.csv(fname)
|
||||
assert os.path.exists(output.file_path)
|
||||
with open(output.file_path) as csvfile:
|
||||
content = csvfile.read()
|
||||
for row in content.splitlines():
|
||||
assert row.count(',') == 1
|
||||
assert row.count(",") == 1
|
||||
assert len(content.splitlines()) == 3
|
||||
|
||||
|
||||
def test_sql_from_file(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autopandas = False")
|
||||
ip.run_line_magic("config", "SqlMagic.autopandas = False")
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
fname = os.path.join(tempdir, 'test.sql')
|
||||
with open(fname, 'w') as tempf:
|
||||
fname = os.path.join(tempdir, "test.sql")
|
||||
with open(fname, "w") as tempf:
|
||||
tempf.write("SELECT * FROM test;")
|
||||
result = ip.run_cell("%sql --file " + fname)
|
||||
assert result.result == [(1, 'foo'), (2, 'bar')]
|
||||
|
||||
result = ip.run_cell("%sql --file " + fname)
|
||||
assert result.result == [(1, "foo"), (2, "bar")]
|
||||
|
||||
|
||||
def test_sql_from_nonexistent_file(ip):
|
||||
ip.run_line_magic('config', "SqlMagic.autopandas = False")
|
||||
ip.run_line_magic("config", "SqlMagic.autopandas = False")
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
fname = os.path.join(tempdir, 'nonexistent.sql')
|
||||
result = ip.run_cell("%sql --file " + fname)
|
||||
fname = os.path.join(tempdir, "nonexistent.sql")
|
||||
result = ip.run_cell("%sql --file " + fname)
|
||||
assert isinstance(result.error_in_exec, FileNotFoundError)
|
||||
|
||||
|
||||
|
||||
def test_dict(ip):
|
||||
result = runsql(ip, "SELECT * FROM author;")
|
||||
result = result.dict()
|
||||
assert isinstance(result, dict)
|
||||
assert 'first_name' in result
|
||||
assert 'last_name' in result
|
||||
assert 'year_of_death' in result
|
||||
assert len(result['last_name']) == 2
|
||||
assert "first_name" in result
|
||||
assert "last_name" in result
|
||||
assert "year_of_death" in result
|
||||
assert len(result["last_name"]) == 2
|
||||
|
||||
|
||||
def test_dicts(ip):
|
||||
result = runsql(ip, "SELECT * FROM author;")
|
||||
for row in result.dicts():
|
||||
assert isinstance(row, dict)
|
||||
assert 'first_name' in row
|
||||
assert 'last_name' in row
|
||||
assert 'year_of_death' in row
|
||||
assert "first_name" in row
|
||||
assert "last_name" in row
|
||||
assert "year_of_death" in row
|
||||
|
||||
|
||||
def test_bracket_var_substitution(ip):
|
||||
|
||||
ip.user_global_ns['col'] = 'first_name'
|
||||
assert runsql(ip,
|
||||
"SELECT * FROM author"
|
||||
" WHERE {col} = 'William' ")[0] == (u'William',
|
||||
u'Shakespeare', 1616)
|
||||
ip.user_global_ns["col"] = "first_name"
|
||||
assert runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")[0] == (
|
||||
u"William",
|
||||
u"Shakespeare",
|
||||
1616,
|
||||
)
|
||||
|
||||
ip.user_global_ns["col"] = "last_name"
|
||||
result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")
|
||||
assert not result
|
||||
|
||||
ip.user_global_ns['col'] = 'last_name'
|
||||
result = runsql(ip,
|
||||
"SELECT * FROM author"
|
||||
" WHERE {col} = 'William' ")
|
||||
assert not result
|
||||
|
||||
|
||||
def test_multiline_bracket_var_substitution(ip):
|
||||
|
||||
ip.user_global_ns['col'] = 'first_name'
|
||||
assert runsql(ip,
|
||||
"SELECT * FROM author\n"
|
||||
" WHERE {col} = 'William' ")[0] == (u'William',
|
||||
u'Shakespeare', 1616)
|
||||
ip.user_global_ns["col"] = "first_name"
|
||||
assert runsql(ip, "SELECT * FROM author\n" " WHERE {col} = 'William' ")[0] == (
|
||||
u"William",
|
||||
u"Shakespeare",
|
||||
1616,
|
||||
)
|
||||
|
||||
ip.user_global_ns["col"] = "last_name"
|
||||
result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")
|
||||
assert not result
|
||||
|
||||
ip.user_global_ns['col'] = 'last_name'
|
||||
result = runsql(ip,
|
||||
"SELECT * FROM author"
|
||||
" WHERE {col} = 'William' ")
|
||||
assert not result
|
||||
|
||||
|
||||
def test_multiline_bracket_var_substitution(ip):
|
||||
ip.user_global_ns['col'] = 'first_name'
|
||||
result = ip.run_cell_magic('sql', '', """
|
||||
ip.user_global_ns["col"] = "first_name"
|
||||
result = ip.run_cell_magic(
|
||||
"sql",
|
||||
"",
|
||||
"""
|
||||
sqlite:// SELECT * FROM author
|
||||
WHERE {col} = 'William'
|
||||
""")
|
||||
assert (u'William', u'Shakespeare', 1616) in result
|
||||
""",
|
||||
)
|
||||
assert (u"William", u"Shakespeare", 1616) in result
|
||||
|
||||
ip.user_global_ns['col'] = 'last_name'
|
||||
result = ip.run_cell_magic('sql', '', """
|
||||
ip.user_global_ns["col"] = "last_name"
|
||||
result = ip.run_cell_magic(
|
||||
"sql",
|
||||
"",
|
||||
"""
|
||||
sqlite:// SELECT * FROM author
|
||||
WHERE {col} = 'William'
|
||||
""")
|
||||
assert not result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
""",
|
||||
)
|
||||
assert not result
|
||||
|
|
|
@ -1,79 +1,103 @@
|
|||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
from sql.parse import parse, connection_from_dsn_section
|
||||
from pathlib import Path
|
||||
|
||||
from six.moves import configparser
|
||||
|
||||
from sql.parse import connection_from_dsn_section, parse
|
||||
|
||||
try:
|
||||
from traitlets.config.configurable import Configurable
|
||||
except ImportError:
|
||||
from IPython.config.configurable import Configurable
|
||||
import json
|
||||
|
||||
empty_config = Configurable()
|
||||
default_connect_args = {'options': '-csearch_path=test'}
|
||||
default_connect_args = {"options": "-csearch_path=test"}
|
||||
|
||||
|
||||
def test_parse_no_sql():
|
||||
assert parse("will:longliveliz@localhost/shakes", empty_config) == \
|
||||
{'connection': "will:longliveliz@localhost/shakes",
|
||||
'sql': '',
|
||||
'result_var': None}
|
||||
assert parse("will:longliveliz@localhost/shakes", empty_config) == {
|
||||
"connection": "will:longliveliz@localhost/shakes",
|
||||
"sql": "",
|
||||
"result_var": None,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_with_sql():
|
||||
assert parse("postgresql://will:longliveliz@localhost/shakes SELECT * FROM work",
|
||||
empty_config) == \
|
||||
{'connection': "postgresql://will:longliveliz@localhost/shakes",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': None}
|
||||
assert parse(
|
||||
"postgresql://will:longliveliz@localhost/shakes SELECT * FROM work",
|
||||
empty_config,
|
||||
) == {
|
||||
"connection": "postgresql://will:longliveliz@localhost/shakes",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": None,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_sql_only():
|
||||
assert parse("SELECT * FROM work", empty_config) == \
|
||||
{'connection': "",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': None}
|
||||
assert parse("SELECT * FROM work", empty_config) == {
|
||||
"connection": "",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": None,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_postgresql_socket_connection():
|
||||
assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == \
|
||||
{'connection': "postgresql:///shakes",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': None}
|
||||
assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == {
|
||||
"connection": "postgresql:///shakes",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": None,
|
||||
}
|
||||
|
||||
|
||||
def test_expand_environment_variables_in_connection():
|
||||
os.environ['DATABASE_URL'] = 'postgresql:///shakes'
|
||||
assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == \
|
||||
{'connection': "postgresql:///shakes",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': None}
|
||||
os.environ["DATABASE_URL"] = "postgresql:///shakes"
|
||||
assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == {
|
||||
"connection": "postgresql:///shakes",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": None,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_shovel_operator():
|
||||
assert parse("dest << SELECT * FROM work", empty_config) == \
|
||||
{'connection': "",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': "dest"}
|
||||
assert parse("dest << SELECT * FROM work", empty_config) == {
|
||||
"connection": "",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": "dest",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_connect_plus_shovel():
|
||||
assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == \
|
||||
{'connection': "sqlite://",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': None}
|
||||
assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == {
|
||||
"connection": "sqlite://",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": None,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_shovel_operator():
|
||||
assert parse("dest << SELECT * FROM work", empty_config) == \
|
||||
{'connection': "",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': "dest"}
|
||||
assert parse("dest << SELECT * FROM work", empty_config) == {
|
||||
"connection": "",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": "dest",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_connect_plus_shovel():
|
||||
assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == \
|
||||
{'connection': "sqlite://",
|
||||
'sql': 'SELECT * FROM work',
|
||||
'result_var': "dest"}
|
||||
assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == {
|
||||
"connection": "sqlite://",
|
||||
"sql": "SELECT * FROM work",
|
||||
"result_var": "dest",
|
||||
}
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
dsn_filename = Path('src/tests/test_dsn_config.ini')
|
||||
dsn_filename = Path("src/tests/test_dsn_config.ini")
|
||||
|
||||
|
||||
def test_connection_from_dsn_section():
|
||||
|
||||
result = connection_from_dsn_section(section='DB_CONFIG_1',
|
||||
config = DummyConfig())
|
||||
assert result == 'postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain'
|
||||
result = connection_from_dsn_section(section='DB_CONFIG_2',
|
||||
config = DummyConfig())
|
||||
assert result == 'mysql://thefin:fishputsfishonthetable@127.0.0.1/dolfin'
|
||||
result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig())
|
||||
assert result == "postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain"
|
||||
result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig())
|
||||
assert result == "mysql://thefin:fishputsfishonthetable@127.0.0.1/dolfin"
|
||||
|
|
Loading…
Reference in New Issue