EOD
This commit is contained in:
777
leadfinder/src/main.rs
Normal file
777
leadfinder/src/main.rs
Normal file
@@ -0,0 +1,777 @@
|
||||
//TODO include [Y/n] confimrmation for API-calling commands and -y flag
|
||||
//TODO Use tracing crate
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
env,
|
||||
fmt::{Debug, Display},
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
process::ExitCode,
|
||||
str::FromStr,
|
||||
};
|
||||
use tokio::time;
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Task {
|
||||
/// Parses a given input file (.csv) and lists the contained companies. Panics if a `name` column does not exist.
|
||||
List {
|
||||
/// The csv file.
|
||||
csv_file: String,
|
||||
},
|
||||
/// Parses a given input file (.json) and counts the number of occurences per company as a reference client of other companies.
|
||||
/// Expects the json file structure given by the LLM prompts
|
||||
/// (TODO: Not quite. If there are multiple prompt results in a single file, concatination of multiple results has to be done manually at the moment. Thats why we like helix tho.)
|
||||
CountPartners {
|
||||
/// The json file, containing the required format.
|
||||
json_file: String,
|
||||
},
|
||||
/// WARNING: This might cost money to run!
|
||||
/// Tests a given system prompt on a given model with the context prompt ""Research Target Company: Pandaloop\n Website: https://pandaloop.de"
|
||||
TestWithPandaloop {
|
||||
#[arg(short, long)]
|
||||
model: Model,
|
||||
#[arg(short, long)]
|
||||
system_prompt_file: String,
|
||||
#[arg(short, long)]
|
||||
pretty: bool,
|
||||
},
|
||||
/// WARNING: This might cost money to run!
|
||||
SinglePrompt {
|
||||
idx: usize,
|
||||
#[arg(short, long)]
|
||||
csv_file: String,
|
||||
#[arg(short, long)]
|
||||
model: Model,
|
||||
#[arg(short, long)]
|
||||
system_prompt_file: String,
|
||||
#[arg(short, long)]
|
||||
pretty: bool,
|
||||
},
|
||||
/// WARNING: This might cost money to run!
|
||||
RangePrompt {
|
||||
range_start: usize,
|
||||
range_end: usize,
|
||||
#[arg(short, long)]
|
||||
csv_file: String,
|
||||
#[arg(short, long)]
|
||||
model: Model,
|
||||
#[arg(short, long)]
|
||||
system_prompt_file: String,
|
||||
#[arg(short, long)]
|
||||
pretty: bool,
|
||||
},
|
||||
/// WARNING: This might cost A LOT money to run!
|
||||
Multiprompt {
|
||||
range_start: usize,
|
||||
range_end: usize,
|
||||
range_step: usize,
|
||||
#[arg(short, long)]
|
||||
csv_file: String,
|
||||
#[arg(short, long)]
|
||||
model: Model,
|
||||
#[arg(short, long)]
|
||||
system_prompt_file: String,
|
||||
#[arg(short, long)]
|
||||
pretty: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
enum Error {
|
||||
#[error("context index range is invalid: {0}")]
|
||||
InvalidRange(String),
|
||||
#[error("api key {0} does not exist in .env file")]
|
||||
ApiKeyNotFound(Model),
|
||||
#[error("Error parsing {filename:?} as csv file: {source}")]
|
||||
CsvFile {
|
||||
#[source]
|
||||
source: csv::Error,
|
||||
filename: Option<PathBuf>,
|
||||
},
|
||||
#[error("Error parsing json: {0}")]
|
||||
SerdeJson(#[from] serde_json::Error),
|
||||
#[error("Error parsing {filename:?} as json: {source}")]
|
||||
SerdeJsonFile {
|
||||
#[source]
|
||||
source: serde_json::Error,
|
||||
filename: Option<PathBuf>,
|
||||
},
|
||||
#[error("Error parsing json: \n\njson:\n{info}")]
|
||||
SerdeJsonInfo {
|
||||
#[source]
|
||||
source: serde_json::Error,
|
||||
info: String
|
||||
},
|
||||
#[error("reqwest Error: {0}")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
#[error("io Error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CsvCompany {
|
||||
#[serde(rename = "Company Name")]
|
||||
name: String,
|
||||
#[serde(rename = "Website")]
|
||||
website: String,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerdeJsonTargetCompany {
|
||||
#[serde(rename = "target_company")]
|
||||
name: String,
|
||||
connections: Vec<SerdeJsonPartnerCompany>,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerdeJsonPartnerCompany {
|
||||
category: String,
|
||||
context: String,
|
||||
partner_name: String,
|
||||
people: Vec<String>,
|
||||
source_type: String,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerdeJsonContact {
|
||||
value: String,
|
||||
value_type: String,
|
||||
category: String,
|
||||
label: String,
|
||||
source_url: String,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerdeJsonEmployee {
|
||||
name: String,
|
||||
role: String,
|
||||
email: String,
|
||||
phone: String,
|
||||
linkedin_url: String,
|
||||
source_url: String,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerdeJsonLead {
|
||||
company_name: String,
|
||||
website: String,
|
||||
industry: String,
|
||||
description: String,
|
||||
employee_count: String,
|
||||
general_contacts: Vec<SerdeJsonContact>,
|
||||
employees: Vec<SerdeJsonContact>,
|
||||
}
|
||||
|
||||
fn prettify(content: &str, model: Model) -> Result<String, Error> {
|
||||
let ser_all: serde_json::Value = serde_json::from_str(content)?;
|
||||
let cont = match model {
|
||||
Model::Gemini => ser_all["candidates"][0]["content"]["parts"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.replace("\\\"", "\"")
|
||||
.replace("\\n", "\n"),
|
||||
Model::Perplexity => {
|
||||
let raw_msg = ser_all["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.replace("\\\"", "\"")
|
||||
.replace("\\n", "\n");
|
||||
let pref_stripped_msg = raw_msg.strip_prefix("```json").unwrap_or(&raw_msg);
|
||||
let bidir_stripped_msg = pref_stripped_msg
|
||||
.strip_suffix("```")
|
||||
.unwrap_or(&pref_stripped_msg);
|
||||
bidir_stripped_msg.to_owned()
|
||||
}
|
||||
};
|
||||
|
||||
eprintln!("{cont}");
|
||||
let ser_cont: serde_json::Value =
|
||||
serde_json::from_str(&cont).map_err(|source| Error::SerdeJson(source))?;
|
||||
let cont_pretty = serde_json::to_string_pretty(&ser_cont)?;
|
||||
Ok(cont_pretty)
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone, Copy, Default, Debug, Serialize)]
|
||||
enum Model {
|
||||
#[default]
|
||||
Gemini,
|
||||
Perplexity,
|
||||
}
|
||||
|
||||
impl Display for Model {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Model::Gemini => {
|
||||
write!(f, "Google Gemini: 2.5 Flash")
|
||||
}
|
||||
Model::Perplexity => {
|
||||
write!(f, "Perplexity AI: Sonar")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
struct ApiKeys {
|
||||
gemini_api_key: String,
|
||||
perplexity_api_key: String,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
}
|
||||
|
||||
fn parse_csv(file_path: &Path) -> Result<Vec<CsvCompany>, Error> {
|
||||
let mut reader = csv::ReaderBuilder::new()
|
||||
.has_headers(true)
|
||||
.from_path(&file_path)
|
||||
.map_err(|source| Error::CsvFile {
|
||||
source,
|
||||
filename: Some(file_path.to_owned()),
|
||||
})?;
|
||||
let mut companies = Vec::new();
|
||||
for res in reader.deserialize() {
|
||||
let record: CsvCompany = res.map_err(|source| Error::CsvFile {
|
||||
source,
|
||||
filename: Some(file_path.to_owned()),
|
||||
})?;
|
||||
companies.push(record);
|
||||
}
|
||||
Ok(companies)
|
||||
}
|
||||
|
||||
async fn prompt(
|
||||
model: Model,
|
||||
reqclient: &reqwest::Client,
|
||||
api_keys: &ApiKeys,
|
||||
system_prompt: &str,
|
||||
context_prompt: &str,
|
||||
) -> Result<reqwest::Response, reqwest::Error> {
|
||||
match model {
|
||||
Model::Gemini => {
|
||||
gemini_prompt(
|
||||
reqclient,
|
||||
&api_keys.gemini_api_key,
|
||||
system_prompt,
|
||||
context_prompt,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Model::Perplexity => {
|
||||
perplexity_prompt(
|
||||
reqclient,
|
||||
&api_keys.perplexity_api_key,
|
||||
system_prompt,
|
||||
context_prompt,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn perplexity_prompt(
|
||||
reqclient: &reqwest::Client,
|
||||
perplexity_api_key: &str,
|
||||
system_prompt: &str,
|
||||
context_prompt: &str,
|
||||
) -> Result<reqwest::Response, reqwest::Error> {
|
||||
let body = format!(
|
||||
r#"{{
|
||||
"model": "sonar",
|
||||
"messages": [
|
||||
{{
|
||||
"role": "system",
|
||||
"content": "{system_prompt}"
|
||||
}},
|
||||
{{
|
||||
"role": "user",
|
||||
"content": "{context_prompt}"
|
||||
}}
|
||||
],
|
||||
"temperature": 0
|
||||
}}"#,
|
||||
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
|
||||
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n")
|
||||
);
|
||||
|
||||
reqclient
|
||||
.post("https://api.perplexity.ai/chat/completions")
|
||||
.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {perplexity_api_key}"))
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn gemini_prompt(
|
||||
reqclient: &reqwest::Client,
|
||||
gemini_api_key: &str,
|
||||
system_prompt: &str,
|
||||
context_prompt: &str,
|
||||
) -> Result<reqwest::Response, reqwest::Error> {
|
||||
let body = format!(
|
||||
r#"{{
|
||||
"contents": [
|
||||
{{
|
||||
"parts": [
|
||||
{{
|
||||
"text": "SYSTEM PROMPT:\n\n\n{system_prompt}\n\nContext Prompt:{context_prompt}\n\n\n\nSYSTEM PROMPT:\n\n\n{system_prompt}\n\nContext Prompt:{context_prompt}\n"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
],
|
||||
"tools": [
|
||||
{{
|
||||
"google_search": {{}}
|
||||
}}
|
||||
],
|
||||
"generationConfig": {{
|
||||
"temperature": 0.0,
|
||||
}}
|
||||
}}"#,
|
||||
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
|
||||
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n")
|
||||
);
|
||||
|
||||
reqclient
|
||||
.post("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent")
|
||||
.header("x-goog-api-key", format!("{gemini_api_key}"))
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run(args: Args) -> Result<(), Error> {
|
||||
let _ = dotenv::dotenv();
|
||||
let mut vars = env::vars().collect::<HashMap<String, String>>();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let gemini_key = match vars.remove("GEMINI_API_KEY") {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
return Err(Error::ApiKeyNotFound(Model::Gemini));
|
||||
}
|
||||
};
|
||||
let perplexity_key = match vars.remove("PERPLEXITY_API_KEY") {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
return Err(Error::ApiKeyNotFound(Model::Perplexity));
|
||||
}
|
||||
};
|
||||
let api_keys = ApiKeys {
|
||||
gemini_api_key: gemini_key,
|
||||
perplexity_api_key: perplexity_key,
|
||||
};
|
||||
|
||||
match args.task {
|
||||
Task::List { csv_file } => {
|
||||
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
||||
let csv = parse_csv(&pathbuf)?;
|
||||
for (idx, company) in csv.iter().enumerate() {
|
||||
println!(
|
||||
"n = {: >5} | Company = {:?} ({:?})",
|
||||
idx, company.name, company.website
|
||||
);
|
||||
}
|
||||
}
|
||||
Task::CountPartners { json_file } => {
|
||||
let pathbuf = PathBuf::from_str(&json_file).unwrap_or_else(|_| unreachable!());
|
||||
let datafile = fs::read_to_string(json_file).unwrap();
|
||||
let companies: Vec<SerdeJsonTargetCompany> =
|
||||
serde_json::from_str(&datafile).map_err(|source| Error::SerdeJsonFile {
|
||||
source,
|
||||
filename: Some(pathbuf),
|
||||
})?;
|
||||
let mut partner_occurrences = HashMap::new();
|
||||
for target_company in companies {
|
||||
let connections_array = target_company.connections;
|
||||
for connection in connections_array {
|
||||
if connection.category == "REFERENCE_CLIENT" {
|
||||
partner_occurrences
|
||||
.entry(connection.partner_name)
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut count_vec: Vec<_> = partner_occurrences.iter().collect();
|
||||
count_vec.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (k, v) in count_vec {
|
||||
println!("Company {:>60} appears {:>3} times", k, v);
|
||||
}
|
||||
}
|
||||
Task::TestWithPandaloop {
|
||||
model,
|
||||
pretty,
|
||||
system_prompt_file,
|
||||
} => {
|
||||
let system_prompt =
|
||||
fs::read_to_string(system_prompt_file).unwrap_or_else(|_| unreachable!());
|
||||
let context_prompt =
|
||||
"Research Target Company T: Pandaloop\nWebsite: https://pandaloop.de\n".to_owned();
|
||||
|
||||
let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
|
||||
.await?
|
||||
.text()
|
||||
.await?;
|
||||
if pretty {
|
||||
response = prettify(&response, model)?;
|
||||
}
|
||||
println!("{response}");
|
||||
}
|
||||
Task::SinglePrompt {
|
||||
idx,
|
||||
csv_file,
|
||||
model,
|
||||
system_prompt_file,
|
||||
pretty,
|
||||
} => {
|
||||
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
||||
let csv = parse_csv(&pathbuf)?;
|
||||
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
||||
|
||||
if let Some(company) = csv.get(idx) {
|
||||
let context_prompt = format!(
|
||||
"Research Target Company: {}\nWebsite: {}\n",
|
||||
company.name, company.website
|
||||
);
|
||||
let mut response =
|
||||
prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
|
||||
.await?
|
||||
.text()
|
||||
.await?;
|
||||
if pretty {
|
||||
response = prettify(&response, model)?;
|
||||
}
|
||||
println!("{response}");
|
||||
} else {
|
||||
eprintln!("Index {:?} out of range, skipping.", idx);
|
||||
}
|
||||
}
|
||||
Task::RangePrompt {
|
||||
range_start,
|
||||
range_end,
|
||||
csv_file,
|
||||
model,
|
||||
system_prompt_file,
|
||||
pretty,
|
||||
} => {
|
||||
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
||||
let csv = parse_csv(&pathbuf)?;
|
||||
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
||||
|
||||
if range_start >= range_end {
|
||||
return Err(Error::InvalidRange(format!(
|
||||
"range args {range_start}..{range_end} are invalid: range_start should be smaller than range_end"
|
||||
)));
|
||||
}
|
||||
let mut context_prompt = String::new();
|
||||
for i in range_start..=range_end {
|
||||
if let Some(company) = csv.get(i) {
|
||||
context_prompt = format!(
|
||||
"{}\nResearch Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.",
|
||||
context_prompt, company.name, company.website
|
||||
);
|
||||
} else {
|
||||
eprintln!("Index {:?} out of range, skipping.", i);
|
||||
}
|
||||
}
|
||||
let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
|
||||
.await?
|
||||
.text()
|
||||
.await?;
|
||||
if pretty {
|
||||
response = prettify(&response, model)?;
|
||||
}
|
||||
println!("{response}");
|
||||
}
|
||||
Task::Multiprompt {
|
||||
range_start,
|
||||
range_end,
|
||||
range_step,
|
||||
csv_file,
|
||||
model,
|
||||
system_prompt_file,
|
||||
pretty,
|
||||
} => {
|
||||
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
||||
let csv = parse_csv(&pathbuf)?;
|
||||
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
||||
if range_start >= range_end {}
|
||||
let mut i = range_start;
|
||||
while i <= range_end {
|
||||
let mut context_prompt = String::new();
|
||||
for j in i..i + range_step {
|
||||
if let Some(company) = csv.get(j) {
|
||||
context_prompt = format!(
|
||||
"{}\nResearch Target Company T: {}\nWebsite: {}\n\n\n",
|
||||
context_prompt, company.name, company.website
|
||||
);
|
||||
} else {
|
||||
eprintln!("Index {:?} out of range, skipping.", j);
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"Dispatching prompt for range {}..={}",
|
||||
i,
|
||||
i + range_step - 1
|
||||
);
|
||||
let response_result =
|
||||
prompt(model, &client, &api_keys, &system_prompt, &context_prompt).await;
|
||||
match response_result {
|
||||
Ok(response) => {
|
||||
if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
eprintln!("429 reached. Go to bed.");
|
||||
return Ok(());
|
||||
} else {
|
||||
}
|
||||
let mut response_text = response.text().await?;
|
||||
if pretty {
|
||||
response_text = prettify(&response_text, model)?;
|
||||
}
|
||||
println!("{response_text}");
|
||||
i += range_step;
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("{}", error);
|
||||
let _ = time::sleep(time::Duration::from_millis(5000));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
// match args {
|
||||
// Args {
|
||||
// list: true,
|
||||
// test: _,
|
||||
// index: _,
|
||||
// range: _,
|
||||
// multiprompt_range: _,
|
||||
// output_file: _,
|
||||
// model: _,
|
||||
// pretty: _,
|
||||
// count_partners: _,
|
||||
// } => {
|
||||
// for (idx, company) in csv.iter().enumerate() {
|
||||
// println!("n = {: >5} | Company = {:?}", idx, company.name);
|
||||
// }
|
||||
// }
|
||||
// Args {
|
||||
// list: _,
|
||||
// test: _,
|
||||
// index: _,
|
||||
// range: _,
|
||||
// multiprompt_range: _,
|
||||
// output_file: _,
|
||||
// model: _,
|
||||
// pretty: _,
|
||||
// count_partners: true,
|
||||
// } => {
|
||||
// // TODO not hard-coded
|
||||
// let mut partner_occurrences = HashMap::new();
|
||||
// let datafile =
|
||||
// fs::read_to_string("/home/jan/pl/qwertus/leadfinder/out/perp_out.json").unwrap();
|
||||
// let json: serde_json::Value = serde_json::from_str(&datafile)?;
|
||||
// let company_array = json.as_array().unwrap_or_else(|| panic!("U fucked up! json file to inspect should have array of companies as first level value."));
|
||||
// for target_company in company_array {
|
||||
// let connections_array = target_company["connections"].as_array().unwrap_or_else(|| panic!("U fucked up! json file to inspect should have array of companies as first level value."));
|
||||
// for connection in connections_array {
|
||||
// if connection["category"].as_str().unwrap_or_else(|| panic!("Ur AI probably fucked up: connection company with no partner_name in file!")) == "REFERENCE_CLIENT" {
|
||||
// partner_occurrences
|
||||
// .entry(connection["partner_name"].as_str().unwrap_or_else(|| panic!("Ur AI probably fucked up: connection company with no partner_name in file!")))
|
||||
// .and_modify(|count| *count += 1)
|
||||
// .or_insert(1);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// let mut count_vec: Vec<_> = partner_occurrences.iter().collect();
|
||||
// count_vec.sort_by(|a, b| b.1.cmp(a.1));
|
||||
// for (k, v) in count_vec {
|
||||
// println!("Company {:>60} appears {:>3} times", k, v);
|
||||
// }
|
||||
// }
|
||||
// Args {
|
||||
// test: true,
|
||||
// list: _,
|
||||
// index: _,
|
||||
// range: _,
|
||||
// multiprompt_range: _,
|
||||
// output_file: _,
|
||||
// model,
|
||||
// pretty,
|
||||
// count_partners: _,
|
||||
// } => {
|
||||
// let context_prompt = "Research Target Company T: Pandaloop\n\
|
||||
// Website: https://pandaloop.de\n\
|
||||
// "
|
||||
// .to_owned();
|
||||
|
||||
// let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
|
||||
// .await?
|
||||
// .text()
|
||||
// .await?;
|
||||
// if pretty {
|
||||
// response = prettify(&response, model)?;
|
||||
// }
|
||||
// file.write_all(&response.as_bytes())?;
|
||||
// }
|
||||
// Args {
|
||||
// index: Some(i),
|
||||
// list: _,
|
||||
// test: _,
|
||||
// range: _,
|
||||
// multiprompt_range: _,
|
||||
// output_file: _,
|
||||
// model,
|
||||
// pretty,
|
||||
// count_partners: _,
|
||||
// } => {
|
||||
// if let Some(company) = csv.get(i) {
|
||||
// let context_prompt = format!(
|
||||
// "Research Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.",
|
||||
// company.name, company.website
|
||||
// );
|
||||
// let mut response =
|
||||
// prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
|
||||
// .await?
|
||||
// .text()
|
||||
// .await?;
|
||||
// if pretty {
|
||||
// response = prettify(&response, model)?;
|
||||
// }
|
||||
// file.write_all(&response.as_bytes())?;
|
||||
// } else {
|
||||
// println!("Index {:?} out of range, skipping.", i);
|
||||
// }
|
||||
// }
|
||||
// Args {
|
||||
// list: _,
|
||||
// test: _,
|
||||
// index: _,
|
||||
// range: Some(range),
|
||||
// multiprompt_range: _,
|
||||
// output_file: _,
|
||||
// model,
|
||||
// pretty,
|
||||
// count_partners: _,
|
||||
// } => {
|
||||
// if range[0] >= range[1] {
|
||||
// println!("Range start larger than range end. Doing nothing.");
|
||||
// return Ok(());
|
||||
// }
|
||||
// let mut context_prompt = String::new();
|
||||
// for i in range[0]..=range[1] {
|
||||
// if let Some(company) = csv.get(i) {
|
||||
// context_prompt = format!(
|
||||
// "{}\nResearch Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.",
|
||||
// context_prompt, company.name, company.website
|
||||
// );
|
||||
// } else {
|
||||
// println!("Index {:?} out of range, skipping.", i);
|
||||
// }
|
||||
// }
|
||||
// let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
|
||||
// .await?
|
||||
// .text()
|
||||
// .await?;
|
||||
// if pretty {
|
||||
// response = prettify(&response, model)?;
|
||||
// }
|
||||
// file.write_all(&response.as_bytes())?;
|
||||
// }
|
||||
// Args {
|
||||
// list: _,
|
||||
// test: _,
|
||||
// index: _,
|
||||
// range: _,
|
||||
// multiprompt_range: Some(multirange),
|
||||
// output_file: _,
|
||||
// model,
|
||||
// pretty,
|
||||
// count_partners: _,
|
||||
// } => {
|
||||
// if multirange[0] >= multirange[1] {
|
||||
// println!("Range start larger than range end. Doing nothing.");
|
||||
// return Ok(());
|
||||
// }
|
||||
// let mut i = multirange[0];
|
||||
// while i <= multirange[1] {
|
||||
// let mut context_prompt = String::new();
|
||||
// for j in i..i + multirange[2] {
|
||||
// if let Some(company) = csv.get(j) {
|
||||
// context_prompt = format!(
|
||||
// "{}\nResearch Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.\n\n",
|
||||
// context_prompt, company.name, company.website
|
||||
// );
|
||||
// } else {
|
||||
// println!("Index {:?} out of range, skipping.", j);
|
||||
// }
|
||||
// }
|
||||
// println!(
|
||||
// "Dispatching prompt for range {}..={}",
|
||||
// i,
|
||||
// i + multirange[2] - 1
|
||||
// );
|
||||
// let response_result =
|
||||
// prompt(model, &client, &api_keys, &system_prompt, &context_prompt).await;
|
||||
// match response_result {
|
||||
// Ok(response) => {
|
||||
// if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
// println!("429 reached. Go to bed.");
|
||||
// return Ok(());
|
||||
// } else {
|
||||
// }
|
||||
// let mut response_text = response.text().await?;
|
||||
// if pretty {
|
||||
// response_text = prettify(&response_text, model)?;
|
||||
// }
|
||||
// file.write_all(&response_text.as_bytes())?;
|
||||
// i += multirange[2];
|
||||
// }
|
||||
// Err(error) => {
|
||||
// println!("{}", error);
|
||||
// let _ = time::sleep(time::Duration::from_millis(5000));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// _ => unreachable!(),
|
||||
// }
|
||||
|
||||
// - denkwerk Partner: *fraenk*, *OTTO*, *SWR*, *Union Investment*, *ARAG*, erenja, BurdaForward, Stiebel Eltron, edding, polymore, Telekom, Santander, DeepL, motel one, Struggly, Fondation Beyeler, Charite, mainova, Esprit, Sport Cast, Remondis, condor, , Aktion Mensch, easy credit, Sparkasse KölnBonn, multiloop, Storck, Teambank, Ranger, Microsoft
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> ExitCode {
|
||||
let args = Args::parse();
|
||||
match run(args).await {
|
||||
Ok(_) => ExitCode::SUCCESS,
|
||||
Err(e) => {
|
||||
// match e {
|
||||
// Error::CsvFile { source, filename } => match source.kind() {
|
||||
// csv::ErrorKind::Io(e) => {
|
||||
// if e.kind() == std::io::ErrorKind::NotFound {
|
||||
// eprintln!("File {filename:?} does not exist.");
|
||||
// }
|
||||
// }
|
||||
// _ => {},
|
||||
// },
|
||||
// _ => eprintln!("{e}"),
|
||||
// }
|
||||
eprintln!("{e}");
|
||||
ExitCode::FAILURE
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user