Skip to content

Commit

Permalink
Merge branch 'develop' into refactor/argilla-server/remove-passlib-de…
Browse files Browse the repository at this point in the history
…pendency
  • Loading branch information
frascuchon authored Nov 6, 2024
2 parents 49f4215 + f346a84 commit a5ad670
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 25 deletions.
7 changes: 6 additions & 1 deletion argilla-frontend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ These are the section headers that we use:
## [Unreleased]()

### Added
- Add a high-contrast theme & improvements for the forced-colors mode ([#5661](https://github.com/argilla-io/argilla/pull/5661))

- Add a high-contrast theme & improvements for the forced-colors mode. ([#5661](https://github.com/argilla-io/argilla/pull/5661))

### Fixed

- Fixed redirection problems after users sign-in using HF OAuth. ([#5635](https://github.com/argilla-io/argilla/pull/5635))

## [2.4.0](https://github.com/argilla-io/argilla/compare/v2.3.0...v2.4.0)

Expand Down
15 changes: 7 additions & 8 deletions argilla-frontend/middleware/route-guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@

import { Context } from "@nuxt/types";
import { useRunningEnvironment } from "~/v1/infrastructure/services/useRunningEnvironment";
import { useLocalStorage } from "~/v1/infrastructure/services";

const { set } = useLocalStorage();

export default ({ $auth, route, redirect }: Context) => {
const { isRunningOnHuggingFace } = useRunningEnvironment();

// By-pass unknown routes. This is needed to avoid errors with API calls.
if (route.name == null) return;

switch (route.name) {
case "sign-in":
if ($auth.loggedIn) return redirect("/");

if (route.params.omitCTA) return;

if (isRunningOnHuggingFace()) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { redirect: _, ...query } = route.query;

return redirect({
name: "welcome-hf-sign-in",
query,
});
}
break;
Expand All @@ -50,14 +52,11 @@ export default ({ $auth, route, redirect }: Context) => {
default:
if (!$auth.loggedIn) {
if (route.path !== "/") {
route.query.redirect = route.fullPath;
set("redirectTo", route.path);
}

redirect({
name: "sign-in",
query: {
...route.query,
},
});
}
}
Expand Down
14 changes: 12 additions & 2 deletions argilla-frontend/pages/oauth/_provider/useOAuthViewModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { useFetch, useRoute } from "@nuxtjs/composition-api";
import { useResolve } from "ts-injecty";
import { ProviderType } from "~/v1/domain/entities/oauth/OAuthProvider";
import { OAuthLoginUseCase } from "~/v1/domain/usecases/oauth-login-use-case";
import { useRoutes, useTranslate } from "~/v1/infrastructure/services";
import {
useRoutes,
useTranslate,
useLocalStorage,
} from "~/v1/infrastructure/services";
import { useNotifications } from "~/v1/infrastructure/services/useNotifications";

export const useOAuthViewModel = () => {
Expand All @@ -11,24 +15,30 @@ export const useOAuthViewModel = () => {
const routes = useRoute();
const router = useRoutes();
const oauthLoginUseCase = useResolve(OAuthLoginUseCase);
const { pop } = useLocalStorage();

useFetch(async () => {
await tryLogin();
});

const redirect = () => {
const redirect = pop("redirectTo");
router.go(redirect || "/");
};

const tryLogin = async () => {
const { params, query } = routes.value;

const provider = params.provider as ProviderType;

try {
await oauthLoginUseCase.login(provider, query);
redirect();
} catch {
notification.notify({
message: t("argilla.api.errors::UnauthorizedError"),
type: "danger",
});
} finally {
router.go("/");
}
};
Expand Down
10 changes: 0 additions & 10 deletions argilla-frontend/pages/sign-in.vue
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,8 @@ export default {
},
},
methods: {
nextRedirect() {
const redirect_url = this.$nuxt.$route.query.redirect || "/";
this.$router.push({
path: redirect_url,
});
},
async loginUser({ username, password }) {
await this.login(username, password);

this.$notification.clear();

this.nextRedirect();
},
async onLoginUser() {
try {
Expand Down
12 changes: 12 additions & 0 deletions argilla-frontend/pages/useSignInViewModel.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import { useResolve } from "ts-injecty";
import { AuthLoginUseCase } from "~/v1/domain/usecases/auth-login-use-case";
import { useRoutes, useLocalStorage } from "~/v1/infrastructure/services";
import { useNotifications } from "~/v1/infrastructure/services/useNotifications";

export const useSignInViewModel = () => {
const useCase = useResolve(AuthLoginUseCase);
const router = useRoutes();
const notification = useNotifications();
const { pop } = useLocalStorage();

const redirect = () => {
const redirect = pop("redirectTo");
router.go(redirect || "/");
};

const login = async (username: string, password: string) => {
await useCase.login(username, password);
notification.clear();
redirect();
};

return {
Expand Down
33 changes: 33 additions & 0 deletions argilla-frontend/v1/domain/entities/hub/DatasetCreation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ const datasetInfo = {
dtype: "string",
_type: "Value",
},
extra: {
dtype: "int32",
_type: "Value",
},
},
],
metadata: {
Expand Down Expand Up @@ -160,6 +164,35 @@ describe("DatasetCreation", () => {
expect(chatField.required).toBeFalsy();
});

it("skip other list feature from chat fields", () => {
const builder = new DatasetCreationBuilder("FAKE", {
default: {
...datasetInfo.default,
features: {
some_list_feature: [
{
other: {
dtype: "string",
_type: "Value",
},
value: {
dtype: "string",
_type: "Value",
},
},
],
},
},
});

const datasetCreation = builder.build();

expect(
datasetCreation.fields.filter((f) => f.type.value === "no mapping")
.length
).toBe(1);
});

it("get no mapped feature", () => {
const builder = new DatasetCreationBuilder("FAKE", {
default: {
Expand Down
13 changes: 11 additions & 2 deletions argilla-frontend/v1/domain/entities/hub/Subset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { QuestionCreation } from "./QuestionCreation";
type Structure = {
name: string;
options?: string[];
role?: string;
content?: string;
structure?: Structure[];
kindObject?: "Value" | "Image" | "ClassLabel" | "Sequence";
type?: "string" | MetadataTypes;
Expand All @@ -33,6 +35,7 @@ export class Subset {

for (const [name, value] of Object.entries<Feature>(datasetInfo.features)) {
if (Array.isArray(value)) {
const { role, content } = value[0];
this.structures.push({
name,
structure: value.map((v) => {
Expand All @@ -42,6 +45,8 @@ export class Subset {
name: key,
kindObject: value._type,
type: value.dtype,
role,
content,
};
}),
});
Expand Down Expand Up @@ -134,8 +139,12 @@ export class Subset {
return "text";

if (structure.kindObject === "Image") return "image";

if (structure.structure?.length > 0) return "chat";
if (
structure.structure?.length > 0 &&
structure.structure[0].content &&
structure.structure[0].role
)
return "chat";
};

const field = FieldCreation.from(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
type Options = "showShortcutsHelper" | "layout";
type Options = "showShortcutsHelper" | "layout" | "redirectTo";

const STORAGE_KEY = "argilla";

Expand Down Expand Up @@ -34,8 +34,15 @@ export const useLocalStorage = () => {
} catch {}
};

const pop = (key: Options) => {
const value = get(key);
set(key, null);
return value;
};

return {
get,
set,
pop,
};
};
3 changes: 2 additions & 1 deletion argilla-server/tests/unit/api/handlers/v1/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ async def test_provider_huggingface_authentication(
):
with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings):
response = await async_client.get(
"/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header
"/api/v1/oauth2/providers/huggingface/authentication?extra=params", headers=owner_auth_header
)
assert response.status_code == 303

redirect_url = URL(response.headers.get("location"))
assert redirect_url.scheme == b"https"
assert redirect_url.host == b"huggingface.co"
assert b"/oauth/authorize?response_type=code&client_id=client_id" in redirect_url.target
assert b"&extra=params" in redirect_url.target

async def test_provider_authentication_with_oauth_disabled(
self,
Expand Down

0 comments on commit a5ad670

Please sign in to comment.