Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Dec 2, 2024
1 parent f2d4b52 commit a89676f
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 52 deletions.
4 changes: 2 additions & 2 deletions python_modules/dagster-graphql/dagster_graphql/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def infer_repository(graphql_context: WorkspaceRequestContext) -> RemoteReposito
assert len(repositories) == 1
return next(iter(repositories.values()))

code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
return code_location.get_repository("test_repo")


Expand All @@ -177,7 +177,7 @@ def infer_repository_selector(graphql_context: WorkspaceRequestContext) -> Selec
assert len(repositories) == 1
repository = next(iter(repositories.values()))
else:
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from dagster_graphql import DagsterGraphQLClientError, ReloadRepositoryLocationStatus
from dagster_graphql.test.utils import main_repo_location_name

from dagster_graphql_tests.client_tests.conftest import MockClient, python_client_test_suite
from dagster_graphql_tests.graphql.graphql_context_test_suite import (
Expand Down Expand Up @@ -96,6 +97,6 @@ def test_failure_with_query_error(mock_client: MockClient):
class TestReloadRepositoryLocationWithClient(BaseTestSuite):
def test_reload_location_real(self, graphql_client):
assert (
graphql_client.reload_repository_location("test").status
graphql_client.reload_repository_location(main_repo_location_name()).status
== ReloadRepositoryLocationStatus.SUCCESS
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dagster._core.errors import DagsterUserCodeUnreachableError
from dagster_graphql import ShutdownRepositoryLocationStatus
from dagster_graphql.client.client_queries import SHUTDOWN_REPOSITORY_LOCATION_MUTATION
from dagster_graphql.test.utils import execute_dagster_graphql
from dagster_graphql.test.utils import execute_dagster_graphql, main_repo_location_name

from dagster_graphql_tests.graphql.graphql_context_test_suite import (
GraphQLContextVariant,
Expand All @@ -22,7 +22,7 @@ def test_shutdown_repository_location_permission_failure(self, graphql_context):
result = execute_dagster_graphql(
graphql_context,
SHUTDOWN_REPOSITORY_LOCATION_MUTATION,
{"repositoryLocationName": "test"},
{"repositoryLocationName": main_repo_location_name()},
)

assert result
Expand All @@ -36,7 +36,7 @@ def test_shutdown_repository_location(self, graphql_client, graphql_context):
origin = next(iter(graphql_context.get_code_location_entries().values())).origin
origin.create_client().heartbeat()

result = graphql_client.shutdown_repository_location("test")
result = graphql_client.shutdown_repository_location(main_repo_location_name())

assert result.status == ShutdownRepositoryLocationStatus.SUCCESS, result.message

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ load_from:
- python_file:
relative_path: repo.py
attribute: test_repo
location_name: test
location_name: test_location
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
execute_dagster_graphql,
infer_job_selector,
infer_repository_selector,
main_repo_location_name,
)

from dagster_graphql_tests.graphql.graphql_context_test_suite import (
Expand Down Expand Up @@ -1241,7 +1242,7 @@ def test_asset_node_in_pipeline(self, graphql_context: WorkspaceRequestContext):

assert len(result.data["assetNodes"]) == 1
asset_node = result.data["assetNodes"][0]
assert asset_node["id"] == 'test.test_repo.["asset_one"]'
assert asset_node["id"] == f'{main_repo_location_name()}.test_repo.["asset_one"]'
assert asset_node["hasMaterializePermission"]
assert asset_node["hasReportRunlessAssetEventPermission"]

Expand All @@ -1256,7 +1257,7 @@ def test_asset_node_in_pipeline(self, graphql_context: WorkspaceRequestContext):

assert len(result.data["assetNodes"]) == 2
asset_node = result.data["assetNodes"][0]
assert asset_node["id"] == 'test.test_repo.["asset_one"]'
assert asset_node["id"] == f'{main_repo_location_name()}.test_repo.["asset_one"]'

def test_asset_node_is_executable(self, graphql_context: WorkspaceRequestContext):
result = execute_dagster_graphql(
Expand Down Expand Up @@ -2126,7 +2127,7 @@ def test_reexecute_subset(self, graphql_context: WorkspaceRequestContext):
def test_named_groups(self, graphql_context: WorkspaceRequestContext):
_create_run(graphql_context, "named_groups_job")
selector = {
"repositoryLocationName": "test",
"repositoryLocationName": main_repo_location_name(),
"repositoryName": "test_repo",
}

Expand Down Expand Up @@ -2640,7 +2641,7 @@ def test_auto_materialize_policy(self, graphql_context: WorkspaceRequestContext)
fresh_diamond_bottom = [
a
for a in result.data["assetNodes"]
if a["id"] == 'test.test_repo.["fresh_diamond_bottom"]'
if a["id"] == f'{main_repo_location_name()}.test_repo.["fresh_diamond_bottom"]'
]
assert len(fresh_diamond_bottom) == 1
assert fresh_diamond_bottom[0]["autoMaterializePolicy"]["policyType"] == "LAZY"
Expand All @@ -2654,7 +2655,8 @@ def test_automation_condition(self, graphql_context: WorkspaceRequestContext):
automation_condition_asset = [
a
for a in result.data["assetNodes"]
if a["id"] == 'test.test_repo.["asset_with_automation_condition"]'
if a["id"]
== f'{main_repo_location_name()}.test_repo.["asset_with_automation_condition"]'
]
assert len(automation_condition_asset) == 1
condition = automation_condition_asset[0]["automationCondition"]
Expand All @@ -2664,7 +2666,8 @@ def test_automation_condition(self, graphql_context: WorkspaceRequestContext):
custom_automation_condition_asset = [
a
for a in result.data["assetNodes"]
if a["id"] == 'test.test_repo.["asset_with_custom_automation_condition"]'
if a["id"]
== f'{main_repo_location_name()}.test_repo.["asset_with_custom_automation_condition"]'
]
assert len(custom_automation_condition_asset) == 1
condition = custom_automation_condition_asset[0]["automationCondition"]
Expand Down Expand Up @@ -3084,7 +3087,7 @@ def test_asset_unstarted_after_materialization(self, graphql_context: WorkspaceR
)

# Create two enqueued runs
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
job = repository.get_full_job("hanging_job")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dagster._core.workspace.context import WorkspaceRequestContext
from dagster_graphql.test.utils import execute_dagster_graphql, infer_repository_selector
from dagster_graphql.test.utils import (
execute_dagster_graphql,
infer_repository_selector,
main_repo_location_name,
)

from dagster_graphql_tests.graphql.graphql_context_test_suite import (
NonLaunchableGraphQLContextTestMatrix,
Expand Down Expand Up @@ -74,12 +78,17 @@ def test_basic_jobs(self, graphql_context: WorkspaceRequestContext):
repo_locations = {
blob["name"]: blob for blob in result.data["workspaceOrError"]["locationEntries"]
}
assert "test" in repo_locations
assert repo_locations["test"]["locationOrLoadError"]["__typename"] == "RepositoryLocation"
assert main_repo_location_name() in repo_locations
assert (
repo_locations[main_repo_location_name()]["locationOrLoadError"]["__typename"]
== "RepositoryLocation"
)

jobs = {
blob["name"]: blob
for blob in repo_locations["test"]["locationOrLoadError"]["repositories"][0]["jobs"]
for blob in repo_locations[main_repo_location_name()]["locationOrLoadError"][
"repositories"
][0]["jobs"]
}

assert "simple_job_a" in jobs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
execute_dagster_graphql_and_finish_runs,
infer_job_selector,
infer_repository_selector,
main_repo_location_name,
)

from dagster_graphql_tests.graphql.graphql_context_test_suite import (
Expand Down Expand Up @@ -324,7 +325,7 @@ def dummy_asset():

class TestPartitionBackillReadonlyFailure(ReadonlyGraphQLContextTestMatrix):
def _create_backfill(self, graphql_context):
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")

backfill = PartitionBackfill(
Expand Down Expand Up @@ -656,7 +657,7 @@ def test_cancel_asset_backfill(self, graphql_context):

# Update asset backfill data to contain requested partition, but does not execute side effects,
# since launching the run will cause test process will hang forever.
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph
_execute_asset_backfill_iteration_no_side_effects(graphql_context, backfill_id, asset_graph)
Expand Down Expand Up @@ -732,7 +733,7 @@ def test_cancel_then_retry_asset_backfill(self, graphql_context):

# Update asset backfill data to contain requested partition, but does not execute side effects,
# since launching the run will cause test process will hang forever.
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph
_execute_asset_backfill_iteration_no_side_effects(graphql_context, backfill_id, asset_graph)
Expand Down Expand Up @@ -1043,7 +1044,7 @@ def test_asset_backfill_partition_stats(self, graphql_context):
assert result.data["launchPartitionBackfill"]["__typename"] == "LaunchBackfillSuccess"
backfill_id = result.data["launchPartitionBackfill"]["backfillId"]

code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph

Expand Down Expand Up @@ -1086,7 +1087,7 @@ def test_asset_backfill_partition_stats(self, graphql_context):
assert asset_partition_status_counts[0]["numPartitionsFailed"] == 2

def test_asset_backfill_status_with_upstream_failure(self, graphql_context):
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph

Expand Down Expand Up @@ -1557,7 +1558,7 @@ def test_launch_backfill_with_all_partitions_flag(self, graphql_context):
assert len(result.data["partitionBackfillOrError"]["partitionNames"]) == 10

def test_reexecute_asset_backfill_from_failure(self, graphql_context):
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph

Expand Down Expand Up @@ -1654,7 +1655,7 @@ def test_reexecute_asset_backfill_from_failure(self, graphql_context):
assert retried_backfill.tags.get(ROOT_BACKFILL_ID_TAG) == backfill_id

def test_reexecute_successful_asset_backfill(self, graphql_context):
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph

Expand Down Expand Up @@ -1741,7 +1742,7 @@ def test_reexecute_successful_asset_backfill(self, graphql_context):
assert retried_backfill.tags.get(ROOT_BACKFILL_ID_TAG) == backfill_id

def test_reexecute_asset_backfill_still_in_progress(self, graphql_context):
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph

Expand Down Expand Up @@ -1865,7 +1866,7 @@ def test_reexecute_asset_backfill_still_in_progress(self, graphql_context):
assert retried_backfill.tags.get(ROOT_BACKFILL_ID_TAG) == backfill_id

def test_reexecute_asset_backfill_twice(self, graphql_context):
code_location = graphql_context.get_code_location("test")
code_location = graphql_context.get_code_location(main_repo_location_name())
repository = code_location.get_repository("test_repo")
asset_graph = repository.asset_graph

Expand Down
Loading

0 comments on commit a89676f

Please sign in to comment.