Skip to content

feat: add Github oauth provider #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
551 changes: 529 additions & 22 deletions Cargo.lock

Large diffs are not rendered by default.

24 changes: 18 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@ tokio = { version = "1.0", features = [
"sync",
] }
tokio-stream = { version = "0.1.15" }
tracing = { version = "0.1.40" }
tracing = { workspace = true }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
time = { version = "0.3.36", features = [] }
thiserror = { version = "1.0.63", features = [] }
dotenvy = { version = "0.15.7" }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
regex = { version = "1.10.6" }
axum = { version = "0.7.5" }
thiserror = { workspace = true }
time = { workspace = true }
uuid = { workspace = true }
regex = { workspace = true }
reqwest = { workspace = true }
serde_json = { workspace = true }
serde = { workspace = true }

[build-dependencies]
tonic-build = "0.12"

[workspace.dependencies]
tracing = { version = "0.1.40" }
thiserror = { version = "1.0.63" }
time = { version = "0.3.36" }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
regex = { version = "1.10.6" }
reqwest = { version = "0.12.7", features = ["json"] }
serde_json = { version = "1.0.128" }
serde = { version = "1.0", features = ["derive"] }
5 changes: 5 additions & 0 deletions proto/empty.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
syntax = "proto3";

package tailcall;

message Empty { }
28 changes: 28 additions & 0 deletions proto/tailcall.proto
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
syntax = "proto3";
import "empty.proto";

package tailcall;

Expand Down Expand Up @@ -32,3 +33,30 @@ enum GithubStatusEnum {
Error = 1;
Deployed = 2;
}

service GithubAuthService {
rpc Start(LoginRequest) returns (LoginLinkResponse);
rpc GetAccessToken(GetAccessTokenRequest) returns (GetAccessTokenResponse);
rpc UserInfo(Empty) returns (UserInfoResponse);
}

message LoginRequest {
string state = 1;
}

message LoginLinkResponse {
string url = 1;
}

message GetAccessTokenRequest {
string access_code = 1;
}

message GetAccessTokenResponse {
string access_token = 1;
}

message UserInfoResponse {
int64 id = 1;
string username = 2;
}
13 changes: 13 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ enum AppError {
Simple(String),
#[error("IO Error: {0}")]
IoError(#[from] std::io::Error),
#[error("Remote Error: {0}")]
RemoteRequestError(#[from] reqwest::Error),
}

type AppResult<T> = Result<T, AppError>;

// TODO: add logging to make debugging easier
impl From<AppError> for tonic::Status {
fn from(value: AppError) -> Self {
match value {
AppError::Simple(error) => tonic::Status::aborted(error),
AppError::IoError(_error) => tonic::Status::internal("IO Error"),
AppError::RemoteRequestError(_error) => tonic::Status::internal("Remote Request Error"),
}
}
}
45 changes: 28 additions & 17 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use std::env;
use tailcall_launchpad::{
proto::github_service_server::GithubServiceServer,
services::github_service::GithubDeploymentService,
proto::{
github_auth_service_server::GithubAuthServiceServer,
github_service_server::GithubServiceServer,
},
services::{
github_auth_service::{auth_interceptor, GithubAuthService},
github_service::GithubDeploymentService,
},
};
use tonic::transport::Server;
use tracing_subscriber::prelude::*;
Expand All @@ -24,31 +31,35 @@ async fn main() {
}

// initialize services
let github_deployment_service = GithubServiceServer::new(GithubDeploymentService::default());
let github_deployment_service =
GithubServiceServer::with_interceptor(GithubDeploymentService::default(), auth_interceptor);

let client_id = env::var("OAUTH_CLIENT_ID").expect("OAUTH_CLIENT_ID is not set in .env file");
let client_secret =
env::var("OAUTH_CLIENT_SECRET").expect("OAUTH_CLIENT_SECRET is not set in .env file");
let github_auth_service = GithubAuthServiceServer::with_interceptor(
GithubAuthService::new(&client_id, &client_secret),
auth_interceptor,
);

// reflection service
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(tailcall_launchpad::proto::FILE_DESCRIPTOR_SET)
.build_v1()
.unwrap();

// start server
let grpc_service = Server::builder()
.add_service(reflection_service)
.add_service(github_deployment_service)
.into_service()
.into_axum_router();

run(grpc_service).await;
}

async fn run(router: axum::Router) {
// extract important config variables
use std::env;
let host = env::var("SERVER_HOST").expect("SERVER_HOST is not set in .env file");
let port = env::var("SERVER_PORT").expect("SERVER_PORT is not set in .env file");
let addr = format!("{host}:{port}");

let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, router).await.unwrap();
println!("server running {}", addr);
// start server
Server::builder()
.add_service(reflection_service)
.add_service(github_deployment_service)
.add_service(github_auth_service)
.serve(addr.parse().unwrap())
.await
.unwrap();
}
150 changes: 150 additions & 0 deletions src/services/github_auth_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use regex::Regex;
use reqwest::StatusCode;
use serde::Deserialize;
use tonic::{Request, Response, Status};

