diff --git a/src/main.rs b/src/main.rs index a658437..3446f29 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,24 +1,24 @@ use std::{io::{Read, Seek}, path::PathBuf}; use clap::{Args, Parser, Subcommand}; use anyhow::Result as AResult; +use futures::{StreamExt,TryStreamExt}; #[derive(Parser)] -#[command(author, version, about, long_about = None)] -#[command(propagate_version = true)] -struct Cli { - #[arg(long)] - path:Option, +#[command(author,version,about,long_about=None)] +#[command(propagate_version=true)] +struct Cli{ #[command(subcommand)] - command: Commands, + command:Commands, } #[derive(Subcommand)] -enum Commands { +enum Commands{ ExtractScripts(PathBufList), Interactive, Replace, Scan, Upload, + UploadScripts(UploadScriptsCommand) } #[derive(Args)] @@ -31,7 +31,14 @@ struct MapList { maps: Vec, } -fn main() -> AResult<()> { +#[derive(Args)] +struct UploadScriptsCommand{ + #[arg(long)] + session_id:PathBuf, +} + +#[tokio::main] +async fn main() -> AResult<()> { let cli = Cli::parse(); match cli.command { Commands::ExtractScripts(pathlist)=>extract_scripts(pathlist.paths), @@ -39,6 +46,7 @@ fn main() -> AResult<()> { Commands::Replace=>replace(), Commands::Scan=>scan(), Commands::Upload=>upload(), + Commands::UploadScripts(command)=>upload_scripts(command.session_id).await, } } @@ -753,3 +761,144 @@ fn interactive() -> AResult<()>{ std::fs::write("id",id.to_string())?; Ok(()) } + +fn hash_source(source:&str)->u64{ + let mut hasher=siphasher::sip::SipHasher::new(); + std::hash::Hasher::write(&mut hasher,source.as_bytes()); + std::hash::Hasher::finish(&hasher) +} + +fn hash_format(hash:u64)->String{ + format!("{:016x}",hash) +} + +type GOCError=submissions_api::types::SingleItemError; +type GOCResult=Result; + +async fn get_or_create_script(api:&submissions_api::external::Context,source:&str)->GOCResult{ + let script_response=api.get_script_from_hash(submissions_api::types::HashRequest{ + hash:hash_format(hash_source(source)).as_str(), + }).await?; + + Ok(match script_response{ + Some(script_response)=>script_response.ID, + None=>api.create_script(submissions_api::types::CreateScriptRequest{ + Name:"Script", + Source:source, + SubmissionID:None, + }).await.map_err(GOCError::Other)?.ID + }) +} + +async fn check_or_create_script_poicy( + api:&submissions_api::external::Context, + hash:&str, + script_policy:submissions_api::types::CreateScriptPolicyRequest, +)->Result<(),GOCError>{ + let script_policy_result=api.get_script_policy_from_hash(submissions_api::types::HashRequest{ + hash, + }).await?; + + match script_policy_result{ + Some(script_policy_reponse)=>{ + // check that everything matches the expectation + assert!(hash==script_policy_reponse.FromScriptHash); + assert!(script_policy.ToScriptID==script_policy_reponse.ToScriptID); + assert!(script_policy.Policy==script_policy_reponse.Policy); + }, + None=>{ + // create a new policy + api.create_script_policy(script_policy).await.map_err(GOCError::Other)?; + } + } + + Ok(()) +} + +async fn do_policy( + api:&submissions_api::external::Context, + script_ids:&std::collections::HashMap<&str,submissions_api::types::ScriptID>, + source:&str, + to_script_id:submissions_api::types::ScriptID, + policy:submissions_api::types::Policy, +)->Result<(),GOCError>{ + let hash=hash_format(hash_source(source)); + check_or_create_script_poicy(api,hash.as_str(),submissions_api::types::CreateScriptPolicyRequest{ + FromScriptID:script_ids[source], + ToScriptID:to_script_id, + Policy:policy, + }).await +} + +async fn upload_scripts(session_id:PathBuf)->AResult<()>{ + let cookie={ + let mut cookie=String::new(); + std::fs::File::open(session_id)?.read_to_string(&mut cookie)?; + submissions_api::Cookie::new(&cookie)? + }; + let api=&submissions_api::external::Context::new("http://localhost:8082".to_owned(),cookie)?; + + let allowed_set=get_allowed_set()?; + let allowed_map=get_allowed_map()?; + let replace_map=get_replace_map()?; + let blocked=get_blocked()?; + + // create a unified deduplicated set of all scripts + let script_set:std::collections::HashSet<&str>=allowed_set.iter() + .map(|s|s.as_str()) + .chain( + replace_map.keys().map(|s|s.as_str()) + ).chain( + blocked.iter().map(|s|s.as_str()) + ).collect(); + + // get or create every unique script + let script_ids:std::collections::HashMap<&str,submissions_api::types::ScriptID>= + futures::stream::iter(script_set) + .map(|source|async move{ + let script_id=get_or_create_script(api,source).await?; + Ok::<_,GOCError>((source,script_id)) + }) + .buffer_unordered(16) + .try_collect().await?; + + // get or create policy for each script in each category + // + // replace + futures::stream::iter(replace_map.iter().map(Ok)) + .try_for_each_concurrent(Some(16),|(source,id)|async{ + do_policy( + api, + &script_ids, + source, + script_ids[allowed_map[id].as_str()], + submissions_api::types::Policy::Replace + ).await + }).await?; + + // allowed + futures::stream::iter(allowed_set.iter().map(Ok)) + .try_for_each_concurrent(Some(16),|source|async{ + do_policy( + api, + &script_ids, + source, + script_ids[source.as_str()], + submissions_api::types::Policy::Allowed + ).await + }).await?; + + // blocked + futures::stream::iter(blocked.iter().map(Ok)) + .try_for_each_concurrent(Some(16),|source|async{ + do_policy( + api, + &script_ids, + source, + script_ids[source.as_str()], + submissions_api::types::Policy::Blocked + ).await + }).await?; + + Ok(()) +}