diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..84b7670 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,30 @@ +import unittest + +from .context import yfinance as yf + +import datetime as _dt +import pandas as _pd +import pytz as _pytz + +class TestUtils(unittest.TestCase): + + def test_parse_user_dt(self): + """ Purpose of _parse_user_dt() is to take any date-like value, + combine with specified timezone and + return its localized timestamp. + """ + tz_name = "America/New_York" + tz = _pytz.timezone(tz_name) + dt_answer = tz.localize(_dt.datetime(2023,1,1)) + + # All possible versions of 'dt_answer' + values = ["2023-01-01", _dt.date(2023,1,1), _dt.datetime(2023,1,1), _pd.Timestamp(_dt.date(2023,1,1))] + # - now add localized versions + values.append(tz.localize(_dt.datetime(2023,1,1))) + values.append(_pd.Timestamp(_dt.date(2023,1,1)).tz_localize(tz_name)) + values.append(int(_pd.Timestamp(_dt.date(2023,1,1)).tz_localize(tz_name).timestamp())) + + for v in values: + v2 = yf.utils._parse_user_dt(v, tz_name) + self.assertEqual(v2, dt_answer.timestamp()) + diff --git a/yfinance/utils.py b/yfinance/utils.py index 48b0434..b2a74b9 100644 --- a/yfinance/utils.py +++ b/yfinance/utils.py @@ -288,21 +288,30 @@ def camel2title(strings: List[str], sep: str = ' ', acronyms: Optional[List[str] return strings -def _parse_user_dt(dt, exchange_tz): +def _parse_user_dt(dt, tz_name): if isinstance(dt, int): # Should already be epoch, test with conversion: - _datetime.datetime.fromtimestamp(dt) - else: - # Convert str/date -> datetime, set tzinfo=exchange, get timestamp: - if isinstance(dt, str): - dt = _datetime.datetime.strptime(str(dt), '%Y-%m-%d') - if isinstance(dt, _datetime.date) and not isinstance(dt, _datetime.datetime): - dt = _datetime.datetime.combine(dt, _datetime.time(0)) - if isinstance(dt, _datetime.datetime) and dt.tzinfo is None: - # Assume user is referring to exchange's timezone - dt = _tz.timezone(exchange_tz).localize(dt) - dt = int(dt.timestamp()) - return dt + try: + _datetime.datetime.fromtimestamp(dt) + except: + raise Exception(f"'dt' is not a valid epoch: '{dt}'") + return dt + + # Convert str/date -> datetime, set tzinfo=exchange, get timestamp: + dt2 = dt + if isinstance(dt2, str): + dt2 = _datetime.datetime.strptime(str(dt2), '%Y-%m-%d') + if isinstance(dt2, _datetime.date) and not isinstance(dt2, _datetime.datetime): + dt2 = _datetime.datetime.combine(dt2, _datetime.time(0)) + if isinstance(dt2, _datetime.datetime) and dt2.tzinfo is None: + # Assume user is referring to exchange's timezone + if tz_name is None: + raise Exception(f"Must provide a timezone for localizing '{dt}'") + dt2 = _tz.timezone(tz_name).localize(dt2) + if not isinstance(dt2, _datetime.datetime): + raise Exception(f"'dt' is not a date-like value: '{dt}'") + dt2 = int(dt2.timestamp()) + return dt2 def _interval_to_timedelta(interval):