use crate::{
proto::{
github_auth_service_server, Empty, GetAccessTokenRequest, GetAccessTokenResponse,
LoginLinkResponse, LoginRequest, UserInfoResponse,
},
AppError,
};

#[derive(Debug)]
pub struct GithubAuthService {
client_id: String,
client_secret: String,
}

impl GithubAuthService {
pub fn new(client_id: &str, client_secret: &str) -> Self {
Self {
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
}
}
}

#[tonic::async_trait]
impl github_auth_service_server::GithubAuthService for GithubAuthService {
async fn start(
&self,
request: Request<LoginRequest>,
) -> Result<Response<LoginLinkResponse>, Status> {
Ok(Response::new(LoginLinkResponse {
url: format!(
"https://github.com/login/oauth/authorize?client_id={}&state={}",
self.client_id,
request.into_inner().state
),
}))
}

async fn get_access_token(
&self,
request: Request<GetAccessTokenRequest>,
) -> Result<Response<GetAccessTokenResponse>, Status> {
let access_token = get_access_token(
&self.client_id,
&self.client_secret,
&request.into_inner().access_code,
)
.await?;
Ok(Response::new(GetAccessTokenResponse { access_token }))
}

async fn user_info(
&self,
request: Request<Empty>,
) -> Result<Response<UserInfoResponse>, Status> {
let extension = request.extensions().get::<AuthExtension>().unwrap();

match &extension.bearer {
Some(access_token) => {
let user = get_user(&self.client_id, &self.client_secret, access_token).await?;
Ok(Response::new(user))
}
None => Err(Status::unauthenticated("The request is not authorized")),
}
}
}

async fn get_access_token(
client_id: &str,
client_secret: &str,
access_code: &str,
) -> Result<String, AppError> {
let url = format!(
"https://github.com/login/oauth/access_token?client_id={}&client_secret={}&code={}",
client_id, client_secret, access_code
);
let res = reqwest::get(url).await?;
let body = res.text().await?;
let re = Regex::new("access_token=([a-z_A-Z0-9]+)").unwrap();
let captures = re.captures(&body).unwrap();
Ok(captures[1].to_string())
}

pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
let bearer = match req.metadata().get("authorization") {
Some(bearer) => bearer
.to_str()
.map_err(|_| Status::invalid_argument("`authorization` header is bad formatted"))?
.split(" ")
.last()
.map(|bearer| bearer.to_string()),
None => None,
};

req.extensions_mut().insert(AuthExtension { bearer });

Ok(req)
}

#[derive(Clone)]
struct AuthExtension {
pub bearer: Option<String>,
}

async fn get_user(
client_id: &str,
client_secret: &str,
access_token: &str,
) -> Result<UserInfoResponse, AppError> {
let url = format!("https://api.github.com/applications/{}/token", client_id);

let client = reqwest::Client::new();
let res = client
.post(url)
.body(format!("{{\"access_token\":\"{}\"}}", access_token))
.basic_auth(client_id, Some(client_secret))
.header("Content-Type", "application/json")
.header("User-Agent", "Tailcall Launchpad")
.header("Accept", "application/vnd.github+json")
.header("X-GitHub-Api-Version", "2022-11-28")
.send()
.await?;
if res.status() == StatusCode::OK {
let json: AuthInfoJson = res.json().await.unwrap();

Ok(UserInfoResponse {
id: json.user.id,
username: json.user.login,
})
} else {
Err(AppError::Simple("Could not fetch user data.".to_string()))
}
}

#[derive(Deserialize)]
pub struct AuthInfoJson {
pub id: i64,
pub user: UserInfoJson,
}

#[derive(Deserialize)]
pub struct UserInfoJson {
pub id: i64,
pub login: String,
}
5 changes: 5 additions & 0 deletions src/services/github_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,10 @@ async fn send_error(message_channel: &MessageChannel, err: AppError) {
.send_status(Status::cancelled("IO Error"))
.await
}
AppError::RemoteRequestError(_) => {
message_channel
.send_status(Status::cancelled("Remote Error"))
.await
}
};
}
1 change: 1 addition & 0 deletions src/services/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod github_auth_service;
pub mod github_service;