diff --git a/src/main.rs b/src/main.rs index 5370ea4..15ccd8d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,8 @@ use clap::{Args,Parser,Subcommand}; +use futures::{StreamExt,TryStreamExt}; + +const READ_CONCURRENCY:usize=16; +const REMOTE_CONCURRENCY:usize=16; #[derive(Parser)] #[command(author,version,about,long_about=None)] @@ -11,6 +15,7 @@ struct Cli{ #[derive(Subcommand)] enum Commands{ Review(ReviewCommand), + UploadScripts(UploadScriptsCommand), } #[derive(Args)] @@ -18,6 +23,13 @@ struct ReviewCommand{ #[arg(long)] cookie:String, } +#[derive(Args)] +struct UploadScriptsCommand{ + #[arg(long)] + session_id:String, + #[arg(long)] + api_url:String, +} #[tokio::main] async fn main(){ @@ -26,6 +38,10 @@ async fn main(){ Commands::Review(command)=>review(ReviewConfig{ cookie:command.cookie, }).await.unwrap(), + Commands::UploadScripts(command)=>upload_scripts(UploadConfig{ + session_id:command.session_id, + api_url:command.api_url, + }).await.unwrap(), } } @@ -158,3 +174,242 @@ async fn review(config:ReviewConfig)->Result<(),ReviewError>{ Ok(()) } + +#[allow(dead_code)] +#[derive(Debug)] +enum ScriptUploadError{ + Cookie(submissions_api::CookieError), + Reqwest(submissions_api::ReqwestError), + AllowedSet(std::io::Error), + AllowedMap(GetMapError), + ReplaceMap(GetMapError), + BlockedSet(std::io::Error), + GOC(GOCError), + GOCPolicyReplace(GOCError), + GOCPolicyAllowed(GOCError), + GOCPolicyBlocked(GOCError), +} + +fn read_dir_stream(dir:tokio::fs::ReadDir)->impl futures::stream::Stream>{ + futures::stream::unfold(dir,|mut dir|async{ + match dir.next_entry().await{ + Ok(Some(entry))=>Some((Ok(entry),dir)), + Ok(None)=>None, // End of directory + Err(e)=>Some((Err(e),dir)), // Error encountered + } + }) +} + +async fn get_set_from_file(path:impl AsRef)->std::io::Result>{ + read_dir_stream(tokio::fs::read_dir(path).await?) + .map(|dir_entry|async{ + tokio::fs::read_to_string(dir_entry?.path()).await + }) + .buffer_unordered(READ_CONCURRENCY) + .try_collect().await +} + +async fn get_allowed_set()->std::io::Result>{ + get_set_from_file("scripts/allowed").await +} + +async fn get_blocked_set()->std::io::Result>{ + get_set_from_file("scripts/blocked").await +} + +#[allow(dead_code)] +#[derive(Debug)] +enum GetMapError{ + IO(std::io::Error), + FileStem, + ToStr, + ParseInt(std::num::ParseIntError), +} + +async fn get_allowed_map()->Result,GetMapError>{ + read_dir_stream(tokio::fs::read_dir("scripts/allowed").await.map_err(GetMapError::IO)?) + .map(|dir_entry|async{ + let path=dir_entry.map_err(GetMapError::IO)?.path(); + let id:u32=path.file_stem() + .ok_or(GetMapError::FileStem)? + .to_str() + .ok_or(GetMapError::ToStr)? + .parse().map_err(GetMapError::ParseInt)?; + let source=tokio::fs::read_to_string(path).await.map_err(GetMapError::IO)?; + Ok((id,source)) + }) + .buffer_unordered(READ_CONCURRENCY) + .try_collect().await +} + +async fn get_replace_map()->Result,GetMapError>{ + read_dir_stream(tokio::fs::read_dir("scripts/replace").await.map_err(GetMapError::IO)?) + .map(|dir_entry|async{ + let path=dir_entry.map_err(GetMapError::IO)?.path(); + let id:u32=path.file_stem() + .ok_or(GetMapError::FileStem)? + .to_str() + .ok_or(GetMapError::ToStr)? + .parse().map_err(GetMapError::ParseInt)?; + let source=tokio::fs::read_to_string(path).await.map_err(GetMapError::IO)?; + Ok((source,id)) + }) + .buffer_unordered(READ_CONCURRENCY) + .try_collect().await +} + +fn hash_source(source:&str)->u64{ + let mut hasher=siphasher::sip::SipHasher::new(); + std::hash::Hasher::write(&mut hasher,source.as_bytes()); + std::hash::Hasher::finish(&hasher) +} + +fn hash_format(hash:u64)->String{ + format!("{:016x}",hash) +} + +type GOCError=submissions_api::types::SingleItemError; +type GOCResult=Result; + +async fn get_or_create_script(api:&submissions_api::external::Context,source:&str)->GOCResult{ + let script_response=api.get_script_from_hash(submissions_api::types::HashRequest{ + hash:hash_format(hash_source(source)).as_str(), + }).await?; + + Ok(match script_response{ + Some(script_response)=>script_response.ID, + None=>api.create_script(submissions_api::types::CreateScriptRequest{ + Name:"Script", + Source:source, + SubmissionID:None, + }).await.map_err(GOCError::Other)?.ID + }) +} + +async fn check_or_create_script_poicy( + api:&submissions_api::external::Context, + hash:&str, + script_policy:submissions_api::types::CreateScriptPolicyRequest, +)->Result<(),GOCError>{ + let script_policy_result=api.get_script_policy_from_hash(submissions_api::types::HashRequest{ + hash, + }).await?; + + match script_policy_result{ + Some(script_policy_reponse)=>{ + // check that everything matches the expectation + assert!(hash==script_policy_reponse.FromScriptHash); + assert!(script_policy.ToScriptID==script_policy_reponse.ToScriptID); + assert!(script_policy.Policy==script_policy_reponse.Policy); + }, + None=>{ + // create a new policy + api.create_script_policy(script_policy).await.map_err(GOCError::Other)?; + } + } + + Ok(()) +} + +async fn do_policy( + api:&submissions_api::external::Context, + script_ids:&std::collections::HashMap<&str,submissions_api::types::ScriptID>, + source:&str, + to_script_id:submissions_api::types::ScriptID, + policy:submissions_api::types::Policy, +)->Result<(),GOCError>{ + let hash=hash_format(hash_source(source)); + check_or_create_script_poicy(api,hash.as_str(),submissions_api::types::CreateScriptPolicyRequest{ + FromScriptID:script_ids[source], + ToScriptID:to_script_id, + Policy:policy, + }).await +} + +struct UploadConfig{ + session_id:String, + api_url:String, +} + +async fn upload_scripts(config:UploadConfig)->Result<(),ScriptUploadError>{ + let cookie=submissions_api::Cookie::new(&config.session_id).map_err(ScriptUploadError::Cookie)?; + let api=&submissions_api::external::Context::new(config.api_url,cookie).map_err(ScriptUploadError::Reqwest)?; + + let ( + allowed_set_result, + allowed_map_result, + replace_map_result, + blocked_set_result, + )=tokio::join!( + get_allowed_set(), + get_allowed_map(), + get_replace_map(), + get_blocked_set(), + ); + + let allowed_set=allowed_set_result.map_err(ScriptUploadError::AllowedSet)?; + let allowed_map=allowed_map_result.map_err(ScriptUploadError::AllowedMap)?; + let replace_map=replace_map_result.map_err(ScriptUploadError::ReplaceMap)?; + let blocked_set=blocked_set_result.map_err(ScriptUploadError::BlockedSet)?; + + // create a unified deduplicated set of all scripts + let script_set:std::collections::HashSet<&str>=allowed_set.iter() + .map(String::as_str) + .chain( + replace_map.keys().map(String::as_str) + ).chain( + blocked_set.iter().map(String::as_str) + ).collect(); + + // get or create every unique script + let script_ids:std::collections::HashMap<&str,submissions_api::types::ScriptID>= + futures::stream::iter(script_set) + .map(|source|async move{ + let script_id=get_or_create_script(api,source).await?; + Ok((source,script_id)) + }) + .buffer_unordered(REMOTE_CONCURRENCY) + .try_collect().await.map_err(ScriptUploadError::GOC)?; + + // get or create policy for each script in each category + // + // replace + let replace_fut=futures::stream::iter(replace_map.iter().map(Ok)) + .try_for_each_concurrent(Some(REMOTE_CONCURRENCY),|(source,id)|async{ + do_policy( + api, + &script_ids, + source, + script_ids[allowed_map[id].as_str()], + submissions_api::types::Policy::Replace + ).await.map_err(ScriptUploadError::GOCPolicyReplace) + }); + + // allowed + let allowed_fut=futures::stream::iter(allowed_set.iter().map(Ok)) + .try_for_each_concurrent(Some(REMOTE_CONCURRENCY),|source|async{ + do_policy( + api, + &script_ids, + source, + script_ids[source.as_str()], + submissions_api::types::Policy::Allowed + ).await.map_err(ScriptUploadError::GOCPolicyAllowed) + }); + + // blocked + let blocked_fut=futures::stream::iter(blocked_set.iter().map(Ok)) + .try_for_each_concurrent(Some(REMOTE_CONCURRENCY),|source|async{ + do_policy( + api, + &script_ids, + source, + script_ids[source.as_str()], + submissions_api::types::Policy::Blocked + ).await.map_err(ScriptUploadError::GOCPolicyBlocked) + }); + + tokio::try_join!(replace_fut,allowed_fut,blocked_fut)?; + + Ok(()) +}