Skip to content

Commit

Permalink
chore: Enable SIM Ruff checks (#1509)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon authored Mar 20, 2023
1 parent c48bd8c commit 0d0eaec
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 67 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ select = [
"PT", # flake8-pytest-style
"RSE", # flake8-raise
"RET", # flake8-return
"SIM", # flake8-simplify
]
src = ["samples", "singer_sdk", "tests"]
target-version = "py37"
Expand Down
2 changes: 1 addition & 1 deletion samples/sample_tap_hostile/hostile_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_records(self, context: dict | None) -> Iterable[dict | tuple[dict, dict]
return (
{
key: self.get_random_lowercase_string()
for key in self.schema["properties"].keys()
for key in self.schema["properties"]
}
for _ in range(10)
)
10 changes: 6 additions & 4 deletions singer_sdk/_singerlib/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ def get_standard_metadata(
if schema_name:
root.schema_name = schema_name

for field_name in schema.get("properties", {}).keys():
if key_properties and field_name in key_properties:
entry = Metadata(inclusion=Metadata.InclusionType.AUTOMATIC)
elif valid_replication_keys and field_name in valid_replication_keys:
for field_name in schema.get("properties", {}):
if (
key_properties
and field_name in key_properties
or (valid_replication_keys and field_name in valid_replication_keys)
):
entry = Metadata(inclusion=Metadata.InclusionType.AUTOMATIC)
else:
entry = Metadata(inclusion=Metadata.InclusionType.AVAILABLE)
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/configuration/_dict_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse_environment_config(
logger.debug("Loading configuration from %s", dotenv_path)
DotEnv(dotenv_path).set_as_environment_variables()

for config_key in config_schema["properties"].keys():
for config_key in config_schema["properties"]:
env_var_name = prefix + config_key.upper().replace("-", "_")
if env_var_name in os.environ:
env_var_value = os.environ[env_var_name]
Expand Down
6 changes: 1 addition & 5 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,11 +854,7 @@ def merge_sql_types(
if issubclass(
generic_type,
(sqlalchemy.types.String, sqlalchemy.types.Unicode),
):
# If length None or 0 then is varchar max ?
if (opt_len is None) or (opt_len == 0):
return opt
elif isinstance(
) or issubclass(
generic_type,
(sqlalchemy.types.String, sqlalchemy.types.Unicode),
):
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/helpers/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _flatten_schema(

for k, v in schema_node["properties"].items():
new_key = flatten_key(k, parent_keys, separator)
if "type" in v.keys():
if "type" in v:
if "object" in v["type"] and "properties" in v and level < max_level:
items.extend(
_flatten_schema(
Expand Down
24 changes: 13 additions & 11 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,19 @@ def finalize_state_progress_markers(stream_or_partition_state: dict) -> dict | N
"""Promote or wipe progress markers once sync is complete."""
signpost_value = stream_or_partition_state.pop(SIGNPOST_MARKER, None)
stream_or_partition_state.pop(STARTING_MARKER, None)
if PROGRESS_MARKERS in stream_or_partition_state:
if "replication_key" in stream_or_partition_state[PROGRESS_MARKERS]:
# Replication keys valid (only) after sync is complete
progress_markers = stream_or_partition_state[PROGRESS_MARKERS]
stream_or_partition_state["replication_key"] = progress_markers.pop(
"replication_key",
)
new_rk_value = progress_markers.pop("replication_key_value")
if signpost_value and _greater_than_signpost(signpost_value, new_rk_value):
new_rk_value = signpost_value
stream_or_partition_state["replication_key_value"] = new_rk_value
if (
PROGRESS_MARKERS in stream_or_partition_state
and "replication_key" in stream_or_partition_state[PROGRESS_MARKERS]
):
# Replication keys valid (only) after sync is complete
progress_markers = stream_or_partition_state[PROGRESS_MARKERS]
stream_or_partition_state["replication_key"] = progress_markers.pop(
"replication_key",
)
new_rk_value = progress_markers.pop("replication_key_value")
if signpost_value and _greater_than_signpost(signpost_value, new_rk_value):
new_rk_value = signpost_value
stream_or_partition_state["replication_key_value"] = new_rk_value
# Wipe and return any markers that have not been promoted
return reset_state_progress_markers(stream_or_partition_state)

Expand Down
5 changes: 1 addition & 4 deletions singer_sdk/helpers/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,7 @@ def is_datetime_type(type_dict: dict) -> bool:
"Did you forget to define a property in the stream schema?",
)
if "anyOf" in type_dict:
for type_dict in type_dict["anyOf"]:
if is_datetime_type(type_dict):
return True
return False
return any(is_datetime_type(type_dict) for type_dict in type_dict["anyOf"])
if "type" in type_dict:
return type_dict.get("format") == "date-time"
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions singer_sdk/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
"""
if not config:
config_dict = {}
elif isinstance(config, str) or isinstance(config, PurePath):
elif isinstance(config, (str, PurePath)):
config_dict = read_json_file(config)
elif isinstance(config, list):
config_dict = {}
Expand Down Expand Up @@ -339,7 +339,7 @@ def print_about(cls: type[PluginBase], format: str | None = None) -> None:
elif format == "markdown":
max_setting_len = cast(
int,
max(len(k) for k in info["settings"]["properties"].keys()),
max(len(k) for k in info["settings"]["properties"]),
)

# Set table base for markdown
Expand Down
16 changes: 9 additions & 7 deletions singer_sdk/sinks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _parse_timestamps_in_record(
schema: TODO
treatment: TODO
"""
for key in record.keys():
for key in record:
datelike_type = get_datelike_property_type(schema["properties"][key])
if datelike_type:
try:
Expand Down Expand Up @@ -479,12 +479,14 @@ def process_batch_files(
storage = StorageTarget.from_url(head)

if encoding.format == BatchFileFormat.JSONL:
with storage.fs(create=False) as batch_fs:
with batch_fs.open(tail, mode="rb") as file:
if encoding.compression == "gzip":
file = gzip_open(file)
context = {"records": [json.loads(line) for line in file]}
self.process_batch(context)
with storage.fs(create=False) as batch_fs, batch_fs.open(
tail,
mode="rb",
) as file:
if encoding.compression == "gzip":
file = gzip_open(file)
context = {"records": [json.loads(line) for line in file]}
self.process_batch(context)
else:
raise NotImplementedError(
f"Unsupported batch encoding format: {encoding.format}",
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def conform_schema(self, schema: dict) -> dict:
"""
conformed_schema = copy(schema)
conformed_property_names = {
key: self.conform_name(key) for key in conformed_schema["properties"].keys()
key: self.conform_name(key) for key in conformed_schema["properties"]
}
self._check_conformed_names_not_duplicated(conformed_property_names)
conformed_schema["properties"] = {
Expand Down
14 changes: 8 additions & 6 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,12 +1266,14 @@ def get_batches(
):
filename = f"{prefix}{sync_id}-{i}.json.gz"
with batch_config.storage.fs() as fs:
with fs.open(filename, "wb") as f:
# TODO: Determine compression from config.
with gzip.GzipFile(fileobj=f, mode="wb") as gz:
gz.writelines(
(json.dumps(record) + "\n").encode() for record in chunk
)
# TODO: Determine compression from config.
with fs.open(filename, "wb") as f, gzip.GzipFile(
fileobj=f,
mode="wb",
) as gz:
gz.writelines(
(json.dumps(record) + "\n").encode() for record in chunk
)
file_url = fs.geturl(filename)

yield batch_config.encoding, [file_url]
Expand Down
14 changes: 3 additions & 11 deletions singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ def _url_encode(val: str | datetime | bool | int | list[str]) -> str:
Returns:
TODO
"""
if isinstance(val, str):
result = val.replace("/", "%2F")
else:
result = str(val)
return result
return val.replace("/", "%2F") if isinstance(val, str) else str(val)

def get_url(self, context: dict | None) -> str:
"""Get stream entity URL.
Expand Down Expand Up @@ -199,10 +195,7 @@ def response_error_message(self, response: requests.Response) -> str:
str: The error message
"""
full_path = urlparse(response.url).path or self.path
if 400 <= response.status_code < 500:
error_type = "Client"
else:
error_type = "Server"
error_type = "Client" if 400 <= response.status_code < 500 else "Server"

return (
f"{response.status_code} {error_type} Error: "
Expand Down Expand Up @@ -430,8 +423,7 @@ def update_sync_costs(
"""
call_costs = self.calculate_sync_cost(request, response, context)
self._sync_costs = {
k: self._sync_costs.get(k, 0) + call_costs.get(k, 0)
for k in call_costs.keys()
k: self._sync_costs.get(k, 0) + call_costs.get(k, 0) for k in call_costs
}
return self._sync_costs

Expand Down
5 changes: 2 additions & 3 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
import contextlib
import json
from enum import Enum
from pathlib import Path, PurePath
Expand Down Expand Up @@ -197,10 +198,8 @@ def run_connection_test(self) -> bool:
"Skipping direct invocation.",
)
continue
try:
with contextlib.suppress(MaxRecordsLimitException):
stream.sync()
except MaxRecordsLimitException:
pass
return True

@final
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/testing/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def input(self) -> IO[str]:
if self.input_io:
self._input = self.input_io
elif self.input_filepath:
self._input = open(self.input_filepath)
self._input = open(self.input_filepath) # noqa: SIM115
return cast(IO[str], self._input)

@input.setter
Expand Down
5 changes: 2 additions & 3 deletions tests/core/configuration/test_dict_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def test_get_env_var_config_not_parsable():
"PLUGIN_TEST_PROP1": "hello",
"PLUGIN_TEST_PROP3": '["repeated"]',
},
):
with pytest.raises(ValueError, match="A bracketed list was detected"):
parse_environment_config(CONFIG_JSONSCHEMA, "PLUGIN_TEST_")
), pytest.raises(ValueError, match="A bracketed list was detected"):
parse_environment_config(CONFIG_JSONSCHEMA, "PLUGIN_TEST_")


def test_merge_config_sources(config_file1, config_file2):
Expand Down
16 changes: 10 additions & 6 deletions tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,19 @@ def test_deprecated_functions_warn(self, connector):
connector.connection

def test_connect_calls_engine(self, connector):
with mock.patch.object(SQLConnector, "_engine") as mock_engine:
with connector._connect() as _:
mock_engine.connect.assert_called_once()
with mock.patch.object(
SQLConnector,
"_engine",
) as mock_engine, connector._connect() as _:
mock_engine.connect.assert_called_once()

def test_connect_calls_connect(self, connector):
attached_engine = connector._engine
with mock.patch.object(attached_engine, "connect") as mock_conn:
with connector._connect() as _:
mock_conn.assert_called_once()
with mock.patch.object(
attached_engine,
"connect",
) as mock_conn, connector._connect() as _:
mock_conn.assert_called_once()

def test_connect_raises_on_operational_failure(self, connector):
with pytest.raises(
Expand Down

0 comments on commit 0d0eaec

Please sign in to comment.