Skip to content

Commit

Permalink
Make FlatState a Mapping instead of a dict
Browse files Browse the repository at this point in the history
Also, make traverse_util.flatten_dict accept mappings.

Fixes #3879
  • Loading branch information
NeilGirdhar committed Apr 30, 2024
1 parent 2c7d7cd commit b1512f3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
7 changes: 4 additions & 3 deletions flax/experimental/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# limitations under the License.
from __future__ import annotations

from collections.abc import Mapping
import typing as tp
import typing_extensions as tpe

Expand All @@ -42,7 +43,7 @@
A = tp.TypeVar('A')

StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
FlatState = dict[PathParts, StateLeaf]
FlatState = Mapping[PathParts, StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand All @@ -66,8 +67,8 @@ class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable):
def __init__(
self,
mapping: tp.Union[
tp.Mapping[Key, tp.Mapping | StateLeaf],
tp.Iterator[tuple[Key, tp.Mapping | StateLeaf]],
Mapping[Key, Mapping | StateLeaf],
tp.Iterator[tuple[Key, Mapping | StateLeaf]],
],
/,
):
Expand Down
51 changes: 42 additions & 9 deletions flax/traverse_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pytype: skip-file
# Skipped until https://github.com/google/pytype/issues/1619 is resolved.
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -52,18 +54,19 @@
Traversals never mutate the original data. Therefore, an update essentially
returns a copy of the data including the provided updates.
"""
from __future__ import annotations

import abc
from collections.abc import Callable, Mapping
import copy
import dataclasses
import warnings
from typing import Any, Callable
from typing import Any, overload

import jax

import flax
from flax.core.scope import VariableDict
from flax.typing import PathParts
from flax.typing import PathParts, VariableDict

from . import struct

Expand All @@ -77,7 +80,37 @@ class _EmptyNode:
empty_node = _EmptyNode()


def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None):
# TODO: In Python 3.10, use TypeAlias.
IsLeafCallable = Callable[[tuple[Any, ...], Mapping[Any, Any]], bool]


@overload
def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: None | IsLeafCallable = None,
sep: None = None
) -> dict[tuple[Any, ...], Any]:
...

@overload
def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: None | IsLeafCallable = None,
sep: str,
) -> dict[str, Any]:
...

def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: None | IsLeafCallable = None,
sep: None | str = None
) -> dict[Any, Any]:
"""Flatten a nested dictionary.
The nested keys are flattened to a tuple.
Expand Down Expand Up @@ -111,16 +144,16 @@ def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None):
The flattened dictionary.
"""
assert isinstance(
xs, (flax.core.FrozenDict, dict)
), f'expected (frozen)dict; got {type(xs)}'
xs, Mapping
), f'expected Mapping; got {type(xs).__qualname__}'

def _key(path):
def _key(path: tuple[Any, ...]) -> tuple[Any, ...] | str:
if sep is None:
return path
return sep.join(path)

def _flatten(xs, prefix):
if not isinstance(xs, (flax.core.FrozenDict, dict)) or (
def _flatten(xs: Any, prefix: tuple[Any, ...]) -> dict[Any, Any]:
if not isinstance(xs, Mapping) or (
is_leaf and is_leaf(prefix, xs)
):
return {_key(prefix): xs}
Expand Down

0 comments on commit b1512f3

Please sign in to comment.