diff --git a/internal/bzlmod/go_mod.bzl b/internal/bzlmod/go_mod.bzl
index 9d37fcf5a..283f128da 100644
--- a/internal/bzlmod/go_mod.bzl
+++ b/internal/bzlmod/go_mod.bzl
@@ -57,7 +57,7 @@ def parse_go_mod(content, path):
             continue
 
         if not current_directive:
-            if tokens[0] not in ["module", "go", "require", "replace", "exclude", "retract"]:
+            if tokens[0] not in ["module", "go", "require", "replace", "exclude", "retract", "toolchain"]:
                 fail("{}:{}: unexpected token '{}' at start of line".format(path, line_no, tokens[0]))
             if len(tokens) == 1:
                 fail("{}:{}: expected another token after '{}'".format(path, line_no, tokens[0]))
@@ -98,7 +98,9 @@ def parse_go_mod(content, path):
     if not go:
         # "As of the Go 1.17 release, if the go directive is missing, go 1.16 is assumed."
         go = "1.16"
-    major, minor = go.split(".")
+
+    # The go directive can contain patch and pre-release versions, but we omit them.
+    major, minor = go.split(".")[:2]
 
     return struct(
         module = module,
diff --git a/tests/bzlmod/go_mod_test.bzl b/tests/bzlmod/go_mod_test.bzl
index a1fcedeb1..ba7519ee2 100644
--- a/tests/bzlmod/go_mod_test.bzl
+++ b/tests/bzlmod/go_mod_test.bzl
@@ -46,6 +46,27 @@ def _go_mod_test_impl(ctx):
 
 go_mod_test = unittest.make(_go_mod_test_impl)
 
+_GO_MOD_21_CONTENT = """go 1.21.0rc1
+
+module example.com
+
+toolchain go1.22.2
+"""
+
+_EXPECTED_GO_MOD_21_PARSE_RESULT = struct(
+    go = (1, 21),
+    module = "example.com",
+    replace_map = {},
+    require = (),
+)
+
+def _go_mod_21_test_impl(ctx):
+    env = unittest.begin(ctx)
+    asserts.equals(env, _EXPECTED_GO_MOD_21_PARSE_RESULT, parse_go_mod(_GO_MOD_21_CONTENT, "/go.mod"))
+    return unittest.end(env)
+
+go_mod_21_test = unittest.make(_go_mod_21_test_impl)
+
 _GO_SUM_CONTENT = """cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
 github.com/bazelbuild/buildtools v0.0.0-20220531122519-a43aed7014c8 h1:fmdo+fvvWlhldUcqkhAMpKndSxMN3vH5l7yow5cEaiQ=
@@ -70,5 +91,6 @@ def go_mod_test_suite(name):
     unittest.suite(
         name,
         go_mod_test,
+        go_mod_21_test,
         go_sum_test,
     )