diff --git a/examples/writers.ipynb b/examples/writers.ipynb index 7857cca..9b41b00 100644 --- a/examples/writers.ipynb +++ b/examples/writers.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { diff --git a/requirements-dev.txt b/requirements-dev.txt index 5a0bfe1..60ebd32 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,6 @@ pytest wheel twine readme-renderer +black +isort diff --git a/src/sql/column_guesser.py b/src/sql/column_guesser.py index 8e46ac2..34d4db8 100644 --- a/src/sql/column_guesser.py +++ b/src/sql/column_guesser.py @@ -4,26 +4,30 @@ makes guesses about the role of each column for plotting purposes (X values, Y values, and text labels). """ + class Column(list): - 'Store a column of tabular data; record its name and whether it is numeric' + "Store a column of tabular data; record its name and whether it is numeric" is_quantity = True - name = '' + name = "" + def __init__(self, *arg, **kwarg): pass - + def is_quantity(val): """Is ``val`` a quantity (int, float, datetime, etc) (not str, bool)? Relies on presence of __sub__. """ - return hasattr(val, '__sub__') + return hasattr(val, "__sub__") + class ColumnGuesserMixin(object): """ plot: [x, y, y...], y pie: ... y """ + def _build_columns(self): self.columns = [Column() for col in self.keys] for row in self: @@ -32,39 +36,40 @@ class ColumnGuesserMixin(object): col.append(col_val) if (col_val is not None) and (not is_quantity(col_val)): col.is_quantity = False - + for (idx, key_name) in enumerate(self.keys): self.columns[idx].name = key_name - + self.x = Column() self.ys = [] - + def _get_y(self): - for idx in range(len(self.columns)-1,-1,-1): + for idx in range(len(self.columns) - 1, -1, -1): if self.columns[idx].is_quantity: self.ys.insert(0, self.columns.pop(idx)) return True - def _get_x(self): + def _get_x(self): for idx in range(len(self.columns)): if self.columns[idx].is_quantity: self.x = self.columns.pop(idx) return True - + def _get_xlabel(self, xlabel_sep=" "): self.xlabels = [] if self.columns: for row_idx in range(len(self.columns[0])): - self.xlabels.append(xlabel_sep.join( - str(c[row_idx]) for c in self.columns)) + self.xlabels.append( + xlabel_sep.join(str(c[row_idx]) for c in self.columns) + ) self.xlabel = ", ".join(c.name for c in self.columns) - + def _guess_columns(self): self._build_columns() self._get_y() if not self.ys: raise AttributeError("No quantitative columns found for chart") - + def guess_pie_columns(self, xlabel_sep=" "): """ Assigns x, y, and x labels from the data set for a pie chart. @@ -75,7 +80,7 @@ class ColumnGuesserMixin(object): """ self._guess_columns() self._get_xlabel(xlabel_sep) - + def guess_plot_columns(self): """ Assigns ``x`` and ``y`` series from the data set for a plot. @@ -88,4 +93,4 @@ class ColumnGuesserMixin(object): self._guess_columns() self._get_x() while self._get_y(): - pass \ No newline at end of file + pass diff --git a/src/sql/connection.py b/src/sql/connection.py index fe930c9..75cab6d 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -1,20 +1,22 @@ -import sqlalchemy import os import re +import sqlalchemy + + class ConnectionError(Exception): pass def rough_dict_get(dct, sought, default=None): - ''' + """ Like dct.get(sought), but any key containing sought will do. If there is a `@` in sought, seek each piece separately. This lets `me@server` match `me:***@myserver/db` - ''' - - sought = sought.split('@') + """ + + sought = sought.split("@") for (key, val) in dct.items(): if not any(s.lower() not in key.lower() for s in sought): return val @@ -29,15 +31,21 @@ class Connection(object): def tell_format(cls): return """Connection info needed in SQLAlchemy format, example: postgresql://username:password@hostname/dbname - or an existing connection: %s""" % str(cls.connections.keys()) + or an existing connection: %s""" % str( + cls.connections.keys() + ) def __init__(self, connect_str=None, connect_args={}, creator=None): try: if creator: - engine = sqlalchemy.create_engine(connect_str, connect_args=connect_args, creator=creator) + engine = sqlalchemy.create_engine( + connect_str, connect_args=connect_args, creator=creator + ) else: - engine = sqlalchemy.create_engine(connect_str, connect_args=connect_args) - except: # TODO: bare except; but what's an ArgumentError? + engine = sqlalchemy.create_engine( + connect_str, connect_args=connect_args + ) + except: # TODO: bare except; but what's an ArgumentError? print(self.tell_format()) raise self.dialect = engine.url.get_dialect() @@ -65,42 +73,50 @@ class Connection(object): if displaycon: print(cls.connection_list()) else: - if os.getenv('DATABASE_URL'): - cls.current = Connection(os.getenv('DATABASE_URL'), connect_args, creator) + if os.getenv("DATABASE_URL"): + cls.current = Connection( + os.getenv("DATABASE_URL"), connect_args, creator + ) else: - raise ConnectionError('Environment variable $DATABASE_URL not set, and no connect string given.') + raise ConnectionError( + "Environment variable $DATABASE_URL not set, and no connect string given." + ) return cls.current @classmethod def assign_name(cls, engine): - name = '%s@%s' % (engine.url.username or '', engine.url.database) + name = "%s@%s" % (engine.url.username or "", engine.url.database) return name @classmethod def connection_list(cls): result = [] for key in sorted(cls.connections): - engine_url = cls.connections[key].metadata.bind.url # type: sqlalchemy.engine.url.URL + engine_url = cls.connections[ + key + ].metadata.bind.url # type: sqlalchemy.engine.url.URL if cls.connections[key] == cls.current: - template = ' * {}' + template = " * {}" else: - template = ' {}' + template = " {}" result.append(template.format(engine_url.__repr__())) - return '\n'.join(result) + return "\n".join(result) + def _close(cls, descriptor): if isinstance(descriptor, Connection): conn = descriptor else: - conn = cls.connections.get(descriptor) or \ - cls.connections.get(descriptor.lower()) + conn = cls.connections.get(descriptor) or cls.connections.get( + descriptor.lower() + ) if not conn: - raise Exception("Could not close connection because it was not found amongst these: %s" \ - %str(cls.connections.keys())) + raise Exception( + "Could not close connection because it was not found amongst these: %s" + % str(cls.connections.keys()) + ) cls.connections.pop(conn.name) cls.connections.pop(str(conn.metadata.bind.url)) conn.session.close() def close(self): self.__class__._close(self) - - diff --git a/src/sql/magic.py b/src/sql/magic.py index aec97d7..a94e38a 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -1,26 +1,30 @@ import json import re from string import Formatter -from IPython.core.magic import Magics, magics_class, cell_magic, line_magic, needs_local_scope + +from IPython.core.magic import (Magics, cell_magic, line_magic, magics_class, + needs_local_scope) +from IPython.core.magic_arguments import (argument, magic_arguments, + parse_argstring) from IPython.display import display_javascript +from sqlalchemy.exc import OperationalError, ProgrammingError + +import sql.connection +import sql.parse +import sql.run + try: from traitlets.config.configurable import Configurable from traitlets import Bool, Int, Unicode except ImportError: from IPython.config.configurable import Configurable from IPython.utils.traitlets import Bool, Int, Unicode -from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring try: from pandas.core.frame import DataFrame, Series except ImportError: DataFrame = None Series = None -from sqlalchemy.exc import ProgrammingError, OperationalError - -import sql.connection -import sql.parse -import sql.run @magics_class class SqlMagic(Magics, Configurable): @@ -29,20 +33,47 @@ class SqlMagic(Magics, Configurable): Provides the %%sql magic.""" displaycon = Bool(True, config=True, help="Show connection string after execute") - autolimit = Int(0, config=True, allow_none=True, help="Automatically limit the size of the returned result sets") - style = Unicode('DEFAULT', config=True, help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)") - short_errors = Bool(True, config=True, help="Don't display the full traceback on SQL Programming Error") - displaylimit = Int(None, config=True, allow_none=True, help="Automatically limit the number of rows displayed (full result set is still stored)") - autopandas = Bool(False, config=True, help="Return Pandas DataFrames instead of regular result sets") - column_local_vars = Bool(False, config=True, help="Return data into local variables from column names") + autolimit = Int( + 0, + config=True, + allow_none=True, + help="Automatically limit the size of the returned result sets", + ) + style = Unicode( + "DEFAULT", + config=True, + help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)", + ) + short_errors = Bool( + True, + config=True, + help="Don't display the full traceback on SQL Programming Error", + ) + displaylimit = Int( + None, + config=True, + allow_none=True, + help="Automatically limit the number of rows displayed (full result set is still stored)", + ) + autopandas = Bool( + False, + config=True, + help="Return Pandas DataFrames instead of regular result sets", + ) + column_local_vars = Bool( + False, config=True, help="Return data into local variables from column names" + ) feedback = Bool(True, config=True, help="Print number of rows affected by DML") - dsn_filename = Unicode('odbc.ini', config=True, help="Path to DSN file. " - "When the first argument is of the form [section], " - "a sqlalchemy connection string is formed from the " - "matching section in the DSN file.") + dsn_filename = Unicode( + "odbc.ini", + config=True, + help="Path to DSN file. " + "When the first argument is of the form [section], " + "a sqlalchemy connection string is formed from the " + "matching section in the DSN file.", + ) autocommit = Bool(True, config=True, help="Set autocommit mode") - def __init__(self, shell): Configurable.__init__(self, config=shell.config) Magics.__init__(self, shell=shell) @@ -51,19 +82,42 @@ class SqlMagic(Magics, Configurable): self.shell.configurables.append(self) @needs_local_scope - @line_magic('sql') - @cell_magic('sql') + @line_magic("sql") + @cell_magic("sql") @magic_arguments() - @argument('line', default='', nargs='*', type=str, help='sql') - @argument('-l', '--connections', action='store_true', help="list active connections") - @argument('-x', '--close', type=str, help="close a session by name") - @argument('-c', '--creator', type=str, help="specify creator function for new connection") - @argument('-s', '--section', type=str, help="section of dsn_file to be used for generating a connection string") - @argument('-p', '--persist', action='store_true', help="create a table name in the database from the named DataFrame") - @argument('--append', action='store_true', help="create, or append to, a table name in the database from the named DataFrame") - @argument('-a', '--connection_arguments', type=str, help="specify dictionary of connection arguments to pass to SQL driver") - @argument('-f', '--file', type=str, help="Run SQL from file at this path") - def execute(self, line='', cell='', local_ns={}): + @argument("line", default="", nargs="*", type=str, help="sql") + @argument( + "-l", "--connections", action="store_true", help="list active connections" + ) + @argument("-x", "--close", type=str, help="close a session by name") + @argument( + "-c", "--creator", type=str, help="specify creator function for new connection" + ) + @argument( + "-s", + "--section", + type=str, + help="section of dsn_file to be used for generating a connection string", + ) + @argument( + "-p", + "--persist", + action="store_true", + help="create a table name in the database from the named DataFrame", + ) + @argument( + "--append", + action="store_true", + help="create, or append to, a table name in the database from the named DataFrame", + ) + @argument( + "-a", + "--connection_arguments", + type=str, + help="specify dictionary of connection arguments to pass to SQL driver", + ) + @argument("-f", "--file", type=str, help="Run SQL from file at this path") + def execute(self, line="", cell="", local_ns={}): """Runs SQL statement against a database, specified by SQLAlchemy connect string. If no database connection has been established, first word @@ -89,7 +143,9 @@ class SqlMagic(Magics, Configurable): """ # Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically) - cell_variables = [fn for _, fn, _, _ in Formatter().parse(cell) if fn is not None] + cell_variables = [ + fn for _, fn, _, _ in Formatter().parse(cell) if fn is not None + ] cell_params = {} for variable in cell_variables: cell_params[variable] = local_ns[variable] @@ -105,15 +161,15 @@ class SqlMagic(Magics, Configurable): user_ns = self.shell.user_ns.copy() user_ns.update(local_ns) - command_text = ' '.join(args.line) + '\n' + cell + command_text = " ".join(args.line) + "\n" + cell if args.file: - with open(args.file, 'r') as infile: - command_text = infile.read() + "\n" + command_text + with open(args.file, "r") as infile: + command_text = infile.read() + "\n" + command_text - parsed = sql.parse.parse(command_text, self) + parsed = sql.parse.parse(command_text, self) - connect_str = parsed['connection'] + connect_str = parsed["connection"] if args.section: connect_str = sql.parse.connection_from_dsn_section(args.section, self) @@ -137,27 +193,36 @@ class SqlMagic(Magics, Configurable): args.creator = user_ns[args.creator] try: - conn = sql.connection.Connection.set(parsed['connection'], displaycon=self.displaycon, connect_args=args.connection_arguments, creator=args.creator) + conn = sql.connection.Connection.set( + parsed["connection"], + displaycon=self.displaycon, + connect_args=args.connection_arguments, + creator=args.creator, + ) except Exception as e: print(e) print(sql.connection.Connection.tell_format()) return None if args.persist: - return self._persist_dataframe(parsed['sql'], conn, user_ns, append=False) + return self._persist_dataframe(parsed["sql"], conn, user_ns, append=False) if args.append: - return self._persist_dataframe(parsed['sql'], conn, user_ns, append=True) + return self._persist_dataframe(parsed["sql"], conn, user_ns, append=True) - if not parsed['sql']: + if not parsed["sql"]: return try: - result = sql.run.run(conn, parsed['sql'], self, user_ns) + result = sql.run.run(conn, parsed["sql"], self, user_ns) - if result is not None and not isinstance(result, str) and self.column_local_vars: - #Instead of returning values, set variables directly in the - #users namespace. Variable names given by column names + if ( + result is not None + and not isinstance(result, str) + and self.column_local_vars + ): + # Instead of returning values, set variables directly in the + # users namespace. Variable names given by column names if self.autopandas: keys = result.keys() @@ -166,21 +231,22 @@ class SqlMagic(Magics, Configurable): result = result.dict() if self.feedback: - print('Returning data to local variables [{}]'.format( - ', '.join(keys))) + print( + "Returning data to local variables [{}]".format(", ".join(keys)) + ) self.shell.user_ns.update(result) return None else: - if parsed['result_var']: - result_var = parsed['result_var'] + if parsed["result_var"]: + result_var = parsed["result_var"] print("Returning data to local variable {}".format(result_var)) self.shell.user_ns.update({result_var: result}) return None - #Return results into the default ipython _ variable + # Return results into the default ipython _ variable return result except (ProgrammingError, OperationalError) as e: @@ -190,28 +256,29 @@ class SqlMagic(Magics, Configurable): else: raise - legal_sql_identifier = re.compile(r'^[A-Za-z0-9#_$]+') + legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+") + def _persist_dataframe(self, raw, conn, user_ns, append=False): """Implements PERSIST, which writes a DataFrame to the RDBMS""" if not DataFrame: raise ImportError("Must `pip install pandas` to use DataFrames") - frame_name = raw.strip(';') + frame_name = raw.strip(";") # Get the DataFrame from the user namespace if not frame_name: - raise SyntaxError('Syntax: %sql PERSIST ') + raise SyntaxError("Syntax: %sql PERSIST ") frame = eval(frame_name, user_ns) if not isinstance(frame, DataFrame) and not isinstance(frame, Series): - raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name) + raise TypeError("%s is not a Pandas DataFrame or Series" % frame_name) # Make a suitable name for the resulting database table table_name = frame_name.lower() table_name = self.legal_sql_identifier.search(table_name).group(0) - if_exists = 'append' if append else 'fail' + if_exists = "append" if append else "fail" frame.to_sql(table_name, conn.session.engine, if_exists=if_exists) - return 'Persisted %s' % table_name + return "Persisted %s" % table_name def load_ipython_extension(ip): diff --git a/src/sql/parse.py b/src/sql/parse.py index be70276..49be5ae 100644 --- a/src/sql/parse.py +++ b/src/sql/parse.py @@ -1,9 +1,11 @@ +import json +import re from os.path import expandvars + import six from six.moves import configparser as CP from sqlalchemy.engine.url import URL -import json -import re + def connection_from_dsn_section(section, config): parser = CP.ConfigParser() @@ -11,19 +13,19 @@ def connection_from_dsn_section(section, config): cfg_dict = dict(parser.items(section)) return str(URL(**cfg_dict)) + def _connection_string(s): s = expandvars(s) # for environment variables - if '@' in s or '://' in s: + if "@" in s or "://" in s: return s - if s.startswith('[') and s.endswith(']'): - section = s.lstrip('[').rstrip(']') + if s.startswith("[") and s.endswith("]"): + section = s.lstrip("[").rstrip("]") parser = CP.ConfigParser() parser.read(config.dsn_filename) cfg_dict = dict(parser.items(section)) return str(URL(**cfg_dict)) - return '' - + return "" def parse(cell, config): @@ -37,18 +39,18 @@ def parse(cell, config): connection string and `<<` operator in. """ - result = {'connection': '', 'sql': '', 'result_var': None} + result = {"connection": "", "sql": "", "result_var": None} pieces = cell.split(None, 3) if not pieces: return result - result['connection'] = _connection_string(pieces[0]) - if result['connection']: + result["connection"] = _connection_string(pieces[0]) + if result["connection"]: pieces.pop(0) - if len(pieces) > 1 and pieces[1] == '<<': - result['result_var'] = pieces.pop(0) + if len(pieces) > 1 and pieces[1] == "<<": + result["result_var"] = pieces.pop(0) pieces.pop(0) # discard << operator - result['sql'] = (' '.join(pieces)).strip() + result["sql"] = (" ".join(pieces)).strip() return result diff --git a/src/sql/run.py b/src/sql/run.py index a2632a8..5445bf6 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -24,9 +24,9 @@ def unduplicate_field_names(field_names): for k in field_names: if k in res: i = 1 - while k + '_' + str(i) in res: + while k + "_" + str(i) in res: i += 1 - k += '_' + str(i) + k += "_" + str(i) res.append(k) return res @@ -46,8 +46,7 @@ class UnicodeWriter(object): def writerow(self, row): if six.PY2: - _row = [s.encode("utf-8") if hasattr(s, "encode") else s - for s in row] + _row = [s.encode("utf-8") if hasattr(s, "encode") else s for s in row] else: _row = row self.writer.writerow(_row) @@ -75,12 +74,12 @@ class CsvResultDescriptor(object): self.file_path = file_path def __repr__(self): - return 'CSV results at %s' % os.path.join( - os.path.abspath('.'), self.file_path) + return "CSV results at %s" % os.path.join(os.path.abspath("."), self.file_path) def _repr_html_(self): - return 'CSV results' % os.path.join('.', 'files', - self.file_path) + return 'CSV results' % os.path.join( + ".", "files", self.file_path + ) def _nonbreaking_spaces(match_obj): @@ -90,11 +89,11 @@ def _nonbreaking_spaces(match_obj): Call with a ``re`` match object. Retain group 1, replace group 2 with nonbreaking speaces. """ - spaces = ' ' * len(match_obj.group(2)) - return '%s%s' % (match_obj.group(1), spaces) + spaces = " " * len(match_obj.group(2)) + return "%s%s" % (match_obj.group(1), spaces) -_cell_with_spaces_pattern = re.compile(r'()( {2,})') +_cell_with_spaces_pattern = re.compile(r"()( {2,})") class ResultSet(list, ColumnGuesserMixin): @@ -124,22 +123,23 @@ class ResultSet(list, ColumnGuesserMixin): self.pretty = None def _repr_html_(self): - _cell_with_spaces_pattern = re.compile(r'()( {2,})') + _cell_with_spaces_pattern = re.compile(r"()( {2,})") if self.pretty: self.pretty.add_rows(self) result = self.pretty.get_html_string() result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result) - if self.config.displaylimit and len( - self) > self.config.displaylimit: - result = '%s\n%d rows, truncated to displaylimit of %d' % ( - result, len(self), self.config.displaylimit) + if self.config.displaylimit and len(self) > self.config.displaylimit: + result = ( + '%s\n%d rows, truncated to displaylimit of %d' + % (result, len(self), self.config.displaylimit) + ) return result else: return None def __str__(self, *arg, **kwarg): self.pretty.add_rows(self) - return str(self.pretty or '') + return str(self.pretty or "") def __getitem__(self, key): """ @@ -170,6 +170,7 @@ class ResultSet(list, ColumnGuesserMixin): def DataFrame(self): "Returns a Pandas DataFrame instance built from the result set." import pandas as pd + frame = pd.DataFrame(self, columns=(self and self.keys) or []) return frame @@ -196,6 +197,7 @@ class ResultSet(list, ColumnGuesserMixin): """ self.guess_pie_columns(xlabel_sep=key_word_sep) import matplotlib.pylab as plt + pie = plt.pie(self.ys[0], labels=self.xlabels, **kwargs) plt.title(title or self.ys[0].name) return pie @@ -219,11 +221,12 @@ class ResultSet(list, ColumnGuesserMixin): through to ``matplotlib.pylab.plot``. """ import matplotlib.pylab as plt + self.guess_plot_columns() self.x = self.x or range(len(self.ys[0])) coords = reduce(operator.add, [(self.x, y) for y in self.ys]) plot = plt.plot(*coords, **kwargs) - if hasattr(self.x, 'name'): + if hasattr(self.x, "name"): plt.xlabel(self.x.name) ylabel = ", ".join(y.name for y in self.ys) plt.title(title or ylabel) @@ -251,6 +254,7 @@ class ResultSet(list, ColumnGuesserMixin): through to ``matplotlib.pylab.bar``. """ import matplotlib.pylab as plt + self.guess_pie_columns(xlabel_sep=key_word_sep) plot = plt.bar(range(len(self.ys[0])), self.ys[0], **kwargs) if self.xlabels: @@ -266,11 +270,11 @@ class ResultSet(list, ColumnGuesserMixin): return None # no results self.pretty.add_rows(self) if filename: - encoding = format_params.get('encoding', 'utf-8') + encoding = format_params.get("encoding", "utf-8") if six.PY2: - outfile = open(filename, 'wb') + outfile = open(filename, "wb") else: - outfile = open(filename, 'w', newline='', encoding=encoding) + outfile = open(filename, "w", newline="", encoding=encoding) else: outfile = six.StringIO() writer = UnicodeWriter(outfile, **format_params) @@ -286,9 +290,9 @@ class ResultSet(list, ColumnGuesserMixin): def interpret_rowcount(rowcount): if rowcount < 0: - result = 'Done.' + result = "Done." else: - result = '%d rows affected.' % rowcount + result = "%d rows affected." % rowcount return result @@ -313,34 +317,33 @@ class FakeResultProxy(object): def from_list(self, source_list): "Simulates SQLA ResultProxy from a list." - self.fetchall = lambda: source_list + self.fetchall = lambda: source_list self.rowcount = len(source_list) def fetchmany(size): - pos = 0 + pos = 0 while pos < len(source_list): - yield source_list[pos:pos+size] - pos += size + yield source_list[pos : pos + size] + pos += size self.fetchmany = fetchmany - # some dialects have autocommit # specific dialects break when commit is used: -_COMMIT_BLACKLIST_DIALECTS = ('mssql', 'clickhouse', 'teradata', 'athena') +_COMMIT_BLACKLIST_DIALECTS = ("mssql", "clickhouse", "teradata", "athena") def _commit(conn, config): """Issues a commit, if appropriate for current config and dialect""" _should_commit = config.autocommit and all( - dialect not in str(conn.dialect) - for dialect in _COMMIT_BLACKLIST_DIALECTS) + dialect not in str(conn.dialect) for dialect in _COMMIT_BLACKLIST_DIALECTS + ) if _should_commit: try: - conn.session.execute('commit') + conn.session.execute("commit") except sqlalchemy.exc.OperationalError: pass # not all engines can commit @@ -349,14 +352,15 @@ def run(conn, sql, config, user_namespace): if sql.strip(): for statement in sqlparse.split(sql): first_word = sql.strip().split()[0].lower() - if first_word == 'begin': + if first_word == "begin": raise Exception("ipython_sql does not support transactions") - if first_word.startswith('\\') and 'postgres' in str(conn.dialect): + if first_word.startswith("\\") and "postgres" in str(conn.dialect): if not PGSpecial: - raise ImportError('pgspecial not installed') + raise ImportError("pgspecial not installed") pgspecial = PGSpecial() _, cur, headers, _ = pgspecial.execute( - conn.session.connection.cursor(), statement)[0] + conn.session.connection.cursor(), statement + )[0] result = FakeResultProxy(cur, headers) else: txt = sqlalchemy.sql.text(statement) @@ -369,9 +373,9 @@ def run(conn, sql, config, user_namespace): return resultset.DataFrame() else: return resultset - #returning only last result, intentionally + # returning only last result, intentionally else: - return 'Connected: %s' % conn.name + return "Connected: %s" % conn.name class PrettyTable(prettytable.PrettyTable): @@ -391,5 +395,5 @@ class PrettyTable(prettytable.PrettyTable): self.row_count = len(data) else: self.row_count = min(len(data), self.displaylimit) - for row in data[:self.displaylimit]: + for row in data[: self.displaylimit]: self.add_row(row) diff --git a/src/tests/test_column_guesser.py b/src/tests/test_column_guesser.py index ebbc781..0df2cce 100644 --- a/src/tests/test_column_guesser.py +++ b/src/tests/test_column_guesser.py @@ -13,10 +13,10 @@ class SqlEnv(object): self.connectstr = connectstr def query(self, txt): - return ip.run_line_magic('sql', "%s %s" % (self.connectstr, txt)) + return ip.run_line_magic("sql", "%s %s" % (self.connectstr, txt)) -sql_env = SqlEnv('sqlite://') +sql_env = SqlEnv("sqlite://") @pytest.fixture @@ -54,14 +54,14 @@ class TestOneNum(Harness): assert results.ys == [[1.01, 2.01, 3.01]] assert results.x == [] assert results.xlabels == [] - assert results.xlabel == '' + assert results.xlabel == "" def test_plot(self, tbl): results = self.run_query() results.guess_plot_columns() assert results.ys == [[1.01, 2.01, 3.01]] assert results.x == [] - assert results.x.name == '' + assert results.x.name == "" class TestOneStrOneNum(Harness): @@ -72,8 +72,8 @@ class TestOneStrOneNum(Harness): results.guess_pie_columns(xlabel_sep="//") assert results.ys[0].is_quantity assert results.ys == [[1.01, 2.01, 3.01]] - assert results.xlabels == ['r1-txt1', 'r2-txt1', 'r3-txt1'] - assert results.xlabel == 'name' + assert results.xlabels == ["r1-txt1", "r2-txt1", "r3-txt1"] + assert results.xlabel == "name" def test_plot(self, tbl): results = self.run_query() @@ -91,10 +91,11 @@ class TestTwoStrTwoNum(Harness): assert results.ys[0].is_quantity assert results.ys == [[1.01, 2.01, 3.01]] assert results.xlabels == [ - 'r1-txt2//1.04//r1-txt1', 'r2-txt2//2.04//r2-txt1', - 'r3-txt2//3.04//r3-txt1' + "r1-txt2//1.04//r1-txt1", + "r2-txt2//2.04//r2-txt1", + "r3-txt2//3.04//r3-txt1", ] - assert results.xlabel == 'name2, y3, name' + assert results.xlabel == "name2, y3, name" def test_plot(self, tbl): results = self.run_query() @@ -112,8 +113,9 @@ class TestTwoStrThreeNum(Harness): assert results.ys[0].is_quantity assert results.ys == [[1.04, 2.04, 3.04]] assert results.xlabels == [ - 'r1-txt1//1.01//r1-txt2//1.02', 'r2-txt1//2.01//r2-txt2//2.02', - 'r3-txt1//3.01//r3-txt2//3.02' + "r1-txt1//1.01//r1-txt2//1.02", + "r2-txt1//2.01//r2-txt2//2.02", + "r3-txt1//3.01//r3-txt2//3.02", ] def test_plot(self, tbl): diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index a551e43..09c2352 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -14,7 +14,7 @@ def runsql(ip_session, statements): statements, ] for statement in statements: - result = ip_session.run_line_magic('sql', 'sqlite:// %s' % statement) + result = ip_session.run_line_magic("sql", "sqlite:// %s" % statement) return result # returns only last result @@ -23,87 +23,108 @@ def ip(): """Provides an IPython session in which tables have been created""" ip_session = get_ipython() - runsql(ip_session, [ - "CREATE TABLE test (n INT, name TEXT)", - "INSERT INTO test VALUES (1, 'foo')", - "INSERT INTO test VALUES (2, 'bar')", - "CREATE TABLE author (first_name, last_name, year_of_death)", - "INSERT INTO author VALUES ('William', 'Shakespeare', 1616)", - "INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)" - ]) + runsql( + ip_session, + [ + "CREATE TABLE test (n INT, name TEXT)", + "INSERT INTO test VALUES (1, 'foo')", + "INSERT INTO test VALUES (2, 'bar')", + "CREATE TABLE author (first_name, last_name, year_of_death)", + "INSERT INTO author VALUES ('William', 'Shakespeare', 1616)", + "INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)", + ], + ) yield ip_session - runsql(ip_session, 'DROP TABLE test') - runsql(ip_session, 'DROP TABLE author') + runsql(ip_session, "DROP TABLE test") + runsql(ip_session, "DROP TABLE author") def test_memory_db(ip): assert runsql(ip, "SELECT * FROM test;")[0][0] == 1 - assert runsql(ip, "SELECT * FROM test;")[1]['name'] == 'bar' + assert runsql(ip, "SELECT * FROM test;")[1]["name"] == "bar" def test_html(ip): result = runsql(ip, "SELECT * FROM test;") - assert 'foo' in result._repr_html_().lower() + assert "foo" in result._repr_html_().lower() def test_print(ip): result = runsql(ip, "SELECT * FROM test;") - assert re.search(r'1\s+\|\s+foo', str(result)) + assert re.search(r"1\s+\|\s+foo", str(result)) def test_plain_style(ip): - ip.run_line_magic('config', "SqlMagic.style = 'PLAIN_COLUMNS'") + ip.run_line_magic("config", "SqlMagic.style = 'PLAIN_COLUMNS'") result = runsql(ip, "SELECT * FROM test;") - assert re.search(r'1\s+\|\s+foo', str(result)) + assert re.search(r"1\s+\|\s+foo", str(result)) @pytest.mark.skip def test_multi_sql(ip): - result = ip.run_cell_magic('sql', '', """ + result = ip.run_cell_magic( + "sql", + "", + """ sqlite:// SELECT last_name FROM author; - """) - assert 'Shakespeare' in str(result) and 'Brecht' in str(result) + """, + ) + assert "Shakespeare" in str(result) and "Brecht" in str(result) def test_result_var(ip): - ip.run_cell_magic('sql', '', """ + ip.run_cell_magic( + "sql", + "", + """ sqlite:// x << SELECT last_name FROM author; - """) - result = ip.user_global_ns['x'] - assert 'Shakespeare' in str(result) and 'Brecht' in str(result) + """, + ) + result = ip.user_global_ns["x"] + assert "Shakespeare" in str(result) and "Brecht" in str(result) def test_result_var_multiline_shovel(ip): - ip.run_cell_magic('sql', '', """ + ip.run_cell_magic( + "sql", + "", + """ sqlite:// x << SELECT last_name FROM author; - """) - result = ip.user_global_ns['x'] - assert 'Shakespeare' in str(result) and 'Brecht' in str(result) + """, + ) + result = ip.user_global_ns["x"] + assert "Shakespeare" in str(result) and "Brecht" in str(result) def test_access_results_by_keys(ip): - assert runsql(ip, - "SELECT * FROM author;")['William'] == (u'William', - u'Shakespeare', 1616) + assert runsql(ip, "SELECT * FROM author;")["William"] == ( + u"William", + u"Shakespeare", + 1616, + ) def test_duplicate_column_names_accepted(ip): - result = ip.run_cell_magic('sql', '', """ + result = ip.run_cell_magic( + "sql", + "", + """ sqlite:// SELECT last_name, last_name FROM author; - """) - assert (u'Brecht', u'Brecht') in result + """, + ) + assert (u"Brecht", u"Brecht") in result def test_autolimit(ip): - ip.run_line_magic('config', "SqlMagic.autolimit = 0") + ip.run_line_magic("config", "SqlMagic.autolimit = 0") result = runsql(ip, "SELECT * FROM test;") assert len(result) == 2 - ip.run_line_magic('config', "SqlMagic.autolimit = 1") + ip.run_line_magic("config", "SqlMagic.autolimit = 1") result = runsql(ip, "SELECT * FROM test;") assert len(result) == 1 @@ -113,8 +134,8 @@ def test_persist(ip): ip.run_cell("results = %sql SELECT * FROM test;") ip.run_cell("results_dframe = results.DataFrame()") ip.run_cell("%sql --persist sqlite:// results_dframe") - persisted = runsql(ip, 'SELECT * FROM results_dframe') - assert 'foo' in str(persisted) + persisted = runsql(ip, "SELECT * FROM results_dframe") + assert "foo" in str(persisted) def test_append(ip): @@ -122,9 +143,9 @@ def test_append(ip): ip.run_cell("results = %sql SELECT * FROM test;") ip.run_cell("results_dframe = results.DataFrame()") ip.run_cell("%sql --persist sqlite:// results_dframe") - persisted = runsql(ip, 'SELECT COUNT(*) FROM results_dframe') + persisted = runsql(ip, "SELECT COUNT(*) FROM results_dframe") ip.run_cell("%sql --append sqlite:// results_dframe") - appended = runsql(ip, 'SELECT COUNT(*) FROM results_dframe') + appended = runsql(ip, "SELECT COUNT(*) FROM results_dframe") assert appended[0][0] == persisted[0][0] * 2 @@ -149,28 +170,32 @@ def test_persist_bare(ip): def test_persist_frame_at_its_creation(ip): ip.run_cell("results = %sql SELECT * FROM author;") ip.run_cell("%sql --persist sqlite:// results.DataFrame()") - persisted = runsql(ip, 'SELECT * FROM results') - assert 'Shakespeare' in str(persisted) + persisted = runsql(ip, "SELECT * FROM results") + assert "Shakespeare" in str(persisted) def test_connection_args_enforce_json(ip): - result = ip.run_cell("%sql --connection_arguments {\"badlyformed\":true") + result = ip.run_cell('%sql --connection_arguments {"badlyformed":true') assert result.error_in_exec + def test_connection_args_in_connection(ip): - ip.run_cell("%sql --connection_arguments {\"timeout\":10} sqlite:///:memory:") + ip.run_cell('%sql --connection_arguments {"timeout":10} sqlite:///:memory:') result = ip.run_cell("%sql --connections") - assert 'timeout' in result.result['sqlite:///:memory:'].connect_args + assert "timeout" in result.result["sqlite:///:memory:"].connect_args + def test_connection_args_single_quotes(ip): ip.run_cell("%sql --connection_arguments '{\"timeout\": 10}' sqlite:///:memory:") result = ip.run_cell("%sql --connections") - assert 'timeout' in result.result['sqlite:///:memory:'].connect_args + assert "timeout" in result.result["sqlite:///:memory:"].connect_args + def test_connection_args_double_quotes(ip): - ip.run_cell('%sql --connection_arguments \"{\\\"timeout\\\": 10}\" sqlite:///:memory:') + ip.run_cell('%sql --connection_arguments "{\\"timeout\\": 10}" sqlite:///:memory:') result = ip.run_cell("%sql --connections") - assert 'timeout' in result.result['sqlite:///:memory:'].connect_args + assert "timeout" in result.result["sqlite:///:memory:"].connect_args + # TODO: support # @with_setup(_setup_author, _teardown_author) @@ -182,162 +207,168 @@ def test_connection_args_double_quotes(ip): def test_displaylimit(ip): - ip.run_line_magic('config', "SqlMagic.autolimit = None") - ip.run_line_magic('config', "SqlMagic.displaylimit = None") + ip.run_line_magic("config", "SqlMagic.autolimit = None") + ip.run_line_magic("config", "SqlMagic.displaylimit = None") result = runsql( ip, - "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;" + "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;", ) - assert 'apple' in result._repr_html_() - assert 'banana' in result._repr_html_() - assert 'cherry' in result._repr_html_() - ip.run_line_magic('config', "SqlMagic.displaylimit = 1") + assert "apple" in result._repr_html_() + assert "banana" in result._repr_html_() + assert "cherry" in result._repr_html_() + ip.run_line_magic("config", "SqlMagic.displaylimit = 1") result = runsql( ip, - "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;" + "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;", ) - assert 'apple' in result._repr_html_() - assert 'cherry' not in result._repr_html_() + assert "apple" in result._repr_html_() + assert "cherry" not in result._repr_html_() def test_column_local_vars(ip): - ip.run_line_magic('config', "SqlMagic.column_local_vars = True") + ip.run_line_magic("config", "SqlMagic.column_local_vars = True") result = runsql(ip, "SELECT * FROM author;") assert result is None - assert 'William' in ip.user_global_ns['first_name'] - assert 'Shakespeare' in ip.user_global_ns['last_name'] - assert len(ip.user_global_ns['first_name']) == 2 - ip.run_line_magic('config', "SqlMagic.column_local_vars = False") + assert "William" in ip.user_global_ns["first_name"] + assert "Shakespeare" in ip.user_global_ns["last_name"] + assert len(ip.user_global_ns["first_name"]) == 2 + ip.run_line_magic("config", "SqlMagic.column_local_vars = False") def test_userns_not_changed(ip): ip.run_cell( - dedent(""" + dedent( + """ def function(): local_var = 'local_val' %sql sqlite:// INSERT INTO test VALUES (2, 'bar'); - function()""")) - assert 'local_var' not in ip.user_ns + function()""" + ) + ) + assert "local_var" not in ip.user_ns def test_bind_vars(ip): - ip.user_global_ns['x'] = 22 + ip.user_global_ns["x"] = 22 result = runsql(ip, "SELECT :x") assert result[0][0] == 22 def test_autopandas(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = True") + ip.run_line_magic("config", "SqlMagic.autopandas = True") dframe = runsql(ip, "SELECT * FROM test;") assert not dframe.empty assert dframe.ndim == 2 - assert dframe.name[0] == 'foo' + assert dframe.name[0] == "foo" def test_csv(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh + ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh result = runsql(ip, "SELECT * FROM test;") result = result.csv() for row in result.splitlines(): - assert row.count(',') == 1 + assert row.count(",") == 1 assert len(result.splitlines()) == 3 def test_csv_to_file(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh + ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh result = runsql(ip, "SELECT * FROM test;") with tempfile.TemporaryDirectory() as tempdir: - fname = os.path.join(tempdir, 'test.csv') + fname = os.path.join(tempdir, "test.csv") output = result.csv(fname) assert os.path.exists(output.file_path) with open(output.file_path) as csvfile: content = csvfile.read() for row in content.splitlines(): - assert row.count(',') == 1 + assert row.count(",") == 1 assert len(content.splitlines()) == 3 def test_sql_from_file(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = False") + ip.run_line_magic("config", "SqlMagic.autopandas = False") with tempfile.TemporaryDirectory() as tempdir: - fname = os.path.join(tempdir, 'test.sql') - with open(fname, 'w') as tempf: + fname = os.path.join(tempdir, "test.sql") + with open(fname, "w") as tempf: tempf.write("SELECT * FROM test;") - result = ip.run_cell("%sql --file " + fname) - assert result.result == [(1, 'foo'), (2, 'bar')] - + result = ip.run_cell("%sql --file " + fname) + assert result.result == [(1, "foo"), (2, "bar")] + + def test_sql_from_nonexistent_file(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = False") + ip.run_line_magic("config", "SqlMagic.autopandas = False") with tempfile.TemporaryDirectory() as tempdir: - fname = os.path.join(tempdir, 'nonexistent.sql') - result = ip.run_cell("%sql --file " + fname) + fname = os.path.join(tempdir, "nonexistent.sql") + result = ip.run_cell("%sql --file " + fname) assert isinstance(result.error_in_exec, FileNotFoundError) - + + def test_dict(ip): result = runsql(ip, "SELECT * FROM author;") result = result.dict() assert isinstance(result, dict) - assert 'first_name' in result - assert 'last_name' in result - assert 'year_of_death' in result - assert len(result['last_name']) == 2 + assert "first_name" in result + assert "last_name" in result + assert "year_of_death" in result + assert len(result["last_name"]) == 2 def test_dicts(ip): result = runsql(ip, "SELECT * FROM author;") for row in result.dicts(): assert isinstance(row, dict) - assert 'first_name' in row - assert 'last_name' in row - assert 'year_of_death' in row + assert "first_name" in row + assert "last_name" in row + assert "year_of_death" in row def test_bracket_var_substitution(ip): - ip.user_global_ns['col'] = 'first_name' - assert runsql(ip, - "SELECT * FROM author" - " WHERE {col} = 'William' ")[0] == (u'William', - u'Shakespeare', 1616) + ip.user_global_ns["col"] = "first_name" + assert runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")[0] == ( + u"William", + u"Shakespeare", + 1616, + ) + + ip.user_global_ns["col"] = "last_name" + result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ") + assert not result - ip.user_global_ns['col'] = 'last_name' - result = runsql(ip, - "SELECT * FROM author" - " WHERE {col} = 'William' ") - assert not result - def test_multiline_bracket_var_substitution(ip): - ip.user_global_ns['col'] = 'first_name' - assert runsql(ip, - "SELECT * FROM author\n" - " WHERE {col} = 'William' ")[0] == (u'William', - u'Shakespeare', 1616) + ip.user_global_ns["col"] = "first_name" + assert runsql(ip, "SELECT * FROM author\n" " WHERE {col} = 'William' ")[0] == ( + u"William", + u"Shakespeare", + 1616, + ) + + ip.user_global_ns["col"] = "last_name" + result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ") + assert not result - ip.user_global_ns['col'] = 'last_name' - result = runsql(ip, - "SELECT * FROM author" - " WHERE {col} = 'William' ") - assert not result - def test_multiline_bracket_var_substitution(ip): - ip.user_global_ns['col'] = 'first_name' - result = ip.run_cell_magic('sql', '', """ + ip.user_global_ns["col"] = "first_name" + result = ip.run_cell_magic( + "sql", + "", + """ sqlite:// SELECT * FROM author WHERE {col} = 'William' - """) - assert (u'William', u'Shakespeare', 1616) in result + """, + ) + assert (u"William", u"Shakespeare", 1616) in result - ip.user_global_ns['col'] = 'last_name' - result = ip.run_cell_magic('sql', '', """ + ip.user_global_ns["col"] = "last_name" + result = ip.run_cell_magic( + "sql", + "", + """ sqlite:// SELECT * FROM author WHERE {col} = 'William' - """) - assert not result - - - - - + """, + ) + assert not result diff --git a/src/tests/test_parse.py b/src/tests/test_parse.py index 31a73d7..2269664 100644 --- a/src/tests/test_parse.py +++ b/src/tests/test_parse.py @@ -1,79 +1,103 @@ -from pathlib import Path +import json import os -from sql.parse import parse, connection_from_dsn_section +from pathlib import Path + from six.moves import configparser + +from sql.parse import connection_from_dsn_section, parse + try: from traitlets.config.configurable import Configurable except ImportError: from IPython.config.configurable import Configurable -import json empty_config = Configurable() -default_connect_args = {'options': '-csearch_path=test'} +default_connect_args = {"options": "-csearch_path=test"} + + def test_parse_no_sql(): - assert parse("will:longliveliz@localhost/shakes", empty_config) == \ - {'connection': "will:longliveliz@localhost/shakes", - 'sql': '', - 'result_var': None} + assert parse("will:longliveliz@localhost/shakes", empty_config) == { + "connection": "will:longliveliz@localhost/shakes", + "sql": "", + "result_var": None, + } + def test_parse_with_sql(): - assert parse("postgresql://will:longliveliz@localhost/shakes SELECT * FROM work", - empty_config) == \ - {'connection': "postgresql://will:longliveliz@localhost/shakes", - 'sql': 'SELECT * FROM work', - 'result_var': None} + assert parse( + "postgresql://will:longliveliz@localhost/shakes SELECT * FROM work", + empty_config, + ) == { + "connection": "postgresql://will:longliveliz@localhost/shakes", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_parse_sql_only(): - assert parse("SELECT * FROM work", empty_config) == \ - {'connection': "", - 'sql': 'SELECT * FROM work', - 'result_var': None} + assert parse("SELECT * FROM work", empty_config) == { + "connection": "", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_parse_postgresql_socket_connection(): - assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == \ - {'connection': "postgresql:///shakes", - 'sql': 'SELECT * FROM work', - 'result_var': None} + assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == { + "connection": "postgresql:///shakes", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_expand_environment_variables_in_connection(): - os.environ['DATABASE_URL'] = 'postgresql:///shakes' - assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == \ - {'connection': "postgresql:///shakes", - 'sql': 'SELECT * FROM work', - 'result_var': None} + os.environ["DATABASE_URL"] = "postgresql:///shakes" + assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == { + "connection": "postgresql:///shakes", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_parse_shovel_operator(): - assert parse("dest << SELECT * FROM work", empty_config) == \ - {'connection': "", - 'sql': 'SELECT * FROM work', - 'result_var': "dest"} + assert parse("dest << SELECT * FROM work", empty_config) == { + "connection": "", + "sql": "SELECT * FROM work", + "result_var": "dest", + } + def test_parse_connect_plus_shovel(): - assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == \ - {'connection': "sqlite://", - 'sql': 'SELECT * FROM work', - 'result_var': None} + assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == { + "connection": "sqlite://", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_parse_shovel_operator(): - assert parse("dest << SELECT * FROM work", empty_config) == \ - {'connection': "", - 'sql': 'SELECT * FROM work', - 'result_var': "dest"} + assert parse("dest << SELECT * FROM work", empty_config) == { + "connection": "", + "sql": "SELECT * FROM work", + "result_var": "dest", + } + def test_parse_connect_plus_shovel(): - assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == \ - {'connection': "sqlite://", - 'sql': 'SELECT * FROM work', - 'result_var': "dest"} + assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == { + "connection": "sqlite://", + "sql": "SELECT * FROM work", + "result_var": "dest", + } + class DummyConfig: - dsn_filename = Path('src/tests/test_dsn_config.ini') + dsn_filename = Path("src/tests/test_dsn_config.ini") + def test_connection_from_dsn_section(): - result = connection_from_dsn_section(section='DB_CONFIG_1', - config = DummyConfig()) - assert result == 'postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain' - result = connection_from_dsn_section(section='DB_CONFIG_2', - config = DummyConfig()) - assert result == 'mysql://thefin:fishputsfishonthetable@127.0.0.1/dolfin' + result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig()) + assert result == "postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain" + result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig()) + assert result == "mysql://thefin:fishputsfishonthetable@127.0.0.1/dolfin"